| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347 |
- use std::ffi::OsStr;
- use std::process::Command;
- use serde_json::json;
- use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
- #[derive(Debug, Clone, Copy, PartialEq, Eq)]
- pub enum HookEvent {
- PreToolUse,
- PostToolUse,
- }
- impl HookEvent {
- fn as_str(self) -> &'static str {
- match self {
- Self::PreToolUse => "PreToolUse",
- Self::PostToolUse => "PostToolUse",
- }
- }
- }
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct HookRunResult {
- denied: bool,
- messages: Vec<String>,
- }
- impl HookRunResult {
- #[must_use]
- pub fn allow(messages: Vec<String>) -> Self {
- Self {
- denied: false,
- messages,
- }
- }
- #[must_use]
- pub fn is_denied(&self) -> bool {
- self.denied
- }
- #[must_use]
- pub fn messages(&self) -> &[String] {
- &self.messages
- }
- }
- #[derive(Debug, Clone, PartialEq, Eq, Default)]
- pub struct HookRunner {
- config: RuntimeHookConfig,
- }
- impl HookRunner {
- #[must_use]
- pub fn new(config: RuntimeHookConfig) -> Self {
- Self { config }
- }
- #[must_use]
- pub fn from_feature_config(feature_config: &RuntimeFeatureConfig) -> Self {
- Self::new(feature_config.hooks().clone())
- }
- #[must_use]
- pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
- Self::run_commands(
- HookEvent::PreToolUse,
- self.config.pre_tool_use(),
- tool_name,
- tool_input,
- None,
- false,
- )
- }
- #[must_use]
- pub fn run_post_tool_use(
- &self,
- tool_name: &str,
- tool_input: &str,
- tool_output: &str,
- is_error: bool,
- ) -> HookRunResult {
- Self::run_commands(
- HookEvent::PostToolUse,
- self.config.post_tool_use(),
- tool_name,
- tool_input,
- Some(tool_output),
- is_error,
- )
- }
- fn run_commands(
- event: HookEvent,
- commands: &[String],
- tool_name: &str,
- tool_input: &str,
- tool_output: Option<&str>,
- is_error: bool,
- ) -> HookRunResult {
- if commands.is_empty() {
- return HookRunResult::allow(Vec::new());
- }
- let payload = json!({
- "hook_event_name": event.as_str(),
- "tool_name": tool_name,
- "tool_input": parse_tool_input(tool_input),
- "tool_input_json": tool_input,
- "tool_output": tool_output,
- "tool_result_is_error": is_error,
- })
- .to_string();
- let mut messages = Vec::new();
- for command in commands {
- match Self::run_command(
- command,
- event,
- tool_name,
- tool_input,
- tool_output,
- is_error,
- &payload,
- ) {
- HookCommandOutcome::Allow { message } => {
- if let Some(message) = message {
- messages.push(message);
- }
- }
- HookCommandOutcome::Deny { message } => {
- let message = message.unwrap_or_else(|| {
- format!("{} hook denied tool `{tool_name}`", event.as_str())
- });
- messages.push(message);
- return HookRunResult {
- denied: true,
- messages,
- };
- }
- HookCommandOutcome::Warn { message } => messages.push(message),
- }
- }
- HookRunResult::allow(messages)
- }
- fn run_command(
- command: &str,
- event: HookEvent,
- tool_name: &str,
- tool_input: &str,
- tool_output: Option<&str>,
- is_error: bool,
- payload: &str,
- ) -> HookCommandOutcome {
- let mut child = shell_command(command);
- child.stdin(std::process::Stdio::piped());
- child.stdout(std::process::Stdio::piped());
- child.stderr(std::process::Stdio::piped());
- child.env("HOOK_EVENT", event.as_str());
- child.env("HOOK_TOOL_NAME", tool_name);
- child.env("HOOK_TOOL_INPUT", tool_input);
- child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" });
- if let Some(tool_output) = tool_output {
- child.env("HOOK_TOOL_OUTPUT", tool_output);
- }
- match child.output_with_stdin(payload.as_bytes()) {
- Ok(output) => {
- let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
- let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
- let message = (!stdout.is_empty()).then_some(stdout);
- match output.status.code() {
- Some(0) => HookCommandOutcome::Allow { message },
- Some(2) => HookCommandOutcome::Deny { message },
- Some(code) => HookCommandOutcome::Warn {
- message: format_hook_warning(
- command,
- code,
- message.as_deref(),
- stderr.as_str(),
- ),
- },
- None => HookCommandOutcome::Warn {
- message: format!(
- "{} hook `{command}` terminated by signal while handling `{tool_name}`",
- event.as_str()
- ),
- },
- }
- }
- Err(error) => HookCommandOutcome::Warn {
- message: format!(
- "{} hook `{command}` failed to start for `{tool_name}`: {error}",
- event.as_str()
- ),
- },
- }
- }
- }
- enum HookCommandOutcome {
- Allow { message: Option<String> },
- Deny { message: Option<String> },
- Warn { message: String },
- }
- fn parse_tool_input(tool_input: &str) -> serde_json::Value {
- serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input }))
- }
- fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String {
- let mut message =
- format!("Hook `{command}` exited with status {code}; allowing tool execution to continue");
- if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) {
- message.push_str(": ");
- message.push_str(stdout);
- } else if !stderr.is_empty() {
- message.push_str(": ");
- message.push_str(stderr);
- }
- message
- }
- fn shell_command(command: &str) -> CommandWithStdin {
- #[cfg(windows)]
- let mut command_builder = {
- let mut command_builder = Command::new("cmd");
- command_builder.arg("/C").arg(command);
- CommandWithStdin::new(command_builder)
- };
- #[cfg(not(windows))]
- let command_builder = {
- let mut command_builder = Command::new("sh");
- command_builder.arg("-lc").arg(command);
- CommandWithStdin::new(command_builder)
- };
- command_builder
- }
- struct CommandWithStdin {
- command: Command,
- }
- impl CommandWithStdin {
- fn new(command: Command) -> Self {
- Self { command }
- }
- fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self {
- self.command.stdin(cfg);
- self
- }
- fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self {
- self.command.stdout(cfg);
- self
- }
- fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self {
- self.command.stderr(cfg);
- self
- }
- fn env<K, V>(&mut self, key: K, value: V) -> &mut Self
- where
- K: AsRef<OsStr>,
- V: AsRef<OsStr>,
- {
- self.command.env(key, value);
- self
- }
- fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result<std::process::Output> {
- let mut child = self.command.spawn()?;
- if let Some(mut child_stdin) = child.stdin.take() {
- use std::io::Write;
- child_stdin.write_all(stdin)?;
- }
- child.wait_with_output()
- }
- }
- #[cfg(test)]
- mod tests {
- use super::{HookRunResult, HookRunner};
- use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
- #[test]
- fn allows_exit_code_zero_and_captures_stdout() {
- let runner = HookRunner::new(RuntimeHookConfig::new(
- vec![shell_snippet("printf 'pre ok'")],
- Vec::new(),
- ));
- let result = runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#);
- assert_eq!(result, HookRunResult::allow(vec!["pre ok".to_string()]));
- }
- #[test]
- fn denies_exit_code_two() {
- let runner = HookRunner::new(RuntimeHookConfig::new(
- vec![shell_snippet("printf 'blocked by hook'; exit 2")],
- Vec::new(),
- ));
- let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
- assert!(result.is_denied());
- assert_eq!(result.messages(), &["blocked by hook".to_string()]);
- }
- #[test]
- fn warns_for_other_non_zero_statuses() {
- let runner = HookRunner::from_feature_config(&RuntimeFeatureConfig::default().with_hooks(
- RuntimeHookConfig::new(
- vec![shell_snippet("printf 'warning hook'; exit 1")],
- Vec::new(),
- ),
- ));
- let result = runner.run_pre_tool_use("Edit", r#"{"file":"src/lib.rs"}"#);
- assert!(!result.is_denied());
- assert!(result
- .messages()
- .iter()
- .any(|message| message.contains("allowing tool execution to continue")));
- }
- #[cfg(windows)]
- fn shell_snippet(script: &str) -> String {
- script.replace('\'', "\"")
- }
- #[cfg(not(windows))]
- fn shell_snippet(script: &str) -> String {
- script.to_string()
- }
- }
|