소스 검색

wip: hook-pipeline progress

Yeachan-Heo 2 달 전
부모
커밋
9efd029e26
4개의 변경된 파일124개의 추가작업 그리고 72개의 파일을 삭제
  1. 84 21
      rust/crates/runtime/src/conversation.rs
  2. 19 35
      rust/crates/runtime/src/hooks.rs
  3. 19 14
      rust/crates/rusty-claude-cli/src/main.rs
  4. 2 2
      rust/crates/rusty-claude-cli/src/render.rs

+ 84 - 21
rust/crates/runtime/src/conversation.rs

@@ -5,7 +5,7 @@ use crate::compact::{
     compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
 };
 use crate::config::RuntimeFeatureConfig;
-use crate::hooks::{HookAbortSignal, HookProgressReporter, HookRunResult, HookRunner};
+use crate::hooks::{HookAbortSignal, HookRunResult, HookRunner};
 use crate::permissions::{
     PermissionContext, PermissionOutcome, PermissionPolicy, PermissionPrompter,
 };
@@ -100,7 +100,6 @@ pub struct ConversationRuntime<C, T> {
     usage_tracker: UsageTracker,
     hook_runner: HookRunner,
     hook_abort_signal: HookAbortSignal,
-    hook_progress_reporter: Option<Box<dyn HookProgressReporter>>,
 }
 
 impl<C, T> ConversationRuntime<C, T>
@@ -122,18 +121,19 @@ where
             tool_executor,
             permission_policy,
             system_prompt,
-            &RuntimeFeatureConfig::default(),
+            RuntimeFeatureConfig::default(),
         )
     }
 
     #[must_use]
+    #[allow(clippy::needless_pass_by_value)]
     pub fn new_with_features(
         session: Session,
         api_client: C,
         tool_executor: T,
         permission_policy: PermissionPolicy,
         system_prompt: Vec<String>,
-        feature_config: &RuntimeFeatureConfig,
+        feature_config: RuntimeFeatureConfig,
     ) -> Self {
         let usage_tracker = UsageTracker::from_session(&session);
         Self {
@@ -144,8 +144,9 @@ where
             system_prompt,
             max_iterations: usize::MAX,
             usage_tracker,
-            hook_runner: HookRunner::from_feature_config(feature_config),
+            hook_runner: HookRunner::from_feature_config(&feature_config),
             hook_abort_signal: HookAbortSignal::default(),
+            hook_progress_reporter: None,
         }
     }
 
