|
@@ -4,6 +4,8 @@ use std::fmt::{Display, Formatter};
|
|
|
use crate::compact::{
|
|
use crate::compact::{
|
|
|
compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
|
|
compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
|
|
|
};
|
|
};
|
|
|
|
|
+use crate::config::RuntimeFeatureConfig;
|
|
|
|
|
+use crate::hooks::{HookRunResult, HookRunner};
|
|
|
use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter};
|
|
use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter};
|
|
|
use crate::session::{ContentBlock, ConversationMessage, Session};
|
|
use crate::session::{ContentBlock, ConversationMessage, Session};
|
|
|
use crate::usage::{TokenUsage, UsageTracker};
|
|
use crate::usage::{TokenUsage, UsageTracker};
|
|
@@ -94,6 +96,7 @@ pub struct ConversationRuntime<C, T> {
|
|
|
system_prompt: Vec<String>,
|
|
system_prompt: Vec<String>,
|
|
|
max_iterations: usize,
|
|
max_iterations: usize,
|
|
|
usage_tracker: UsageTracker,
|
|
usage_tracker: UsageTracker,
|
|
|
|
|
+ hook_runner: HookRunner,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
impl<C, T> ConversationRuntime<C, T>
|
|
impl<C, T> ConversationRuntime<C, T>
|
|
@@ -108,6 +111,25 @@ where
|
|
|
tool_executor: T,
|
|
tool_executor: T,
|
|
|
permission_policy: PermissionPolicy,
|
|
permission_policy: PermissionPolicy,
|
|
|
system_prompt: Vec<String>,
|
|
system_prompt: Vec<String>,
|
|
|
|
|
+ ) -> Self {
|
|
|
|
|
+ Self::new_with_features(
|
|
|
|
|
+ session,
|
|
|
|
|
+ api_client,
|
|
|
|
|
+ tool_executor,
|
|
|
|
|
+ permission_policy,
|
|
|
|
|
+ system_prompt,
|
|
|
|
|
+ RuntimeFeatureConfig::default(),
|
|
|
|
|
+ )
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[must_use]
|
|
|
|
|
+ pub fn new_with_features(
|
|
|
|
|
+ session: Session,
|
|
|
|
|
+ api_client: C,
|
|
|
|
|
+ tool_executor: T,
|
|
|
|
|
+ permission_policy: PermissionPolicy,
|
|
|
|
|
+ system_prompt: Vec<String>,
|
|
|
|
|
+ feature_config: RuntimeFeatureConfig,
|
|
|
) -> Self {
|
|
) -> Self {
|
|
|
let usage_tracker = UsageTracker::from_session(&session);
|
|
let usage_tracker = UsageTracker::from_session(&session);
|
|
|
Self {
|
|
Self {
|
|
@@ -118,6 +140,7 @@ where
|
|
|
system_prompt,
|
|
system_prompt,
|
|
|
max_iterations: usize::MAX,
|
|
max_iterations: usize::MAX,
|
|
|
usage_tracker,
|
|
usage_tracker,
|
|
|
|
|
+ hook_runner: HookRunner::from_feature_config(&feature_config),
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -185,19 +208,41 @@ where
|
|
|
|
|
|
|
|
let result_message = match permission_outcome {
|
|
let result_message = match permission_outcome {
|
|
|
PermissionOutcome::Allow => {
|
|
PermissionOutcome::Allow => {
|
|
|
- match self.tool_executor.execute(&tool_name, &input) {
|
|
|
|
|
- Ok(output) => ConversationMessage::tool_result(
|
|
|
|
|
|
|
+ let pre_hook_result = self.hook_runner.run_pre_tool_use(&tool_name, &input);
|
|
|
|
|
+ if pre_hook_result.is_denied() {
|
|
|
|
|
+ let deny_message = format!("PreToolUse hook denied tool `{tool_name}`");
|
|
|
|
|
+ ConversationMessage::tool_result(
|
|
|
tool_use_id,
|
|
tool_use_id,
|
|
|
tool_name,
|
|
tool_name,
|
|
|
|
|
+ format_hook_message(&pre_hook_result, &deny_message),
|
|
|
|
|
+ true,
|
|
|
|
|
+ )
|
|
|
|
|
+ } else {
|
|
|
|
|
+ let (mut output, mut is_error) =
|
|
|
|
|
+ match self.tool_executor.execute(&tool_name, &input) {
|
|
|
|
|
+ Ok(output) => (output, false),
|
|
|
|
|
+ Err(error) => (error.to_string(), true),
|
|
|
|
|
+ };
|
|
|
|
|
+ output = merge_hook_feedback(pre_hook_result.messages(), output, false);
|
|
|
|
|
+
|
|
|
|
|
+ let post_hook_result = self
|
|
|
|
|
+ .hook_runner
|
|
|
|
|
+ .run_post_tool_use(&tool_name, &input, &output, is_error);
|
|
|
|
|
+ if post_hook_result.is_denied() {
|
|
|
|
|
+ is_error = true;
|
|
|
|
|
+ }
|
|
|
|
|
+ output = merge_hook_feedback(
|
|
|
|
|
+ post_hook_result.messages(),
|
|
|
output,
|
|
output,
|
|
|
- false,
|
|
|
|
|
- ),
|
|
|
|
|
- Err(error) => ConversationMessage::tool_result(
|
|
|
|
|
|
|
+ post_hook_result.is_denied(),
|
|
|
|
|
+ );
|
|
|
|
|
+
|
|
|
|
|
+ ConversationMessage::tool_result(
|
|
|
tool_use_id,
|
|
tool_use_id,
|
|
|
tool_name,
|
|
tool_name,
|
|
|
- error.to_string(),
|
|
|
|
|
- true,
|
|
|
|
|
- ),
|
|
|
|
|
|
|
+ output,
|
|
|
|
|
+ is_error,
|
|
|
|
|
+ )
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
PermissionOutcome::Deny { reason } => {
|
|
PermissionOutcome::Deny { reason } => {
|
|
@@ -290,6 +335,32 @@ fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+fn format_hook_message(result: &HookRunResult, fallback: &str) -> String {
|
|
|
|
|
+ if result.messages().is_empty() {
|
|
|
|
|
+ fallback.to_string()
|
|
|
|
|
+ } else {
|
|
|
|
|
+ result.messages().join("\n")
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> String {
|
|
|
|
|
+ if messages.is_empty() {
|
|
|
|
|
+ return output;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ let mut sections = Vec::new();
|
|
|
|
|
+ if !output.trim().is_empty() {
|
|
|
|
|
+ sections.push(output);
|
|
|
|
|
+ }
|
|
|
|
|
+ let label = if denied {
|
|
|
|
|
+ "Hook feedback (denied)"
|
|
|
|
|
+ } else {
|
|
|
|
|
+ "Hook feedback"
|
|
|
|
|
+ };
|
|
|
|
|
+ sections.push(format!("{label}:\n{}", messages.join("\n")));
|
|
|
|
|
+ sections.join("\n\n")
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
|
|
type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
|
|
|
|
|
|
|
|
#[derive(Default)]
|
|
#[derive(Default)]
|
|
@@ -329,6 +400,7 @@ mod tests {
|
|
|
StaticToolExecutor,
|
|
StaticToolExecutor,
|
|
|
};
|
|
};
|
|
|
use crate::compact::CompactionConfig;
|
|
use crate::compact::CompactionConfig;
|
|
|
|
|
+ use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
|
|
|
use crate::permissions::{
|
|
use crate::permissions::{
|
|
|
PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
|
|
PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
|
|
|
PermissionRequest,
|
|
PermissionRequest,
|
|
@@ -503,6 +575,141 @@ mod tests {
|
|
|
));
|
|
));
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn denies_tool_use_when_pre_tool_hook_blocks() {
|
|
|
|
|
+ struct SingleCallApiClient;
|
|
|
|
|
+ impl ApiClient for SingleCallApiClient {
|
|
|
|
|
+ fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
|
|
|
+ if request
|
|
|
|
|
+ .messages
|
|
|
|
|
+ .iter()
|
|
|
|
|
+ .any(|message| message.role == MessageRole::Tool)
|
|
|
|
|
+ {
|
|
|
|
|
+ return Ok(vec![
|
|
|
|
|
+ AssistantEvent::TextDelta("blocked".to_string()),
|
|
|
|
|
+ AssistantEvent::MessageStop,
|
|
|
|
|
+ ]);
|
|
|
|
|
+ }
|
|
|
|
|
+ Ok(vec![
|
|
|
|
|
+ AssistantEvent::ToolUse {
|
|
|
|
|
+ id: "tool-1".to_string(),
|
|
|
|
|
+ name: "blocked".to_string(),
|
|
|
|
|
+ input: r#"{"path":"secret.txt"}"#.to_string(),
|
|
|
|
|
+ },
|
|
|
|
|
+ AssistantEvent::MessageStop,
|
|
|
|
|
+ ])
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ let mut runtime = ConversationRuntime::new_with_features(
|
|
|
|
|
+ Session::new(),
|
|
|
|
|
+ SingleCallApiClient,
|
|
|
|
|
+ StaticToolExecutor::new().register("blocked", |_input| {
|
|
|
|
|
+ panic!("tool should not execute when hook denies")
|
|
|
|
|
+ }),
|
|
|
|
|
+ PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
|
|
|
+ vec!["system".to_string()],
|
|
|
|
|
+ RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
|
|
|
|
+ vec![shell_snippet("printf 'blocked by hook'; exit 2")],
|
|
|
|
|
+ Vec::new(),
|
|
|
|
|
+ )),
|
|
|
|
|
+ );
|
|
|
|
|
+
|
|
|
|
|
+ let summary = runtime
|
|
|
|
|
+ .run_turn("use the tool", None)
|
|
|
|
|
+ .expect("conversation should continue after hook denial");
|
|
|
|
|
+
|
|
|
|
|
+ assert_eq!(summary.tool_results.len(), 1);
|
|
|
|
|
+ let ContentBlock::ToolResult {
|
|
|
|
|
+ is_error, output, ..
|
|
|
|
|
+ } = &summary.tool_results[0].blocks[0]
|
|
|
|
|
+ else {
|
|
|
|
|
+ panic!("expected tool result block");
|
|
|
|
|
+ };
|
|
|
|
|
+ assert!(
|
|
|
|
|
+ *is_error,
|
|
|
|
|
+ "hook denial should produce an error result: {output}"
|
|
|
|
|
+ );
|
|
|
|
|
+ assert!(
|
|
|
|
|
+ output.contains("denied tool") || output.contains("blocked by hook"),
|
|
|
|
|
+ "unexpected hook denial output: {output:?}"
|
|
|
|
|
+ );
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn appends_post_tool_hook_feedback_to_tool_result() {
|
|
|
|
|
+ struct TwoCallApiClient {
|
|
|
|
|
+ calls: usize,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ impl ApiClient for TwoCallApiClient {
|
|
|
|
|
+ fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
|
|
|
|
|
+ self.calls += 1;
|
|
|
|
|
+ match self.calls {
|
|
|
|
|
+ 1 => Ok(vec![
|
|
|
|
|
+ AssistantEvent::ToolUse {
|
|
|
|
|
+ id: "tool-1".to_string(),
|
|
|
|
|
+ name: "add".to_string(),
|
|
|
|
|
+ input: r#"{"lhs":2,"rhs":2}"#.to_string(),
|
|
|
|
|
+ },
|
|
|
|
|
+ AssistantEvent::MessageStop,
|
|
|
|
|
+ ]),
|
|
|
|
|
+ 2 => {
|
|
|
|
|
+ assert!(request
|
|
|
|
|
+ .messages
|
|
|
|
|
+ .iter()
|
|
|
|
|
+ .any(|message| message.role == MessageRole::Tool));
|
|
|
|
|
+ Ok(vec![
|
|
|
|
|
+ AssistantEvent::TextDelta("done".to_string()),
|
|
|
|
|
+ AssistantEvent::MessageStop,
|
|
|
|
|
+ ])
|
|
|
|
|
+ }
|
|
|
|
|
+ _ => Err(RuntimeError::new("unexpected extra API call")),
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ let mut runtime = ConversationRuntime::new_with_features(
|
|
|
|
|
+ Session::new(),
|
|
|
|
|
+ TwoCallApiClient { calls: 0 },
|
|
|
|
|
+ StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
|
|
|
|
|
+ PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
|
|
|
+ vec!["system".to_string()],
|
|
|
|
|
+ RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
|
|
|
|
+ vec![shell_snippet("printf 'pre hook ran'")],
|
|
|
|
|
+ vec![shell_snippet("printf 'post hook ran'")],
|
|
|
|
|
+ )),
|
|
|
|
|
+ );
|
|
|
|
|
+
|
|
|
|
|
+ let summary = runtime
|
|
|
|
|
+ .run_turn("use add", None)
|
|
|
|
|
+ .expect("tool loop succeeds");
|
|
|
|
|
+
|
|
|
|
|
+ assert_eq!(summary.tool_results.len(), 1);
|
|
|
|
|
+ let ContentBlock::ToolResult {
|
|
|
|
|
+ is_error, output, ..
|
|
|
|
|
+ } = &summary.tool_results[0].blocks[0]
|
|
|
|
|
+ else {
|
|
|
|
|
+ panic!("expected tool result block");
|
|
|
|
|
+ };
|
|
|
|
|
+ assert!(
|
|
|
|
|
+ !*is_error,
|
|
|
|
|
+ "post hook should preserve non-error result: {output:?}"
|
|
|
|
|
+ );
|
|
|
|
|
+ assert!(
|
|
|
|
|
+ output.contains("4"),
|
|
|
|
|
+ "tool output missing value: {output:?}"
|
|
|
|
|
+ );
|
|
|
|
|
+ assert!(
|
|
|
|
|
+ output.contains("pre hook ran"),
|
|
|
|
|
+ "tool output missing pre hook feedback: {output:?}"
|
|
|
|
|
+ );
|
|
|
|
|
+ assert!(
|
|
|
|
|
+ output.contains("post hook ran"),
|
|
|
|
|
+ "tool output missing post hook feedback: {output:?}"
|
|
|
|
|
+ );
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
#[test]
|
|
#[test]
|
|
|
fn reconstructs_usage_tracker_from_restored_session() {
|
|
fn reconstructs_usage_tracker_from_restored_session() {
|
|
|
struct SimpleApi;
|
|
struct SimpleApi;
|
|
@@ -581,4 +788,14 @@ mod tests {
|
|
|
MessageRole::System
|
|
MessageRole::System
|
|
|
);
|
|
);
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ #[cfg(windows)]
|
|
|
|
|
+ fn shell_snippet(script: &str) -> String {
|
|
|
|
|
+ script.replace('\'', "\"")
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[cfg(not(windows))]
|
|
|
|
|
+ fn shell_snippet(script: &str) -> String {
|
|
|
|
|
+ script.to_string()
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|