hooks.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. use std::ffi::OsStr;
  2. use std::path::Path;
  3. use std::process::Command;
  4. use serde_json::json;
  5. use crate::{PluginError, PluginHooks, PluginRegistry};
  6. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  7. pub enum HookEvent {
  8. PreToolUse,
  9. PostToolUse,
  10. PostToolUseFailure,
  11. }
  12. impl HookEvent {
  13. fn as_str(self) -> &'static str {
  14. match self {
  15. Self::PreToolUse => "PreToolUse",
  16. Self::PostToolUse => "PostToolUse",
  17. Self::PostToolUseFailure => "PostToolUseFailure",
  18. }
  19. }
  20. }
  21. #[derive(Debug, Clone, PartialEq, Eq)]
  22. pub struct HookRunResult {
  23. denied: bool,
  24. failed: bool,
  25. messages: Vec<String>,
  26. }
  27. impl HookRunResult {
  28. #[must_use]
  29. pub fn allow(messages: Vec<String>) -> Self {
  30. Self {
  31. denied: false,
  32. failed: false,
  33. messages,
  34. }
  35. }
  36. #[must_use]
  37. pub fn is_denied(&self) -> bool {
  38. self.denied
  39. }
  40. #[must_use]
  41. pub fn is_failed(&self) -> bool {
  42. self.failed
  43. }
  44. #[must_use]
  45. pub fn messages(&self) -> &[String] {
  46. &self.messages
  47. }
  48. }
  49. #[derive(Debug, Clone, PartialEq, Eq, Default)]
  50. pub struct HookRunner {
  51. hooks: PluginHooks,
  52. }
  53. impl HookRunner {
  54. #[must_use]
  55. pub fn new(hooks: PluginHooks) -> Self {
  56. Self { hooks }
  57. }
  58. pub fn from_registry(plugin_registry: &PluginRegistry) -> Result<Self, PluginError> {
  59. Ok(Self::new(plugin_registry.aggregated_hooks()?))
  60. }
  61. #[must_use]
  62. pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
  63. self.run_commands(
  64. HookEvent::PreToolUse,
  65. &self.hooks.pre_tool_use,
  66. tool_name,
  67. tool_input,
  68. None,
  69. false,
  70. )
  71. }
  72. #[must_use]
  73. pub fn run_post_tool_use(
  74. &self,
  75. tool_name: &str,
  76. tool_input: &str,
  77. tool_output: &str,
  78. is_error: bool,
  79. ) -> HookRunResult {
  80. self.run_commands(
  81. HookEvent::PostToolUse,
  82. &self.hooks.post_tool_use,
  83. tool_name,
  84. tool_input,
  85. Some(tool_output),
  86. is_error,
  87. )
  88. }
  89. #[must_use]
  90. pub fn run_post_tool_use_failure(
  91. &self,
  92. tool_name: &str,
  93. tool_input: &str,
  94. tool_error: &str,
  95. ) -> HookRunResult {
  96. self.run_commands(
  97. HookEvent::PostToolUseFailure,
  98. &self.hooks.post_tool_use_failure,
  99. tool_name,
  100. tool_input,
  101. Some(tool_error),
  102. true,
  103. )
  104. }
  105. fn run_commands(
  106. &self,
  107. event: HookEvent,
  108. commands: &[String],
  109. tool_name: &str,
  110. tool_input: &str,
  111. tool_output: Option<&str>,
  112. is_error: bool,
  113. ) -> HookRunResult {
  114. if commands.is_empty() {
  115. return HookRunResult::allow(Vec::new());
  116. }
  117. let payload = hook_payload(event, tool_name, tool_input, tool_output, is_error).to_string();
  118. let mut messages = Vec::new();
  119. for command in commands {
  120. match self.run_command(
  121. command,
  122. event,
  123. tool_name,
  124. tool_input,
  125. tool_output,
  126. is_error,
  127. &payload,
  128. ) {
  129. HookCommandOutcome::Allow { message } => {
  130. if let Some(message) = message {
  131. messages.push(message);
  132. }
  133. }
  134. HookCommandOutcome::Deny { message } => {
  135. messages.push(message.unwrap_or_else(|| {
  136. format!("{} hook denied tool `{tool_name}`", event.as_str())
  137. }));
  138. return HookRunResult {
  139. denied: true,
  140. failed: false,
  141. messages,
  142. };
  143. }
  144. HookCommandOutcome::Failed { message } => {
  145. messages.push(message);
  146. return HookRunResult {
  147. denied: false,
  148. failed: true,
  149. messages,
  150. };
  151. }
  152. }
  153. }
  154. HookRunResult::allow(messages)
  155. }
  156. #[allow(clippy::too_many_arguments, clippy::unused_self)]
  157. fn run_command(
  158. &self,
  159. command: &str,
  160. event: HookEvent,
  161. tool_name: &str,
  162. tool_input: &str,
  163. tool_output: Option<&str>,
  164. is_error: bool,
  165. payload: &str,
  166. ) -> HookCommandOutcome {
  167. let mut child = shell_command(command);
  168. child.stdin(std::process::Stdio::piped());
  169. child.stdout(std::process::Stdio::piped());
  170. child.stderr(std::process::Stdio::piped());
  171. child.env("HOOK_EVENT", event.as_str());
  172. child.env("HOOK_TOOL_NAME", tool_name);
  173. child.env("HOOK_TOOL_INPUT", tool_input);
  174. child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" });
  175. if let Some(tool_output) = tool_output {
  176. child.env("HOOK_TOOL_OUTPUT", tool_output);
  177. }
  178. match child.output_with_stdin(payload.as_bytes()) {
  179. Ok(output) => {
  180. let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
  181. let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
  182. let message = (!stdout.is_empty()).then_some(stdout);
  183. match output.status.code() {
  184. Some(0) => HookCommandOutcome::Allow { message },
  185. Some(2) => HookCommandOutcome::Deny { message },
  186. Some(code) => HookCommandOutcome::Failed {
  187. message: format_hook_warning(
  188. command,
  189. code,
  190. message.as_deref(),
  191. stderr.as_str(),
  192. ),
  193. },
  194. None => HookCommandOutcome::Failed {
  195. message: format!(
  196. "{} hook `{command}` terminated by signal while handling `{tool_name}`",
  197. event.as_str()
  198. ),
  199. },
  200. }
  201. }
  202. Err(error) => HookCommandOutcome::Failed {
  203. message: format!(
  204. "{} hook `{command}` failed to start for `{tool_name}`: {error}",
  205. event.as_str()
  206. ),
  207. },
  208. }
  209. }
  210. }
  211. enum HookCommandOutcome {
  212. Allow { message: Option<String> },
  213. Deny { message: Option<String> },
  214. Failed { message: String },
  215. }
  216. fn hook_payload(
  217. event: HookEvent,
  218. tool_name: &str,
  219. tool_input: &str,
  220. tool_output: Option<&str>,
  221. is_error: bool,
  222. ) -> serde_json::Value {
  223. match event {
  224. HookEvent::PostToolUseFailure => json!({
  225. "hook_event_name": event.as_str(),
  226. "tool_name": tool_name,
  227. "tool_input": parse_tool_input(tool_input),
  228. "tool_input_json": tool_input,
  229. "tool_error": tool_output,
  230. "tool_result_is_error": true,
  231. }),
  232. _ => json!({
  233. "hook_event_name": event.as_str(),
  234. "tool_name": tool_name,
  235. "tool_input": parse_tool_input(tool_input),
  236. "tool_input_json": tool_input,
  237. "tool_output": tool_output,
  238. "tool_result_is_error": is_error,
  239. }),
  240. }
  241. }
  242. fn parse_tool_input(tool_input: &str) -> serde_json::Value {
  243. serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input }))
  244. }
  245. fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String {
  246. let mut message = format!("Hook `{command}` exited with status {code}");
  247. if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) {
  248. message.push_str(": ");
  249. message.push_str(stdout);
  250. } else if !stderr.is_empty() {
  251. message.push_str(": ");
  252. message.push_str(stderr);
  253. }
  254. message
  255. }
  256. fn shell_command(command: &str) -> CommandWithStdin {
  257. #[cfg(windows)]
  258. let command_builder = {
  259. let mut command_builder = Command::new("cmd");
  260. command_builder.arg("/C").arg(command);
  261. CommandWithStdin::new(command_builder)
  262. };
  263. #[cfg(not(windows))]
  264. let command_builder = if Path::new(command).exists() {
  265. let mut command_builder = Command::new("sh");
  266. command_builder.arg(command);
  267. CommandWithStdin::new(command_builder)
  268. } else {
  269. let mut command_builder = Command::new("sh");
  270. command_builder.arg("-lc").arg(command);
  271. CommandWithStdin::new(command_builder)
  272. };
  273. command_builder
  274. }
  275. struct CommandWithStdin {
  276. command: Command,
  277. }
  278. impl CommandWithStdin {
  279. fn new(command: Command) -> Self {
  280. Self { command }
  281. }
  282. fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self {
  283. self.command.stdin(cfg);
  284. self
  285. }
  286. fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self {
  287. self.command.stdout(cfg);
  288. self
  289. }
  290. fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self {
  291. self.command.stderr(cfg);
  292. self
  293. }
  294. fn env<K, V>(&mut self, key: K, value: V) -> &mut Self
  295. where
  296. K: AsRef<OsStr>,
  297. V: AsRef<OsStr>,
  298. {
  299. self.command.env(key, value);
  300. self
  301. }
  302. fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result<std::process::Output> {
  303. let mut child = self.command.spawn()?;
  304. if let Some(mut child_stdin) = child.stdin.take() {
  305. use std::io::Write as _;
  306. child_stdin.write_all(stdin)?;
  307. }
  308. child.wait_with_output()
  309. }
  310. }
  311. #[cfg(test)]
  312. mod tests {
  313. use super::{HookRunResult, HookRunner};
  314. use crate::{PluginManager, PluginManagerConfig};
  315. use std::fs;
  316. use std::path::{Path, PathBuf};
  317. use std::time::{SystemTime, UNIX_EPOCH};
  318. fn temp_dir(label: &str) -> PathBuf {
  319. let nanos = SystemTime::now()
  320. .duration_since(UNIX_EPOCH)
  321. .expect("time should be after epoch")
  322. .as_nanos();
  323. std::env::temp_dir().join(format!("plugins-hook-runner-{label}-{nanos}"))
  324. }
  325. fn write_hook_plugin(
  326. root: &Path,
  327. name: &str,
  328. pre_message: &str,
  329. post_message: &str,
  330. failure_message: &str,
  331. ) {
  332. fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir");
  333. fs::create_dir_all(root.join("hooks")).expect("hooks dir");
  334. fs::write(
  335. root.join("hooks").join("pre.sh"),
  336. format!("#!/bin/sh\nprintf '%s\\n' '{pre_message}'\n"),
  337. )
  338. .expect("write pre hook");
  339. fs::write(
  340. root.join("hooks").join("post.sh"),
  341. format!("#!/bin/sh\nprintf '%s\\n' '{post_message}'\n"),
  342. )
  343. .expect("write post hook");
  344. fs::write(
  345. root.join("hooks").join("failure.sh"),
  346. format!("#!/bin/sh\nprintf '%s\\n' '{failure_message}'\n"),
  347. )
  348. .expect("write failure hook");
  349. fs::write(
  350. root.join(".claude-plugin").join("plugin.json"),
  351. format!(
  352. "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"],\n \"PostToolUseFailure\": [\"./hooks/failure.sh\"]\n }}\n}}"
  353. ),
  354. )
  355. .expect("write plugin manifest");
  356. }
  357. #[test]
  358. fn collects_and_runs_hooks_from_enabled_plugins() {
  359. // given
  360. let config_home = temp_dir("config");
  361. let first_source_root = temp_dir("source-a");
  362. let second_source_root = temp_dir("source-b");
  363. write_hook_plugin(
  364. &first_source_root,
  365. "first",
  366. "plugin pre one",
  367. "plugin post one",
  368. "plugin failure one",
  369. );
  370. write_hook_plugin(
  371. &second_source_root,
  372. "second",
  373. "plugin pre two",
  374. "plugin post two",
  375. "plugin failure two",
  376. );
  377. let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home));
  378. manager
  379. .install(first_source_root.to_str().expect("utf8 path"))
  380. .expect("first plugin install should succeed");
  381. manager
  382. .install(second_source_root.to_str().expect("utf8 path"))
  383. .expect("second plugin install should succeed");
  384. let registry = manager.plugin_registry().expect("registry should build");
  385. // when
  386. let runner = HookRunner::from_registry(&registry).expect("plugin hooks should load");
  387. // then
  388. assert_eq!(
  389. runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#),
  390. HookRunResult::allow(vec![
  391. "plugin pre one".to_string(),
  392. "plugin pre two".to_string(),
  393. ])
  394. );
  395. assert_eq!(
  396. runner.run_post_tool_use("Read", r#"{"path":"README.md"}"#, "ok", false),
  397. HookRunResult::allow(vec![
  398. "plugin post one".to_string(),
  399. "plugin post two".to_string(),
  400. ])
  401. );
  402. assert_eq!(
  403. runner.run_post_tool_use_failure("Read", r#"{"path":"README.md"}"#, "tool failed",),
  404. HookRunResult::allow(vec![
  405. "plugin failure one".to_string(),
  406. "plugin failure two".to_string(),
  407. ])
  408. );
  409. let _ = fs::remove_dir_all(config_home);
  410. let _ = fs::remove_dir_all(first_source_root);
  411. let _ = fs::remove_dir_all(second_source_root);
  412. }
  413. #[test]
  414. fn pre_tool_use_denies_when_plugin_hook_exits_two() {
  415. // given
  416. let runner = HookRunner::new(crate::PluginHooks {
  417. pre_tool_use: vec!["printf 'blocked by plugin'; exit 2".to_string()],
  418. post_tool_use: Vec::new(),
  419. post_tool_use_failure: Vec::new(),
  420. });
  421. // when
  422. let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
  423. // then
  424. assert!(result.is_denied());
  425. assert_eq!(result.messages(), &["blocked by plugin".to_string()]);
  426. }
  427. #[test]
  428. fn propagates_plugin_hook_failures() {
  429. // given
  430. let runner = HookRunner::new(crate::PluginHooks {
  431. pre_tool_use: vec![
  432. "printf 'broken plugin hook'; exit 1".to_string(),
  433. "printf 'later plugin hook'".to_string(),
  434. ],
  435. post_tool_use: Vec::new(),
  436. post_tool_use_failure: Vec::new(),
  437. });
  438. // when
  439. let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
  440. // then
  441. assert!(result.is_failed());
  442. assert!(result
  443. .messages()
  444. .iter()
  445. .any(|message| message.contains("broken plugin hook")));
  446. assert!(!result
  447. .messages()
  448. .iter()
  449. .any(|message| message == "later plugin hook"));
  450. }
  451. }