@@ -220,17 +221,12 @@ where
             }
 
             for (tool_use_id, tool_name, input) in pending_tool_uses {
-                let pre_hook_result = self.hook_runner.run_pre_tool_use_with_context(
-                    &tool_name,
-                    &input,
-                    Some(&self.hook_abort_signal),
-                    self.hook_progress_reporter.as_deref_mut(),
-                );
+                let pre_hook_result = self.run_pre_tool_use_hook(&tool_name, &input);
                 let effective_input = pre_hook_result
-                    .updated_input_json()
+                    .updated_input()
                     .map_or_else(|| input.clone(), ToOwned::to_owned);
                 let permission_context = PermissionContext::new(
-                    pre_hook_result.permission_decision(),
+                    pre_hook_result.permission_override(),
                     pre_hook_result.permission_reason().map(ToOwned::to_owned),
                 );
 
@@ -274,21 +270,17 @@ where
                         output = merge_hook_feedback(pre_hook_result.messages(), output, false);
 
                         let post_hook_result = if is_error {
-                            self.hook_runner.run_post_tool_use_failure_with_context(
+                            self.run_post_tool_use_failure_hook(
                                 &tool_name,
                                 &effective_input,
                                 &output,
-                                Some(&self.hook_abort_signal),
-                                self.hook_progress_reporter.as_deref_mut(),
                             )
                         } else {
-                            self.hook_runner.run_post_tool_use_with_context(
+                            self.run_post_tool_use_hook(
                                 &tool_name,
                                 &effective_input,
                                 &output,
                                 false,
-                                Some(&self.hook_abort_signal),
-                                self.hook_progress_reporter.as_deref_mut(),
                             )
                         };
                         if post_hook_result.is_denied() || post_hook_result.is_cancelled() {
@@ -322,6 +314,77 @@ where
         })
     }
 
+    fn run_pre_tool_use_hook(&mut self, tool_name: &str, input: &str) -> HookRunResult {
+        if let Some(reporter) = self.hook_progress_reporter.as_mut() {
+            self.hook_runner.run_pre_tool_use_with_context(
+                tool_name,
+                input,
+                Some(&self.hook_abort_signal),
+                Some(reporter.as_mut()),
+            )
+        } else {
+            self.hook_runner.run_pre_tool_use_with_context(
+                tool_name,
+                input,
+                Some(&self.hook_abort_signal),
+                None,
+            )
+        }
+    }
+
+    fn run_post_tool_use_hook(
+        &mut self,
+        tool_name: &str,
+        input: &str,
+        output: &str,
+        is_error: bool,
+    ) -> HookRunResult {
+        if let Some(reporter) = self.hook_progress_reporter.as_mut() {
+            self.hook_runner.run_post_tool_use_with_context(
+                tool_name,
+                input,
+                output,
+                is_error,
+                Some(&self.hook_abort_signal),
+                Some(reporter.as_mut()),
+            )
+        } else {
+            self.hook_runner.run_post_tool_use_with_context(
+                tool_name,
+                input,
+                output,
+                is_error,
+                Some(&self.hook_abort_signal),
+                None,
+            )
+        }
+    }
+
+    fn run_post_tool_use_failure_hook(
+        &mut self,
+        tool_name: &str,
+        input: &str,
+        output: &str,
+    ) -> HookRunResult {
+        if let Some(reporter) = self.hook_progress_reporter.as_mut() {
+            self.hook_runner.run_post_tool_use_failure_with_context(
+                tool_name,
+                input,
+                output,
+                Some(&self.hook_abort_signal),
+                Some(reporter.as_mut()),
+            )
+        } else {
+            self.hook_runner.run_post_tool_use_failure_with_context(
+                tool_name,
+                input,
+                output,
+                Some(&self.hook_abort_signal),
+                None,
+            )
+        }
+    }
+
     #[must_use]
     pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
         compact_session(&self.session, config)
@@ -669,7 +732,7 @@ mod tests {
             }),
             PermissionPolicy::new(PermissionMode::DangerFullAccess),
             vec!["system".to_string()],
-            &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
+            RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
                 vec![shell_snippet("printf 'blocked by hook'; exit 2")],
                 Vec::new(),
                 Vec::new(),
