hooks.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. use std::ffi::OsStr;
  2. use std::path::Path;
  3. use std::process::Command;
  4. use serde_json::json;
  5. use crate::{PluginError, PluginHooks, PluginRegistry};
  6. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  7. pub enum HookEvent {
  8. PreToolUse,
  9. PostToolUse,
  10. PostToolUseFailure,
  11. }
  12. impl HookEvent {
  13. fn as_str(self) -> &'static str {
  14. match self {
  15. Self::PreToolUse => "PreToolUse",
  16. Self::PostToolUse => "PostToolUse",
  17. Self::PostToolUseFailure => "PostToolUseFailure",
  18. }
  19. }
  20. }
  21. #[derive(Debug, Clone, PartialEq, Eq)]
  22. pub struct HookRunResult {
  23. denied: bool,
  24. failed: bool,
  25. messages: Vec<String>,
  26. }
  27. impl HookRunResult {
  28. #[must_use]
  29. pub fn allow(messages: Vec<String>) -> Self {
  30. Self {
  31. denied: false,
  32. failed: false,
  33. messages,
  34. }
  35. }
  36. #[must_use]
  37. pub fn is_denied(&self) -> bool {
  38. self.denied
  39. }
  40. #[must_use]
  41. pub fn is_failed(&self) -> bool {
  42. self.failed
  43. }
  44. #[must_use]
  45. pub fn messages(&self) -> &[String] {
  46. &self.messages
  47. }
  48. }
  49. #[derive(Debug, Clone, PartialEq, Eq, Default)]
  50. pub struct HookRunner {
  51. hooks: PluginHooks,
  52. }
  53. impl HookRunner {
  54. #[must_use]
  55. pub fn new(hooks: PluginHooks) -> Self {
  56. Self { hooks }
  57. }
  58. pub fn from_registry(plugin_registry: &PluginRegistry) -> Result<Self, PluginError> {
  59. Ok(Self::new(plugin_registry.aggregated_hooks()?))
  60. }
  61. #[must_use]
  62. pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
  63. Self::run_commands(
  64. HookEvent::PreToolUse,
  65. &self.hooks.pre_tool_use,
  66. tool_name,
  67. tool_input,
  68. None,
  69. false,
  70. )
  71. }
  72. #[must_use]
  73. pub fn run_post_tool_use(
  74. &self,
  75. tool_name: &str,
  76. tool_input: &str,
  77. tool_output: &str,
  78. is_error: bool,
  79. ) -> HookRunResult {
  80. Self::run_commands(
  81. HookEvent::PostToolUse,
  82. &self.hooks.post_tool_use,
  83. tool_name,
  84. tool_input,
  85. Some(tool_output),
  86. is_error,
  87. )
  88. }
  89. #[must_use]
  90. pub fn run_post_tool_use_failure(
  91. &self,
  92. tool_name: &str,
  93. tool_input: &str,
  94. tool_error: &str,
  95. ) -> HookRunResult {
  96. Self::run_commands(
  97. HookEvent::PostToolUseFailure,
  98. &self.hooks.post_tool_use_failure,
  99. tool_name,
  100. tool_input,
  101. Some(tool_error),
  102. true,
  103. )
  104. }
  105. fn run_commands(
  106. event: HookEvent,
  107. commands: &[String],
  108. tool_name: &str,
  109. tool_input: &str,
  110. tool_output: Option<&str>,
  111. is_error: bool,
  112. ) -> HookRunResult {
  113. if commands.is_empty() {
  114. return HookRunResult::allow(Vec::new());
  115. }
  116. let payload = hook_payload(event, tool_name, tool_input, tool_output, is_error).to_string();
  117. let mut messages = Vec::new();
  118. for command in commands {
  119. match Self::run_command(
  120. command,
  121. event,
  122. tool_name,
  123. tool_input,
  124. tool_output,
  125. is_error,
  126. &payload,
  127. ) {
  128. HookCommandOutcome::Allow { message } => {
  129. if let Some(message) = message {
  130. messages.push(message);
  131. }
  132. }
  133. HookCommandOutcome::Deny { message } => {
  134. messages.push(message.unwrap_or_else(|| {
  135. format!("{} hook denied tool `{tool_name}`", event.as_str())
  136. }));
  137. return HookRunResult {
  138. denied: true,
  139. failed: false,
  140. messages,
  141. };
  142. }
  143. HookCommandOutcome::Failed { message } => {
  144. messages.push(message);
  145. return HookRunResult {
  146. denied: false,
  147. failed: true,
  148. messages,
  149. };
  150. }
  151. }
  152. }
  153. HookRunResult::allow(messages)
  154. }
  155. #[allow(clippy::too_many_arguments)]
  156. fn run_command(
  157. command: &str,
  158. event: HookEvent,
  159. tool_name: &str,
  160. tool_input: &str,
  161. tool_output: Option<&str>,
  162. is_error: bool,
  163. payload: &str,
  164. ) -> HookCommandOutcome {
  165. let mut child = shell_command(command);
  166. child.stdin(std::process::Stdio::piped());
  167. child.stdout(std::process::Stdio::piped());
  168. child.stderr(std::process::Stdio::piped());
  169. child.env("HOOK_EVENT", event.as_str());
  170. child.env("HOOK_TOOL_NAME", tool_name);
  171. child.env("HOOK_TOOL_INPUT", tool_input);
  172. child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" });
  173. if let Some(tool_output) = tool_output {
  174. child.env("HOOK_TOOL_OUTPUT", tool_output);
  175. }
  176. match child.output_with_stdin(payload.as_bytes()) {
  177. Ok(output) => {
  178. let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
  179. let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
  180. let message = (!stdout.is_empty()).then_some(stdout);
  181. match output.status.code() {
  182. Some(0) => HookCommandOutcome::Allow { message },
  183. Some(2) => HookCommandOutcome::Deny { message },
  184. Some(code) => HookCommandOutcome::Failed {
  185. message: format_hook_warning(
  186. command,
  187. code,
  188. message.as_deref(),
  189. stderr.as_str(),
  190. ),
  191. },
  192. None => HookCommandOutcome::Failed {
  193. message: format!(
  194. "{} hook `{command}` terminated by signal while handling `{tool_name}`",
  195. event.as_str()
  196. ),
  197. },
  198. }
  199. }
  200. Err(error) => HookCommandOutcome::Failed {
  201. message: format!(
  202. "{} hook `{command}` failed to start for `{tool_name}`: {error}",
  203. event.as_str()
  204. ),
  205. },
  206. }
  207. }
  208. }
  209. enum HookCommandOutcome {
  210. Allow { message: Option<String> },
  211. Deny { message: Option<String> },
  212. Failed { message: String },
  213. }
  214. fn hook_payload(
  215. event: HookEvent,
  216. tool_name: &str,
  217. tool_input: &str,
  218. tool_output: Option<&str>,
  219. is_error: bool,
  220. ) -> serde_json::Value {
  221. match event {
  222. HookEvent::PostToolUseFailure => json!({
  223. "hook_event_name": event.as_str(),
  224. "tool_name": tool_name,
  225. "tool_input": parse_tool_input(tool_input),
  226. "tool_input_json": tool_input,
  227. "tool_error": tool_output,
  228. "tool_result_is_error": true,
  229. }),
  230. _ => json!({
  231. "hook_event_name": event.as_str(),
  232. "tool_name": tool_name,
  233. "tool_input": parse_tool_input(tool_input),
  234. "tool_input_json": tool_input,
  235. "tool_output": tool_output,
  236. "tool_result_is_error": is_error,
  237. }),
  238. }
  239. }
  240. fn parse_tool_input(tool_input: &str) -> serde_json::Value {
  241. serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input }))
  242. }
  243. fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String {
  244. let mut message = format!("Hook `{command}` exited with status {code}");
  245. if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) {
  246. message.push_str(": ");
  247. message.push_str(stdout);
  248. } else if !stderr.is_empty() {
  249. message.push_str(": ");
  250. message.push_str(stderr);
  251. }
  252. message
  253. }
  254. fn shell_command(command: &str) -> CommandWithStdin {
  255. #[cfg(windows)]
  256. let command_builder = {
  257. let mut command_builder = Command::new("cmd");
  258. command_builder.arg("/C").arg(command);
  259. CommandWithStdin::new(command_builder)
  260. };
  261. #[cfg(not(windows))]
  262. let command_builder = if Path::new(command).exists() {
  263. let mut command_builder = Command::new("sh");
  264. command_builder.arg(command);
  265. CommandWithStdin::new(command_builder)
  266. } else {
  267. let mut command_builder = Command::new("sh");
  268. command_builder.arg("-lc").arg(command);
  269. CommandWithStdin::new(command_builder)
  270. };
  271. command_builder
  272. }
  273. struct CommandWithStdin {
  274. command: Command,
  275. }
  276. impl CommandWithStdin {
  277. fn new(command: Command) -> Self {
  278. Self { command }
  279. }
  280. fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self {
  281. self.command.stdin(cfg);
  282. self
  283. }
  284. fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self {
  285. self.command.stdout(cfg);
  286. self
  287. }
  288. fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self {
  289. self.command.stderr(cfg);
  290. self
  291. }
  292. fn env<K, V>(&mut self, key: K, value: V) -> &mut Self
  293. where
  294. K: AsRef<OsStr>,
  295. V: AsRef<OsStr>,
  296. {
  297. self.command.env(key, value);
  298. self
  299. }
  300. fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result<std::process::Output> {
  301. let mut child = self.command.spawn()?;
  302. if let Some(mut child_stdin) = child.stdin.take() {
  303. use std::io::Write as _;
  304. child_stdin.write_all(stdin)?;
  305. }
  306. child.wait_with_output()
  307. }
  308. }
  309. #[cfg(test)]
  310. mod tests {
  311. use super::{HookRunResult, HookRunner};
  312. use crate::{PluginManager, PluginManagerConfig};
  313. use std::fs;
  314. use std::path::{Path, PathBuf};
  315. use std::time::{SystemTime, UNIX_EPOCH};
  316. fn temp_dir(label: &str) -> PathBuf {
  317. let nanos = SystemTime::now()
  318. .duration_since(UNIX_EPOCH)
  319. .expect("time should be after epoch")
  320. .as_nanos();
  321. std::env::temp_dir().join(format!("plugins-hook-runner-{label}-{nanos}"))
  322. }
  323. fn write_hook_plugin(
  324. root: &Path,
  325. name: &str,
  326. pre_message: &str,
  327. post_message: &str,
  328. failure_message: &str,
  329. ) {
  330. fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir");
  331. fs::create_dir_all(root.join("hooks")).expect("hooks dir");
  332. fs::write(
  333. root.join("hooks").join("pre.sh"),
  334. format!("#!/bin/sh\nprintf '%s\\n' '{pre_message}'\n"),
  335. )
  336. .expect("write pre hook");
  337. fs::write(
  338. root.join("hooks").join("post.sh"),
  339. format!("#!/bin/sh\nprintf '%s\\n' '{post_message}'\n"),
  340. )
  341. .expect("write post hook");
  342. fs::write(
  343. root.join("hooks").join("failure.sh"),
  344. format!("#!/bin/sh\nprintf '%s\\n' '{failure_message}'\n"),
  345. )
  346. .expect("write failure hook");
  347. fs::write(
  348. root.join(".claude-plugin").join("plugin.json"),
  349. format!(
  350. "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"],\n \"PostToolUseFailure\": [\"./hooks/failure.sh\"]\n }}\n}}"
  351. ),
  352. )
  353. .expect("write plugin manifest");
  354. }
  355. #[test]
  356. fn collects_and_runs_hooks_from_enabled_plugins() {
  357. // given
  358. let config_home = temp_dir("config");
  359. let first_source_root = temp_dir("source-a");
  360. let second_source_root = temp_dir("source-b");
  361. write_hook_plugin(
  362. &first_source_root,
  363. "first",
  364. "plugin pre one",
  365. "plugin post one",
  366. "plugin failure one",
  367. );
  368. write_hook_plugin(
  369. &second_source_root,
  370. "second",
  371. "plugin pre two",
  372. "plugin post two",
  373. "plugin failure two",
  374. );
  375. let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home));
  376. manager
  377. .install(first_source_root.to_str().expect("utf8 path"))
  378. .expect("first plugin install should succeed");
  379. manager
  380. .install(second_source_root.to_str().expect("utf8 path"))
  381. .expect("second plugin install should succeed");
  382. let registry = manager.plugin_registry().expect("registry should build");
  383. // when
  384. let runner = HookRunner::from_registry(&registry).expect("plugin hooks should load");
  385. // then
  386. assert_eq!(
  387. runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#),
  388. HookRunResult::allow(vec![
  389. "plugin pre one".to_string(),
  390. "plugin pre two".to_string(),
  391. ])
  392. );
  393. assert_eq!(
  394. runner.run_post_tool_use("Read", r#"{"path":"README.md"}"#, "ok", false),
  395. HookRunResult::allow(vec![
  396. "plugin post one".to_string(),
  397. "plugin post two".to_string(),
  398. ])
  399. );
  400. assert_eq!(
  401. runner.run_post_tool_use_failure("Read", r#"{"path":"README.md"}"#, "tool failed",),
  402. HookRunResult::allow(vec![
  403. "plugin failure one".to_string(),
  404. "plugin failure two".to_string(),
  405. ])
  406. );
  407. let _ = fs::remove_dir_all(config_home);
  408. let _ = fs::remove_dir_all(first_source_root);
  409. let _ = fs::remove_dir_all(second_source_root);
  410. }
  411. #[test]
  412. fn pre_tool_use_denies_when_plugin_hook_exits_two() {
  413. // given
  414. let runner = HookRunner::new(crate::PluginHooks {
  415. pre_tool_use: vec!["printf 'blocked by plugin'; exit 2".to_string()],
  416. post_tool_use: Vec::new(),
  417. post_tool_use_failure: Vec::new(),
  418. });
  419. // when
  420. let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
  421. // then
  422. assert!(result.is_denied());
  423. assert_eq!(result.messages(), &["blocked by plugin".to_string()]);
  424. }
  425. #[test]
  426. fn propagates_plugin_hook_failures() {
  427. // given
  428. let runner = HookRunner::new(crate::PluginHooks {
  429. pre_tool_use: vec![
  430. "printf 'broken plugin hook'; exit 1".to_string(),
  431. "printf 'later plugin hook'".to_string(),
  432. ],
  433. post_tool_use: Vec::new(),
  434. post_tool_use_failure: Vec::new(),
  435. });
  436. // when
  437. let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
  438. // then
  439. assert!(result.is_failed());
  440. assert!(result
  441. .messages()
  442. .iter()
  443. .any(|message| message.contains("broken plugin hook")));
  444. assert!(!result
  445. .messages()
  446. .iter()
  447. .any(|message| message == "later plugin hook"));
  448. }
  449. }