|
@@ -9,6 +9,8 @@ use std::io::{self, Read, Write};
|
|
|
use std::net::TcpListener;
|
|
use std::net::TcpListener;
|
|
|
use std::path::{Path, PathBuf};
|
|
use std::path::{Path, PathBuf};
|
|
|
use std::process::Command;
|
|
use std::process::Command;
|
|
|
|
|
+use std::sync::mpsc::{self, Receiver, Sender};
|
|
|
|
|
+use std::thread::{self, JoinHandle};
|
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
|
|
|
|
|
|
|
use api::{
|
|
use api::{
|
|
@@ -984,6 +986,61 @@ struct LiveCli {
|
|
|
session: SessionHandle,
|
|
session: SessionHandle,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+struct HookAbortMonitor {
|
|
|
|
|
+ stop_tx: Option<Sender<()>>,
|
|
|
|
|
+ join_handle: Option<JoinHandle<()>>,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+impl HookAbortMonitor {
|
|
|
|
|
+ fn spawn(abort_signal: runtime::HookAbortSignal) -> Self {
|
|
|
|
|
+ Self::spawn_with_waiter(abort_signal, move |stop_rx, abort_signal| {
|
|
|
|
|
+ let Ok(runtime) = tokio::runtime::Builder::new_current_thread()
|
|
|
|
|
+ .enable_all()
|
|
|
|
|
+ .build()
|
|
|
|
|
+ else {
|
|
|
|
|
+ return;
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ runtime.block_on(async move {
|
|
|
|
|
+ let wait_for_stop = tokio::task::spawn_blocking(move || {
|
|
|
|
|
+ let _ = stop_rx.recv();
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ tokio::select! {
|
|
|
|
|
+ result = tokio::signal::ctrl_c() => {
|
|
|
|
|
+ if result.is_ok() {
|
|
|
|
|
+ abort_signal.abort();
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ _ = wait_for_stop => {}
|
|
|
|
|
+ }
|
|
|
|
|
+ });
|
|
|
|
|
+ })
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ fn spawn_with_waiter<F>(abort_signal: runtime::HookAbortSignal, wait_for_interrupt: F) -> Self
|
|
|
|
|
+ where
|
|
|
|
|
+ F: FnOnce(Receiver<()>, runtime::HookAbortSignal) + Send + 'static,
|
|
|
|
|
+ {
|
|
|
|
|
+ let (stop_tx, stop_rx) = mpsc::channel();
|
|
|
|
|
+ let join_handle = thread::spawn(move || wait_for_interrupt(stop_rx, abort_signal));
|
|
|
|
|
+
|
|
|
|
|
+ Self {
|
|
|
|
|
+ stop_tx: Some(stop_tx),
|
|
|
|
|
+ join_handle: Some(join_handle),
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ fn stop(mut self) {
|
|
|
|
|
+ if let Some(stop_tx) = self.stop_tx.take() {
|
|
|
|
|
+ let _ = stop_tx.send(());
|
|
|
|
|
+ }
|
|
|
|
|
+ if let Some(join_handle) = self.join_handle.take() {
|
|
|
|
|
+ let _ = join_handle.join();
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
impl LiveCli {
|
|
impl LiveCli {
|
|
|
fn new(
|
|
fn new(
|
|
|
model: String,
|
|
model: String,
|
|
@@ -1040,6 +1097,19 @@ impl LiveCli {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
fn run_turn(&mut self, input: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|
fn run_turn(&mut self, input: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
|
+ let session = self.runtime.session().clone();
|
|
|
|
|
+ let hook_abort_signal = runtime::HookAbortSignal::new();
|
|
|
|
|
+ let mut runtime = build_runtime(
|
|
|
|
|
+ session,
|
|
|
|
|
+ self.model.clone(),
|
|
|
|
|
+ self.system_prompt.clone(),
|
|
|
|
|
+ true,
|
|
|
|
|
+ true,
|
|
|
|
|
+ self.allowed_tools.clone(),
|
|
|
|
|
+ self.permission_mode,
|
|
|
|
|
+ )?
|
|
|
|
|
+ .with_hook_abort_signal(hook_abort_signal.clone());
|
|
|
|
|
+ let hook_abort_monitor = HookAbortMonitor::spawn(hook_abort_signal);
|
|
|
let mut spinner = Spinner::new();
|
|
let mut spinner = Spinner::new();
|
|
|
let mut stdout = io::stdout();
|
|
let mut stdout = io::stdout();
|
|
|
spinner.tick(
|
|
spinner.tick(
|
|
@@ -1048,7 +1118,9 @@ impl LiveCli {
|
|
|
&mut stdout,
|
|
&mut stdout,
|
|
|
)?;
|
|
)?;
|
|
|
let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode);
|
|
let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode);
|
|
|
- let result = self.runtime.run_turn(input, Some(&mut permission_prompter));
|
|
|
|
|
|
|
+ let result = runtime.run_turn(input, Some(&mut permission_prompter));
|
|
|
|
|
+ hook_abort_monitor.stop();
|
|
|
|
|
+ self.runtime = runtime;
|
|
|
match result {
|
|
match result {
|
|
|
Ok(_) => {
|
|
Ok(_) => {
|
|
|
spinner.finish(
|
|
spinner.finish(
|
|
@@ -1084,6 +1156,7 @@ impl LiveCli {
|
|
|
|
|
|
|
|
fn run_prompt_json(&mut self, input: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|
fn run_prompt_json(&mut self, input: &str) -> Result<(), Box<dyn std::error::Error>> {
|
|
|
let session = self.runtime.session().clone();
|
|
let session = self.runtime.session().clone();
|
|
|
|
|
+ let hook_abort_signal = runtime::HookAbortSignal::new();
|
|
|
let mut runtime = build_runtime(
|
|
let mut runtime = build_runtime(
|
|
|
session,
|
|
session,
|
|
|
self.model.clone(),
|
|
self.model.clone(),
|
|
@@ -1092,9 +1165,13 @@ impl LiveCli {
|
|
|
false,
|
|
false,
|
|
|
self.allowed_tools.clone(),
|
|
self.allowed_tools.clone(),
|
|
|
self.permission_mode,
|
|
self.permission_mode,
|
|
|
- )?;
|
|
|
|
|
|
|
+ )?
|
|
|
|
|
+ .with_hook_abort_signal(hook_abort_signal.clone());
|
|
|
|
|
+ let hook_abort_monitor = HookAbortMonitor::spawn(hook_abort_signal);
|
|
|
let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode);
|
|
let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode);
|
|
|
- let summary = runtime.run_turn(input, Some(&mut permission_prompter))?;
|
|
|
|
|
|
|
+ let result = runtime.run_turn(input, Some(&mut permission_prompter));
|
|
|
|
|
+ hook_abort_monitor.stop();
|
|
|
|
|
+ let summary = result?;
|
|
|
self.runtime = runtime;
|
|
self.runtime = runtime;
|
|
|
self.persist_session()?;
|
|
self.persist_session()?;
|
|
|
println!(
|
|
println!(
|
|
@@ -2871,12 +2948,17 @@ mod tests {
|
|
|
normalize_permission_mode, parse_args, parse_git_status_metadata, print_help_to,
|
|
normalize_permission_mode, parse_args, parse_git_status_metadata, print_help_to,
|
|
|
push_output_block, render_config_report, render_memory_report, render_repl_help,
|
|
push_output_block, render_config_report, render_memory_report, render_repl_help,
|
|
|
resolve_model_alias, response_to_events, resume_supported_slash_commands, status_context,
|
|
resolve_model_alias, response_to_events, resume_supported_slash_commands, status_context,
|
|
|
- CliAction, CliOutputFormat, SlashCommand, StatusUsage, DEFAULT_MODEL,
|
|
|
|
|
|
|
+ CliAction, CliOutputFormat, HookAbortMonitor, SlashCommand, StatusUsage, DEFAULT_MODEL,
|
|
|
};
|
|
};
|
|
|
use api::{MessageResponse, OutputContentBlock, Usage};
|
|
use api::{MessageResponse, OutputContentBlock, Usage};
|
|
|
- use runtime::{AssistantEvent, ContentBlock, ConversationMessage, MessageRole, PermissionMode};
|
|
|
|
|
|
|
+ use runtime::{
|
|
|
|
|
+ AssistantEvent, ContentBlock, ConversationMessage, HookAbortSignal, MessageRole,
|
|
|
|
|
+ PermissionMode,
|
|
|
|
|
+ };
|
|
|
use serde_json::json;
|
|
use serde_json::json;
|
|
|
use std::path::PathBuf;
|
|
use std::path::PathBuf;
|
|
|
|
|
+ use std::sync::mpsc;
|
|
|
|
|
+ use std::time::Duration;
|
|
|
|
|
|
|
|
#[test]
|
|
#[test]
|
|
|
fn defaults_to_repl_when_no_args() {
|
|
fn defaults_to_repl_when_no_args() {
|
|
@@ -3535,4 +3617,43 @@ mod tests {
|
|
|
if name == "read_file" && input == "{\"path\":\"rust/Cargo.toml\"}"
|
|
if name == "read_file" && input == "{\"path\":\"rust/Cargo.toml\"}"
|
|
|
));
|
|
));
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn hook_abort_monitor_stops_without_aborting() {
|
|
|
|
|
+ let abort_signal = HookAbortSignal::new();
|
|
|
|
|
+ let (ready_tx, ready_rx) = mpsc::channel();
|
|
|
|
|
+ let monitor = HookAbortMonitor::spawn_with_waiter(
|
|
|
|
|
+ abort_signal.clone(),
|
|
|
|
|
+ move |stop_rx, abort_signal| {
|
|
|
|
|
+ ready_tx.send(()).expect("ready signal");
|
|
|
|
|
+ let _ = stop_rx.recv();
|
|
|
|
|
+ assert!(!abort_signal.is_aborted());
|
|
|
|
|
+ },
|
|
|
|
|
+ );
|
|
|
|
|
+
|
|
|
|
|
+ ready_rx.recv().expect("waiter should be ready");
|
|
|
|
|
+ monitor.stop();
|
|
|
|
|
+
|
|
|
|
|
+ assert!(!abort_signal.is_aborted());
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn hook_abort_monitor_propagates_interrupt() {
|
|
|
|
|
+ let abort_signal = HookAbortSignal::new();
|
|
|
|
|
+ let (done_tx, done_rx) = mpsc::channel();
|
|
|
|
|
+ let monitor = HookAbortMonitor::spawn_with_waiter(
|
|
|
|
|
+ abort_signal.clone(),
|
|
|
|
|
+ move |_stop_rx, abort_signal| {
|
|
|
|
|
+ abort_signal.abort();
|
|
|
|
|
+ done_tx.send(()).expect("done signal");
|
|
|
|
|
+ },
|
|
|
|
|
+ );
|
|
|
|
|
+
|
|
|
|
|
+ done_rx
|
|
|
|
|
+ .recv_timeout(Duration::from_secs(1))
|
|
|
|
|
+ .expect("interrupt should complete");
|
|
|
|
|
+ monitor.stop();
|
|
|
|
|
+
|
|
|
|
|
+ assert!(abort_signal.is_aborted());
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|