@@ -736,7 +799,7 @@ mod tests {
             StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
             PermissionPolicy::new(PermissionMode::DangerFullAccess),
             vec!["system".to_string()],
-            &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
+            RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
                 vec![shell_snippet("printf 'pre hook ran'")],
                 vec![shell_snippet("printf 'post hook ran'")],
                 Vec::new(),

+ 19 - 35
rust/crates/runtime/src/hooks.rs

@@ -1,16 +1,14 @@
 use std::ffi::OsStr;
+use std::io::Write;
 use std::process::{Command, Stdio};
 use std::sync::{
     atomic::{AtomicBool, Ordering},
     Arc,
 };
+use std::thread;
 use std::time::Duration;
 
 use serde_json::{json, Value};
-use tokio::io::AsyncWriteExt;
-use tokio::process::Command as TokioCommand;
-use tokio::runtime::Builder;
-use tokio::time::sleep;
 
 use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
 use crate::permissions::PermissionOverride;
@@ -172,7 +170,7 @@ impl HookRunner {
         abort_signal: Option<&HookAbortSignal>,
         reporter: Option<&mut dyn HookProgressReporter>,
     ) -> HookRunResult {
-        self.run_commands(
+        Self::run_commands(
             HookEvent::PreToolUse,
             self.config.pre_tool_use(),
             tool_name,
@@ -222,7 +220,7 @@ impl HookRunner {
         abort_signal: Option<&HookAbortSignal>,
         reporter: Option<&mut dyn HookProgressReporter>,
     ) -> HookRunResult {
-        self.run_commands(
+        Self::run_commands(
             HookEvent::PostToolUse,
             self.config.post_tool_use(),
             tool_name,
@@ -272,7 +270,7 @@ impl HookRunner {
         abort_signal: Option<&HookAbortSignal>,
         reporter: Option<&mut dyn HookProgressReporter>,
     ) -> HookRunResult {
-        self.run_commands(
+        Self::run_commands(
             HookEvent::PostToolUseFailure,
             self.config.post_tool_use_failure(),
             tool_name,
@@ -303,7 +301,6 @@ impl HookRunner {
 
     #[allow(clippy::too_many_arguments)]
     fn run_commands(
-        &self,
         event: HookEvent,
         commands: &[String],
         tool_name: &str,
@@ -675,36 +672,23 @@ impl CommandWithStdin {
         stdin: &[u8],
         abort_signal: Option<&HookAbortSignal>,
     ) -> std::io::Result<CommandExecution> {
-        let runtime = Builder::new_current_thread().enable_all().build()?;
-        let mut command =
-            TokioCommand::from(std::mem::replace(&mut self.command, Command::new("true")));
-        let stdin = stdin.to_vec();
-        let abort_signal = abort_signal.cloned();
-        runtime.block_on(async move {
-            let mut child = command.spawn()?;
-            if let Some(mut child_stdin) = child.stdin.take() {
-                child_stdin.write_all(&stdin).await?;
-            }
-
-            loop {
-                if abort_signal
-                    .as_ref()
-                    .is_some_and(HookAbortSignal::is_aborted)
-                {
-                    let _ = child.start_kill();
-                    let _ = child.wait().await;
-                    return Ok(CommandExecution::Cancelled);
-                }
+        let mut child = self.command.spawn()?;
+        if let Some(mut child_stdin) = child.stdin.take() {
+            child_stdin.write_all(stdin)?;
+        }
 
-                if let Some(status) = child.try_wait()? {
-                    let output = child.wait_with_output().await?;
-                    debug_assert_eq!(output.status.code(), status.code());
-                    return Ok(CommandExecution::Finished(output));
-                }
+        loop {
+            if abort_signal.is_some_and(HookAbortSignal::is_aborted) {
+                let _ = child.kill();
+                let _ = child.wait_with_output();
+                return Ok(CommandExecution::Cancelled);
+            }
 
-                sleep(Duration::from_millis(20)).await;
+            match child.try_wait()? {
+                Some(_) => return child.wait_with_output().map(CommandExecution::Finished),
+                None => thread::sleep(Duration::from_millis(20)),
             }
-        })
+        }
     }
 }
 

+ 19 - 14
rust/crates/rusty-claude-cli/src/main.rs

@@ -1923,14 +1923,15 @@ fn build_runtime(
 ) -> Result<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
 {
     let feature_config = build_runtime_feature_config()?;
-    Ok(ConversationRuntime::new_with_features(
+    let runtime = ConversationRuntime::new_with_features(
         session,
         AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?,
         CliToolExecutor::new(allowed_tools, emit_output),
         permission_policy(permission_mode, &feature_config),
         system_prompt,
-        &feature_config,
-    ))
+        feature_config,
+    );
+    Ok(runtime)
 }
 
 struct CliPermissionPrompter {
@@ -1953,6 +1954,9 @@ impl runtime::PermissionPrompter for CliPermissionPrompter {
         println!("  Tool             {}", request.tool_name);
         println!("  Current mode     {}", self.current_mode.as_str());
         println!("  Required mode    {}", request.required_mode.as_str());
+        if let Some(reason) = &request.reason {
+            println!("  Reason           {reason}");
+        }
         println!("  Input            {}", request.input);
         print!("Approve this tool call? [y/N]: ");
         let _ = io::stdout().flush();
@@ -2365,13 +2369,15 @@ fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String {
         .get("backgroundTaskId")
         .and_then(|value| value.as_str())
     {
-        lines[0].push_str(&format!(" backgrounded ({task_id})"));
+        use std::fmt::Write as _;
+        let _ = write!(lines[0], " backgrounded ({task_id})");
     } else if let Some(status) = parsed
         .get("returnCodeInterpretation")
         .and_then(|value| value.as_str())
         .filter(|status| !status.is_empty())
     {
-        lines[0].push_str(&format!(" {status}"));
+        use std::fmt::Write as _;
+        let _ = write!(lines[0], " {status}");
     }
 
     if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) {
@@ -2393,15 +2399,15 @@ fn format_read_result(icon: &str, parsed: &serde_json::Value) -> String {
     let path = extract_tool_path(file);
     let start_line = file
         .get("startLine")
-        .and_then(|value| value.as_u64())
+        .and_then(serde_json::Value::as_u64)
         .unwrap_or(1);
     let num_lines = file
         .get("numLines")
-        .and_then(|value| value.as_u64())
+        .and_then(serde_json::Value::as_u64)
         .unwrap_or(0);
     let total_lines = file
         .get("totalLines")
-        .and_then(|value| value.as_u64())
+        .and_then(serde_json::Value::as_u64)
         .unwrap_or(num_lines);
     let content = file
         .get("content")
@@ -2427,8 +2433,7 @@ fn format_write_result(icon: &str, parsed: &serde_json::Value) -> String {
     let line_count = parsed
         .get("content")
         .and_then(|value| value.as_str())
-        .map(|content| content.lines().count())
-        .unwrap_or(0);
+        .map_or(0, |content| content.lines().count());
     format!(
         "{icon} \x1b[1;32m✏️ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m",
         if kind == "create" { "Wrote" } else { "Updated" },
@@ -2459,7 +2464,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String {
     let path = extract_tool_path(parsed);
     let suffix = if parsed
         .get("replaceAll")
-        .and_then(|value| value.as_bool())
+        .and_then(serde_json::Value::as_bool)
         .unwrap_or(false)
     {
         " (replace all)"
@@ -2487,7 +2492,7 @@ fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String {
 fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String {
     let num_files = parsed
         .get("numFiles")
-        .and_then(|value| value.as_u64())
+        .and_then(serde_json::Value::as_u64)
         .unwrap_or(0);
     let filenames = parsed
         .get("filenames")
@@ -2511,11 +2516,11 @@ fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String {
 fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String {
     let num_matches = parsed
         .get("numMatches")
-        .and_then(|value| value.as_u64())
+        .and_then(serde_json::Value::as_u64)
         .unwrap_or(0);
     let num_files = parsed
         .get("numFiles")
-        .and_then(|value| value.as_u64())
+        .and_then(serde_json::Value::as_u64)
         .unwrap_or(0);
     let content = parsed
         .get("content")

+ 2 - 2
rust/crates/rusty-claude-cli/src/render.rs

@@ -286,7 +286,7 @@ impl TerminalRenderer {
     ) {
         match event {
             Event::Start(Tag::Heading { level, .. }) => {
-                self.start_heading(state, level as u8, output)
+                Self::start_heading(state, level as u8, output);
             }
             Event::End(TagEnd::Paragraph) => output.push_str("\n\n"),
             Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output),
@@ -426,7 +426,7 @@ impl TerminalRenderer {
         }
     }
 
-    fn start_heading(&self, state: &mut RenderState, level: u8, output: &mut String) {
+    fn start_heading(state: &mut RenderState, level: u8, output: &mut String) {
         state.heading_level = Some(level);
         if !output.is_empty() {
             output.push('\n');