hooks.rs 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987
  1. use std::ffi::OsStr;
  2. use std::io::Write;
  3. use std::process::{Command, Stdio};
  4. use std::sync::{
  5. atomic::{AtomicBool, Ordering},
  6. Arc,
  7. };
  8. use std::thread;
  9. use std::time::Duration;
  10. use serde_json::{json, Value};
  11. use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
  12. use crate::permissions::PermissionOverride;
  13. pub type HookPermissionDecision = PermissionOverride;
  14. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  15. pub enum HookEvent {
  16. PreToolUse,
  17. PostToolUse,
  18. PostToolUseFailure,
  19. }
  20. impl HookEvent {
  21. #[must_use]
  22. pub fn as_str(self) -> &'static str {
  23. match self {
  24. Self::PreToolUse => "PreToolUse",
  25. Self::PostToolUse => "PostToolUse",
  26. Self::PostToolUseFailure => "PostToolUseFailure",
  27. }
  28. }
  29. }
  30. #[derive(Debug, Clone, PartialEq, Eq)]
  31. pub enum HookProgressEvent {
  32. Started {
  33. event: HookEvent,
  34. tool_name: String,
  35. command: String,
  36. },
  37. Completed {
  38. event: HookEvent,
  39. tool_name: String,
  40. command: String,
  41. },
  42. Cancelled {
  43. event: HookEvent,
  44. tool_name: String,
  45. command: String,
  46. },
  47. }
  48. pub trait HookProgressReporter {
  49. fn on_event(&mut self, event: &HookProgressEvent);
  50. }
  51. #[derive(Debug, Clone, Default)]
  52. pub struct HookAbortSignal {
  53. aborted: Arc<AtomicBool>,
  54. }
  55. impl HookAbortSignal {
  56. #[must_use]
  57. pub fn new() -> Self {
  58. Self::default()
  59. }
  60. pub fn abort(&self) {
  61. self.aborted.store(true, Ordering::SeqCst);
  62. }
  63. #[must_use]
  64. pub fn is_aborted(&self) -> bool {
  65. self.aborted.load(Ordering::SeqCst)
  66. }
  67. }
  68. #[derive(Debug, Clone, PartialEq, Eq)]
  69. pub struct HookRunResult {
  70. denied: bool,
  71. failed: bool,
  72. cancelled: bool,
  73. messages: Vec<String>,
  74. permission_override: Option<PermissionOverride>,
  75. permission_reason: Option<String>,
  76. updated_input: Option<String>,
  77. }
  78. impl HookRunResult {
  79. #[must_use]
  80. pub fn allow(messages: Vec<String>) -> Self {
  81. Self {
  82. denied: false,
  83. failed: false,
  84. cancelled: false,
  85. messages,
  86. permission_override: None,
  87. permission_reason: None,
  88. updated_input: None,
  89. }
  90. }
  91. #[must_use]
  92. pub fn is_denied(&self) -> bool {
  93. self.denied
  94. }
  95. #[must_use]
  96. pub fn is_failed(&self) -> bool {
  97. self.failed
  98. }
  99. #[must_use]
  100. pub fn is_cancelled(&self) -> bool {
  101. self.cancelled
  102. }
  103. #[must_use]
  104. pub fn messages(&self) -> &[String] {
  105. &self.messages
  106. }
  107. #[must_use]
  108. pub fn permission_override(&self) -> Option<PermissionOverride> {
  109. self.permission_override
  110. }
  111. #[must_use]
  112. pub fn permission_decision(&self) -> Option<HookPermissionDecision> {
  113. self.permission_override
  114. }
  115. #[must_use]
  116. pub fn permission_reason(&self) -> Option<&str> {
  117. self.permission_reason.as_deref()
  118. }
  119. #[must_use]
  120. pub fn updated_input(&self) -> Option<&str> {
  121. self.updated_input.as_deref()
  122. }
  123. #[must_use]
  124. pub fn updated_input_json(&self) -> Option<&str> {
  125. self.updated_input()
  126. }
  127. }
  128. #[derive(Debug, Clone, PartialEq, Eq, Default)]
  129. pub struct HookRunner {
  130. config: RuntimeHookConfig,
  131. }
  132. impl HookRunner {
  133. #[must_use]
  134. pub fn new(config: RuntimeHookConfig) -> Self {
  135. Self { config }
  136. }
  137. #[must_use]
  138. pub fn from_feature_config(feature_config: &RuntimeFeatureConfig) -> Self {
  139. Self::new(feature_config.hooks().clone())
  140. }
  141. #[must_use]
  142. pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
  143. self.run_pre_tool_use_with_context(tool_name, tool_input, None, None)
  144. }
  145. #[must_use]
  146. pub fn run_pre_tool_use_with_context(
  147. &self,
  148. tool_name: &str,
  149. tool_input: &str,
  150. abort_signal: Option<&HookAbortSignal>,
  151. reporter: Option<&mut dyn HookProgressReporter>,
  152. ) -> HookRunResult {
  153. Self::run_commands(
  154. HookEvent::PreToolUse,
  155. self.config.pre_tool_use(),
  156. tool_name,
  157. tool_input,
  158. None,
  159. false,
  160. abort_signal,
  161. reporter,
  162. )
  163. }
  164. #[must_use]
  165. pub fn run_pre_tool_use_with_signal(
  166. &self,
  167. tool_name: &str,
  168. tool_input: &str,
  169. abort_signal: Option<&HookAbortSignal>,
  170. ) -> HookRunResult {
  171. self.run_pre_tool_use_with_context(tool_name, tool_input, abort_signal, None)
  172. }
  173. #[must_use]
  174. pub fn run_post_tool_use(
  175. &self,
  176. tool_name: &str,
  177. tool_input: &str,
  178. tool_output: &str,
  179. is_error: bool,
  180. ) -> HookRunResult {
  181. self.run_post_tool_use_with_context(
  182. tool_name,
  183. tool_input,
  184. tool_output,
  185. is_error,
  186. None,
  187. None,
  188. )
  189. }
  190. #[must_use]
  191. pub fn run_post_tool_use_with_context(
  192. &self,
  193. tool_name: &str,
  194. tool_input: &str,
  195. tool_output: &str,
  196. is_error: bool,
  197. abort_signal: Option<&HookAbortSignal>,
  198. reporter: Option<&mut dyn HookProgressReporter>,
  199. ) -> HookRunResult {
  200. Self::run_commands(
  201. HookEvent::PostToolUse,
  202. self.config.post_tool_use(),
  203. tool_name,
  204. tool_input,
  205. Some(tool_output),
  206. is_error,
  207. abort_signal,
  208. reporter,
  209. )
  210. }
  211. #[must_use]
  212. pub fn run_post_tool_use_with_signal(
  213. &self,
  214. tool_name: &str,
  215. tool_input: &str,
  216. tool_output: &str,
  217. is_error: bool,
  218. abort_signal: Option<&HookAbortSignal>,
  219. ) -> HookRunResult {
  220. self.run_post_tool_use_with_context(
  221. tool_name,
  222. tool_input,
  223. tool_output,
  224. is_error,
  225. abort_signal,
  226. None,
  227. )
  228. }
  229. #[must_use]
  230. pub fn run_post_tool_use_failure(
  231. &self,
  232. tool_name: &str,
  233. tool_input: &str,
  234. tool_error: &str,
  235. ) -> HookRunResult {
  236. self.run_post_tool_use_failure_with_context(tool_name, tool_input, tool_error, None, None)
  237. }
  238. #[must_use]
  239. pub fn run_post_tool_use_failure_with_context(
  240. &self,
  241. tool_name: &str,
  242. tool_input: &str,
  243. tool_error: &str,
  244. abort_signal: Option<&HookAbortSignal>,
  245. reporter: Option<&mut dyn HookProgressReporter>,
  246. ) -> HookRunResult {
  247. Self::run_commands(
  248. HookEvent::PostToolUseFailure,
  249. self.config.post_tool_use_failure(),
  250. tool_name,
  251. tool_input,
  252. Some(tool_error),
  253. true,
  254. abort_signal,
  255. reporter,
  256. )
  257. }
  258. #[must_use]
  259. pub fn run_post_tool_use_failure_with_signal(
  260. &self,
  261. tool_name: &str,
  262. tool_input: &str,
  263. tool_error: &str,
  264. abort_signal: Option<&HookAbortSignal>,
  265. ) -> HookRunResult {
  266. self.run_post_tool_use_failure_with_context(
  267. tool_name,
  268. tool_input,
  269. tool_error,
  270. abort_signal,
  271. None,
  272. )
  273. }
  274. #[allow(clippy::too_many_arguments)]
  275. fn run_commands(
  276. event: HookEvent,
  277. commands: &[String],
  278. tool_name: &str,
  279. tool_input: &str,
  280. tool_output: Option<&str>,
  281. is_error: bool,
  282. abort_signal: Option<&HookAbortSignal>,
  283. mut reporter: Option<&mut dyn HookProgressReporter>,
  284. ) -> HookRunResult {
  285. if commands.is_empty() {
  286. return HookRunResult::allow(Vec::new());
  287. }
  288. if abort_signal.is_some_and(HookAbortSignal::is_aborted) {
  289. return HookRunResult {
  290. denied: false,
  291. failed: false,
  292. cancelled: true,
  293. messages: vec![format!(
  294. "{} hook cancelled before execution",
  295. event.as_str()
  296. )],
  297. permission_override: None,
  298. permission_reason: None,
  299. updated_input: None,
  300. };
  301. }
  302. let payload = hook_payload(event, tool_name, tool_input, tool_output, is_error).to_string();
  303. let mut result = HookRunResult::allow(Vec::new());
  304. for command in commands {
  305. if let Some(reporter) = reporter.as_deref_mut() {
  306. reporter.on_event(&HookProgressEvent::Started {
  307. event,
  308. tool_name: tool_name.to_string(),
  309. command: command.clone(),
  310. });
  311. }
  312. match Self::run_command(
  313. command,
  314. event,
  315. tool_name,
  316. tool_input,
  317. tool_output,
  318. is_error,
  319. &payload,
  320. abort_signal,
  321. ) {
  322. HookCommandOutcome::Allow { parsed } => {
  323. if let Some(reporter) = reporter.as_deref_mut() {
  324. reporter.on_event(&HookProgressEvent::Completed {
  325. event,
  326. tool_name: tool_name.to_string(),
  327. command: command.clone(),
  328. });
  329. }
  330. merge_parsed_hook_output(&mut result, parsed);
  331. }
  332. HookCommandOutcome::Deny { parsed } => {
  333. if let Some(reporter) = reporter.as_deref_mut() {
  334. reporter.on_event(&HookProgressEvent::Completed {
  335. event,
  336. tool_name: tool_name.to_string(),
  337. command: command.clone(),
  338. });
  339. }
  340. merge_parsed_hook_output(&mut result, parsed);
  341. result.denied = true;
  342. return result;
  343. }
  344. HookCommandOutcome::Failed { parsed } => {
  345. if let Some(reporter) = reporter.as_deref_mut() {
  346. reporter.on_event(&HookProgressEvent::Completed {
  347. event,
  348. tool_name: tool_name.to_string(),
  349. command: command.clone(),
  350. });
  351. }
  352. merge_parsed_hook_output(&mut result, parsed);
  353. result.failed = true;
  354. return result;
  355. }
  356. HookCommandOutcome::Cancelled { message } => {
  357. if let Some(reporter) = reporter.as_deref_mut() {
  358. reporter.on_event(&HookProgressEvent::Cancelled {
  359. event,
  360. tool_name: tool_name.to_string(),
  361. command: command.clone(),
  362. });
  363. }
  364. result.cancelled = true;
  365. result.messages.push(message);
  366. return result;
  367. }
  368. }
  369. }
  370. result
  371. }
  372. #[allow(clippy::too_many_arguments)]
  373. fn run_command(
  374. command: &str,
  375. event: HookEvent,
  376. tool_name: &str,
  377. tool_input: &str,
  378. tool_output: Option<&str>,
  379. is_error: bool,
  380. payload: &str,
  381. abort_signal: Option<&HookAbortSignal>,
  382. ) -> HookCommandOutcome {
  383. let mut child = shell_command(command);
  384. child.stdin(Stdio::piped());
  385. child.stdout(Stdio::piped());
  386. child.stderr(Stdio::piped());
  387. child.env("HOOK_EVENT", event.as_str());
  388. child.env("HOOK_TOOL_NAME", tool_name);
  389. child.env("HOOK_TOOL_INPUT", tool_input);
  390. child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" });
  391. if let Some(tool_output) = tool_output {
  392. child.env("HOOK_TOOL_OUTPUT", tool_output);
  393. }
  394. match child.output_with_stdin(payload.as_bytes(), abort_signal) {
  395. Ok(CommandExecution::Finished(output)) => {
  396. let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
  397. let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
  398. let parsed = parse_hook_output(&stdout);
  399. let primary_message = parsed.primary_message().map(ToOwned::to_owned);
  400. match output.status.code() {
  401. Some(0) => {
  402. if parsed.deny {
  403. HookCommandOutcome::Deny { parsed }
  404. } else {
  405. HookCommandOutcome::Allow { parsed }
  406. }
  407. }
  408. Some(2) => HookCommandOutcome::Deny {
  409. parsed: parsed.with_fallback_message(format!(
  410. "{} hook denied tool `{tool_name}`",
  411. event.as_str()
  412. )),
  413. },
  414. Some(code) => HookCommandOutcome::Failed {
  415. parsed: parsed.with_fallback_message(format_hook_failure(
  416. command,
  417. code,
  418. primary_message.as_deref(),
  419. stderr.as_str(),
  420. )),
  421. },
  422. None => HookCommandOutcome::Failed {
  423. parsed: parsed.with_fallback_message(format!(
  424. "{} hook `{command}` terminated by signal while handling `{}`",
  425. event.as_str(),
  426. tool_name
  427. )),
  428. },
  429. }
  430. }
  431. Ok(CommandExecution::Cancelled) => HookCommandOutcome::Cancelled {
  432. message: format!(
  433. "{} hook `{command}` cancelled while handling `{tool_name}`",
  434. event.as_str()
  435. ),
  436. },
  437. Err(error) => HookCommandOutcome::Failed {
  438. parsed: ParsedHookOutput {
  439. messages: vec![format!(
  440. "{} hook `{command}` failed to start for `{}`: {error}",
  441. event.as_str(),
  442. tool_name
  443. )],
  444. ..ParsedHookOutput::default()
  445. },
  446. },
  447. }
  448. }
  449. }
  450. enum HookCommandOutcome {
  451. Allow { parsed: ParsedHookOutput },
  452. Deny { parsed: ParsedHookOutput },
  453. Failed { parsed: ParsedHookOutput },
  454. Cancelled { message: String },
  455. }
  456. #[derive(Debug, Clone, PartialEq, Eq, Default)]
  457. struct ParsedHookOutput {
  458. messages: Vec<String>,
  459. deny: bool,
  460. permission_override: Option<PermissionOverride>,
  461. permission_reason: Option<String>,
  462. updated_input: Option<String>,
  463. }
  464. impl ParsedHookOutput {
  465. fn with_fallback_message(mut self, fallback: String) -> Self {
  466. if self.messages.is_empty() {
  467. self.messages.push(fallback);
  468. }
  469. self
  470. }
  471. fn primary_message(&self) -> Option<&str> {
  472. self.messages.first().map(String::as_str)
  473. }
  474. }
  475. fn merge_parsed_hook_output(target: &mut HookRunResult, parsed: ParsedHookOutput) {
  476. target.messages.extend(parsed.messages);
  477. if parsed.permission_override.is_some() {
  478. target.permission_override = parsed.permission_override;
  479. }
  480. if parsed.permission_reason.is_some() {
  481. target.permission_reason = parsed.permission_reason;
  482. }
  483. if parsed.updated_input.is_some() {
  484. target.updated_input = parsed.updated_input;
  485. }
  486. }
  487. fn parse_hook_output(stdout: &str) -> ParsedHookOutput {
  488. if stdout.is_empty() {
  489. return ParsedHookOutput::default();
  490. }
  491. let Ok(Value::Object(root)) = serde_json::from_str::<Value>(stdout) else {
  492. return ParsedHookOutput {
  493. messages: vec![stdout.to_string()],
  494. ..ParsedHookOutput::default()
  495. };
  496. };
  497. let mut parsed = ParsedHookOutput::default();
  498. if let Some(message) = root.get("systemMessage").and_then(Value::as_str) {
  499. parsed.messages.push(message.to_string());
  500. }
  501. if let Some(message) = root.get("reason").and_then(Value::as_str) {
  502. parsed.messages.push(message.to_string());
  503. }
  504. if root.get("continue").and_then(Value::as_bool) == Some(false)
  505. || root.get("decision").and_then(Value::as_str) == Some("block")
  506. {
  507. parsed.deny = true;
  508. }
  509. if let Some(Value::Object(specific)) = root.get("hookSpecificOutput") {
  510. if let Some(Value::String(additional_context)) = specific.get("additionalContext") {
  511. parsed.messages.push(additional_context.clone());
  512. }
  513. if let Some(decision) = specific.get("permissionDecision").and_then(Value::as_str) {
  514. parsed.permission_override = match decision {
  515. "allow" => Some(PermissionOverride::Allow),
  516. "deny" => Some(PermissionOverride::Deny),
  517. "ask" => Some(PermissionOverride::Ask),
  518. _ => None,
  519. };
  520. }
  521. if let Some(reason) = specific
  522. .get("permissionDecisionReason")
  523. .and_then(Value::as_str)
  524. {
  525. parsed.permission_reason = Some(reason.to_string());
  526. }
  527. if let Some(updated_input) = specific.get("updatedInput") {
  528. parsed.updated_input = serde_json::to_string(updated_input).ok();
  529. }
  530. }
  531. if parsed.messages.is_empty() {
  532. parsed.messages.push(stdout.to_string());
  533. }
  534. parsed
  535. }
  536. fn hook_payload(
  537. event: HookEvent,
  538. tool_name: &str,
  539. tool_input: &str,
  540. tool_output: Option<&str>,
  541. is_error: bool,
  542. ) -> Value {
  543. match event {
  544. HookEvent::PostToolUseFailure => json!({
  545. "hook_event_name": event.as_str(),
  546. "tool_name": tool_name,
  547. "tool_input": parse_tool_input(tool_input),
  548. "tool_input_json": tool_input,
  549. "tool_error": tool_output,
  550. "tool_result_is_error": true,
  551. }),
  552. _ => json!({
  553. "hook_event_name": event.as_str(),
  554. "tool_name": tool_name,
  555. "tool_input": parse_tool_input(tool_input),
  556. "tool_input_json": tool_input,
  557. "tool_output": tool_output,
  558. "tool_result_is_error": is_error,
  559. }),
  560. }
  561. }
  562. fn parse_tool_input(tool_input: &str) -> Value {
  563. serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input }))
  564. }
  565. fn format_hook_failure(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String {
  566. let mut message = format!("Hook `{command}` exited with status {code}");
  567. if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) {
  568. message.push_str(": ");
  569. message.push_str(stdout);
  570. } else if !stderr.is_empty() {
  571. message.push_str(": ");
  572. message.push_str(stderr);
  573. }
  574. message
  575. }
  576. fn shell_command(command: &str) -> CommandWithStdin {
  577. #[cfg(windows)]
  578. let mut command_builder = {
  579. let mut command_builder = Command::new("cmd");
  580. command_builder.arg("/C").arg(command);
  581. CommandWithStdin::new(command_builder)
  582. };
  583. #[cfg(not(windows))]
  584. let command_builder = {
  585. let mut command_builder = Command::new("sh");
  586. command_builder.arg("-lc").arg(command);
  587. CommandWithStdin::new(command_builder)
  588. };
  589. command_builder
  590. }
  591. struct CommandWithStdin {
  592. command: Command,
  593. }
  594. impl CommandWithStdin {
  595. fn new(command: Command) -> Self {
  596. Self { command }
  597. }
  598. fn stdin(&mut self, cfg: Stdio) -> &mut Self {
  599. self.command.stdin(cfg);
  600. self
  601. }
  602. fn stdout(&mut self, cfg: Stdio) -> &mut Self {
  603. self.command.stdout(cfg);
  604. self
  605. }
  606. fn stderr(&mut self, cfg: Stdio) -> &mut Self {
  607. self.command.stderr(cfg);
  608. self
  609. }
  610. fn env<K, V>(&mut self, key: K, value: V) -> &mut Self
  611. where
  612. K: AsRef<OsStr>,
  613. V: AsRef<OsStr>,
  614. {
  615. self.command.env(key, value);
  616. self
  617. }
  618. fn output_with_stdin(
  619. &mut self,
  620. stdin: &[u8],
  621. abort_signal: Option<&HookAbortSignal>,
  622. ) -> std::io::Result<CommandExecution> {
  623. let mut child = self.command.spawn()?;
  624. if let Some(mut child_stdin) = child.stdin.take() {
  625. child_stdin.write_all(stdin)?;
  626. }
  627. loop {
  628. if abort_signal.is_some_and(HookAbortSignal::is_aborted) {
  629. let _ = child.kill();
  630. let _ = child.wait_with_output();
  631. return Ok(CommandExecution::Cancelled);
  632. }
  633. match child.try_wait()? {
  634. Some(_) => return child.wait_with_output().map(CommandExecution::Finished),
  635. None => thread::sleep(Duration::from_millis(20)),
  636. }
  637. }
  638. }
  639. }
  640. enum CommandExecution {
  641. Finished(std::process::Output),
  642. Cancelled,
  643. }
  644. #[cfg(test)]
  645. mod tests {
  646. use std::thread;
  647. use std::time::Duration;
  648. use super::{
  649. HookAbortSignal, HookEvent, HookProgressEvent, HookProgressReporter, HookRunResult,
  650. HookRunner,
  651. };
  652. use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
  653. use crate::permissions::PermissionOverride;
  654. struct RecordingReporter {
  655. events: Vec<HookProgressEvent>,
  656. }
  657. impl HookProgressReporter for RecordingReporter {
  658. fn on_event(&mut self, event: &HookProgressEvent) {
  659. self.events.push(event.clone());
  660. }
  661. }
  662. #[test]
  663. fn allows_exit_code_zero_and_captures_stdout() {
  664. let runner = HookRunner::new(RuntimeHookConfig::new(
  665. vec![shell_snippet("printf 'pre ok'")],
  666. Vec::new(),
  667. Vec::new(),
  668. ));
  669. let result = runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#);
  670. assert_eq!(result, HookRunResult::allow(vec!["pre ok".to_string()]));
  671. }
  672. #[test]
  673. fn denies_exit_code_two() {
  674. let runner = HookRunner::new(RuntimeHookConfig::new(
  675. vec![shell_snippet("printf 'blocked by hook'; exit 2")],
  676. Vec::new(),
  677. Vec::new(),
  678. ));
  679. let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
  680. assert!(result.is_denied());
  681. assert_eq!(result.messages(), &["blocked by hook".to_string()]);
  682. }
  683. #[test]
  684. fn propagates_other_non_zero_statuses_as_failures() {
  685. let runner = HookRunner::from_feature_config(&RuntimeFeatureConfig::default().with_hooks(
  686. RuntimeHookConfig::new(
  687. vec![shell_snippet("printf 'warning hook'; exit 1")],
  688. Vec::new(),
  689. Vec::new(),
  690. ),
  691. ));
  692. // given
  693. // when
  694. let result = runner.run_pre_tool_use("Edit", r#"{"file":"src/lib.rs"}"#);
  695. // then
  696. assert!(result.is_failed());
  697. assert!(result
  698. .messages()
  699. .iter()
  700. .any(|message| message.contains("warning hook")));
  701. }
  702. #[test]
  703. fn parses_pre_hook_permission_override_and_updated_input() {
  704. let runner = HookRunner::new(RuntimeHookConfig::new(
  705. vec![shell_snippet(
  706. r#"printf '%s' '{"systemMessage":"updated","hookSpecificOutput":{"permissionDecision":"allow","permissionDecisionReason":"hook ok","updatedInput":{"command":"git status"}}}'"#,
  707. )],
  708. Vec::new(),
  709. Vec::new(),
  710. ));
  711. let result = runner.run_pre_tool_use("bash", r#"{"command":"pwd"}"#);
  712. assert_eq!(
  713. result.permission_override(),
  714. Some(PermissionOverride::Allow)
  715. );
  716. assert_eq!(result.permission_reason(), Some("hook ok"));
  717. assert_eq!(result.updated_input(), Some(r#"{"command":"git status"}"#));
  718. assert!(result.messages().iter().any(|message| message == "updated"));
  719. }
  720. #[test]
  721. fn runs_post_tool_use_failure_hooks() {
  722. // given
  723. let runner = HookRunner::new(RuntimeHookConfig::new(
  724. Vec::new(),
  725. Vec::new(),
  726. vec![shell_snippet("printf 'failure hook ran'")],
  727. ));
  728. // when
  729. let result =
  730. runner.run_post_tool_use_failure("bash", r#"{"command":"false"}"#, "command failed");
  731. // then
  732. assert!(!result.is_denied());
  733. assert_eq!(result.messages(), &["failure hook ran".to_string()]);
  734. }
  735. #[test]
  736. fn stops_running_failure_hooks_after_failure() {
  737. // given
  738. let runner = HookRunner::new(RuntimeHookConfig::new(
  739. Vec::new(),
  740. Vec::new(),
  741. vec![
  742. shell_snippet("printf 'broken failure hook'; exit 1"),
  743. shell_snippet("printf 'later failure hook'"),
  744. ],
  745. ));
  746. // when
  747. let result =
  748. runner.run_post_tool_use_failure("bash", r#"{"command":"false"}"#, "command failed");
  749. // then
  750. assert!(result.is_failed());
  751. assert!(result
  752. .messages()
  753. .iter()
  754. .any(|message| message.contains("broken failure hook")));
  755. assert!(!result
  756. .messages()
  757. .iter()
  758. .any(|message| message == "later failure hook"));
  759. }
  760. #[test]
  761. fn executes_hooks_in_configured_order() {
  762. // given
  763. let runner = HookRunner::new(RuntimeHookConfig::new(
  764. vec![
  765. shell_snippet("printf 'first'"),
  766. shell_snippet("printf 'second'"),
  767. ],
  768. Vec::new(),
  769. Vec::new(),
  770. ));
  771. let mut reporter = RecordingReporter { events: Vec::new() };
  772. // when
  773. let result = runner.run_pre_tool_use_with_context(
  774. "Read",
  775. r#"{"path":"README.md"}"#,
  776. None,
  777. Some(&mut reporter),
  778. );
  779. // then
  780. assert_eq!(
  781. result,
  782. HookRunResult::allow(vec!["first".to_string(), "second".to_string()])
  783. );
  784. assert_eq!(reporter.events.len(), 4);
  785. assert!(matches!(
  786. &reporter.events[0],
  787. HookProgressEvent::Started {
  788. event: HookEvent::PreToolUse,
  789. command,
  790. ..
  791. } if command == "printf 'first'"
  792. ));
  793. assert!(matches!(
  794. &reporter.events[1],
  795. HookProgressEvent::Completed {
  796. event: HookEvent::PreToolUse,
  797. command,
  798. ..
  799. } if command == "printf 'first'"
  800. ));
  801. assert!(matches!(
  802. &reporter.events[2],
  803. HookProgressEvent::Started {
  804. event: HookEvent::PreToolUse,
  805. command,
  806. ..
  807. } if command == "printf 'second'"
  808. ));
  809. assert!(matches!(
  810. &reporter.events[3],
  811. HookProgressEvent::Completed {
  812. event: HookEvent::PreToolUse,
  813. command,
  814. ..
  815. } if command == "printf 'second'"
  816. ));
  817. }
  818. #[test]
  819. fn stops_running_hooks_after_failure() {
  820. // given
  821. let runner = HookRunner::new(RuntimeHookConfig::new(
  822. vec![
  823. shell_snippet("printf 'broken'; exit 1"),
  824. shell_snippet("printf 'later'"),
  825. ],
  826. Vec::new(),
  827. Vec::new(),
  828. ));
  829. // when
  830. let result = runner.run_pre_tool_use("Edit", r#"{"file":"src/lib.rs"}"#);
  831. // then
  832. assert!(result.is_failed());
  833. assert!(result
  834. .messages()
  835. .iter()
  836. .any(|message| message.contains("broken")));
  837. assert!(!result.messages().iter().any(|message| message == "later"));
  838. }
  839. #[test]
  840. fn abort_signal_cancels_long_running_hook_and_reports_progress() {
  841. let runner = HookRunner::new(RuntimeHookConfig::new(
  842. vec![shell_snippet("sleep 5")],
  843. Vec::new(),
  844. Vec::new(),
  845. ));
  846. let abort_signal = HookAbortSignal::new();
  847. let abort_signal_for_thread = abort_signal.clone();
  848. let mut reporter = RecordingReporter { events: Vec::new() };
  849. thread::spawn(move || {
  850. thread::sleep(Duration::from_millis(100));
  851. abort_signal_for_thread.abort();
  852. });
  853. let result = runner.run_pre_tool_use_with_context(
  854. "bash",
  855. r#"{"command":"sleep 5"}"#,
  856. Some(&abort_signal),
  857. Some(&mut reporter),
  858. );
  859. assert!(result.is_cancelled());
  860. assert!(reporter.events.iter().any(|event| matches!(
  861. event,
  862. HookProgressEvent::Started {
  863. event: HookEvent::PreToolUse,
  864. ..
  865. }
  866. )));
  867. assert!(reporter.events.iter().any(|event| matches!(
  868. event,
  869. HookProgressEvent::Cancelled {
  870. event: HookEvent::PreToolUse,
  871. ..
  872. }
  873. )));
  874. }
  875. #[cfg(windows)]
  876. fn shell_snippet(script: &str) -> String {
  877. script.replace('\'', "\"")
  878. }
  879. #[cfg(not(windows))]
  880. fn shell_snippet(script: &str) -> String {
  881. script.to_string()
  882. }
  883. }