Преглед изворни кода

feat: hook abort signal + Ctrl-C cancellation pipeline

Yeachan-Heo пре 2 месеци
родитељ
комит
eaf7dc83f0
2 измењених фајлова са 127 додато и 6 уклоњено
  1. 1 1
      rust/crates/rusty-claude-cli/Cargo.toml
  2. 126 5
      rust/crates/rusty-claude-cli/src/main.rs

+ 1 - 1
rust/crates/rusty-claude-cli/Cargo.toml

@@ -19,7 +19,7 @@ rustyline = "15"
 runtime = { path = "../runtime" }
 serde_json = "1"
 syntect = "5"
-tokio = { version = "1", features = ["rt-multi-thread", "time"] }
+tokio = { version = "1", features = ["rt-multi-thread", "signal", "time"] }
 tools = { path = "../tools" }
 
 [lints]

+ 126 - 5
rust/crates/rusty-claude-cli/src/main.rs

@@ -9,6 +9,8 @@ use std::io::{self, Read, Write};
 use std::net::TcpListener;
 use std::path::{Path, PathBuf};
 use std::process::Command;
+use std::sync::mpsc::{self, Receiver, Sender};
+use std::thread::{self, JoinHandle};
 use std::time::{SystemTime, UNIX_EPOCH};
 
 use api::{
@@ -984,6 +986,61 @@ struct LiveCli {
     session: SessionHandle,
 }
 
+struct HookAbortMonitor {
+    stop_tx: Option<Sender<()>>,
+    join_handle: Option<JoinHandle<()>>,
+}
+
+impl HookAbortMonitor {
+    fn spawn(abort_signal: runtime::HookAbortSignal) -> Self {
+        Self::spawn_with_waiter(abort_signal, move |stop_rx, abort_signal| {
+            let Ok(runtime) = tokio::runtime::Builder::new_current_thread()
+                .enable_all()
+                .build()
+            else {
+                return;
+            };
+
+            runtime.block_on(async move {
+                let wait_for_stop = tokio::task::spawn_blocking(move || {
+                    let _ = stop_rx.recv();
+                });
+
+                tokio::select! {
+                    result = tokio::signal::ctrl_c() => {
+                        if result.is_ok() {
+                            abort_signal.abort();
+                        }
+                    }
+                    _ = wait_for_stop => {}
+                }
+            });
+        })
+    }
+
+    fn spawn_with_waiter<F>(abort_signal: runtime::HookAbortSignal, wait_for_interrupt: F) -> Self
+    where
+        F: FnOnce(Receiver<()>, runtime::HookAbortSignal) + Send + 'static,
+    {
+        let (stop_tx, stop_rx) = mpsc::channel();
+        let join_handle = thread::spawn(move || wait_for_interrupt(stop_rx, abort_signal));
+
+        Self {
+            stop_tx: Some(stop_tx),
+            join_handle: Some(join_handle),
+        }
+    }
+
+    fn stop(mut self) {
+        if let Some(stop_tx) = self.stop_tx.take() {
+            let _ = stop_tx.send(());
+        }
+        if let Some(join_handle) = self.join_handle.take() {
+            let _ = join_handle.join();
+        }
+    }
+}
+
 impl LiveCli {
     fn new(
         model: String,
@@ -1040,6 +1097,19 @@ impl LiveCli {
     }
 
     fn run_turn(&mut self, input: &str) -> Result<(), Box<dyn std::error::Error>> {
+        let session = self.runtime.session().clone();
+        let hook_abort_signal = runtime::HookAbortSignal::new();
+        let mut runtime = build_runtime(
+            session,
+            self.model.clone(),
+            self.system_prompt.clone(),
+            true,
+            true,
+            self.allowed_tools.clone(),
+            self.permission_mode,
+        )?
+        .with_hook_abort_signal(hook_abort_signal.clone());
+        let hook_abort_monitor = HookAbortMonitor::spawn(hook_abort_signal);
         let mut spinner = Spinner::new();
         let mut stdout = io::stdout();
         spinner.tick(
@@ -1048,7 +1118,9 @@ impl LiveCli {
             &mut stdout,
         )?;
         let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode);
-        let result = self.runtime.run_turn(input, Some(&mut permission_prompter));
+        let result = runtime.run_turn(input, Some(&mut permission_prompter));
+        hook_abort_monitor.stop();
+        self.runtime = runtime;
         match result {
             Ok(_) => {
                 spinner.finish(
@@ -1084,6 +1156,7 @@ impl LiveCli {
 
     fn run_prompt_json(&mut self, input: &str) -> Result<(), Box<dyn std::error::Error>> {
         let session = self.runtime.session().clone();
+        let hook_abort_signal = runtime::HookAbortSignal::new();
         let mut runtime = build_runtime(
             session,
             self.model.clone(),
@@ -1092,9 +1165,13 @@ impl LiveCli {
             false,
             self.allowed_tools.clone(),
             self.permission_mode,
-        )?;
+        )?
+        .with_hook_abort_signal(hook_abort_signal.clone());
+        let hook_abort_monitor = HookAbortMonitor::spawn(hook_abort_signal);
         let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode);
-        let summary = runtime.run_turn(input, Some(&mut permission_prompter))?;
+        let result = runtime.run_turn(input, Some(&mut permission_prompter));
+        hook_abort_monitor.stop();
+        let summary = result?;
         self.runtime = runtime;
         self.persist_session()?;
         println!(
@@ -2871,12 +2948,17 @@ mod tests {
         normalize_permission_mode, parse_args, parse_git_status_metadata, print_help_to,
         push_output_block, render_config_report, render_memory_report, render_repl_help,
         resolve_model_alias, response_to_events, resume_supported_slash_commands, status_context,
-        CliAction, CliOutputFormat, SlashCommand, StatusUsage, DEFAULT_MODEL,
+        CliAction, CliOutputFormat, HookAbortMonitor, SlashCommand, StatusUsage, DEFAULT_MODEL,
     };
     use api::{MessageResponse, OutputContentBlock, Usage};
-    use runtime::{AssistantEvent, ContentBlock, ConversationMessage, MessageRole, PermissionMode};
+    use runtime::{
+        AssistantEvent, ContentBlock, ConversationMessage, HookAbortSignal, MessageRole,
+        PermissionMode,
+    };
     use serde_json::json;
     use std::path::PathBuf;
+    use std::sync::mpsc;
+    use std::time::Duration;
 
     #[test]
     fn defaults_to_repl_when_no_args() {
@@ -3535,4 +3617,43 @@ mod tests {
                 if name == "read_file" && input == "{\"path\":\"rust/Cargo.toml\"}"
         ));
     }
+
+    #[test]
+    fn hook_abort_monitor_stops_without_aborting() {
+        let abort_signal = HookAbortSignal::new();
+        let (ready_tx, ready_rx) = mpsc::channel();
+        let monitor = HookAbortMonitor::spawn_with_waiter(
+            abort_signal.clone(),
+            move |stop_rx, abort_signal| {
+                ready_tx.send(()).expect("ready signal");
+                let _ = stop_rx.recv();
+                assert!(!abort_signal.is_aborted());
+            },
+        );
+
+        ready_rx.recv().expect("waiter should be ready");
+        monitor.stop();
+
+        assert!(!abort_signal.is_aborted());
+    }
+
+    #[test]
+    fn hook_abort_monitor_propagates_interrupt() {
+        let abort_signal = HookAbortSignal::new();
+        let (done_tx, done_rx) = mpsc::channel();
+        let monitor = HookAbortMonitor::spawn_with_waiter(
+            abort_signal.clone(),
+            move |_stop_rx, abort_signal| {
+                abort_signal.abort();
+                done_tx.send(()).expect("done signal");
+            },
+        );
+
+        done_rx
+            .recv_timeout(Duration::from_secs(1))
+            .expect("interrupt should complete");
+        monitor.stop();
+
+        assert!(abort_signal.is_aborted());
+    }
 }