|
|
@@ -5,7 +5,7 @@ use crate::compact::{
|
|
|
compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
|
|
|
};
|
|
|
use crate::config::RuntimeFeatureConfig;
|
|
|
-use crate::hooks::{HookAbortSignal, HookProgressReporter, HookRunResult, HookRunner};
|
|
|
+use crate::hooks::{HookAbortSignal, HookRunResult, HookRunner};
|
|
|
use crate::permissions::{
|
|
|
PermissionContext, PermissionOutcome, PermissionPolicy, PermissionPrompter,
|
|
|
};
|
|
|
@@ -100,7 +100,6 @@ pub struct ConversationRuntime<C, T> {
|
|
|
usage_tracker: UsageTracker,
|
|
|
hook_runner: HookRunner,
|
|
|
hook_abort_signal: HookAbortSignal,
|
|
|
- hook_progress_reporter: Option<Box<dyn HookProgressReporter>>,
|
|
|
}
|
|
|
|
|
|
impl<C, T> ConversationRuntime<C, T>
|
|
|
@@ -122,18 +121,19 @@ where
|
|
|
tool_executor,
|
|
|
permission_policy,
|
|
|
system_prompt,
|
|
|
- &RuntimeFeatureConfig::default(),
|
|
|
+ RuntimeFeatureConfig::default(),
|
|
|
)
|
|
|
}
|
|
|
|
|
|
#[must_use]
|
|
|
+ #[allow(clippy::needless_pass_by_value)]
|
|
|
pub fn new_with_features(
|
|
|
session: Session,
|
|
|
api_client: C,
|
|
|
tool_executor: T,
|
|
|
permission_policy: PermissionPolicy,
|
|
|
system_prompt: Vec<String>,
|
|
|
- feature_config: &RuntimeFeatureConfig,
|
|
|
+ feature_config: RuntimeFeatureConfig,
|
|
|
) -> Self {
|
|
|
let usage_tracker = UsageTracker::from_session(&session);
|
|
|
Self {
|
|
|
@@ -144,8 +144,9 @@ where
|
|
|
system_prompt,
|
|
|
max_iterations: usize::MAX,
|
|
|
usage_tracker,
|
|
|
- hook_runner: HookRunner::from_feature_config(feature_config),
|
|
|
+ hook_runner: HookRunner::from_feature_config(&feature_config),
|
|
|
hook_abort_signal: HookAbortSignal::default(),
|
|
|
+ hook_progress_reporter: None,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -220,17 +221,12 @@ where
|
|
|
}
|
|
|
|
|
|
for (tool_use_id, tool_name, input) in pending_tool_uses {
|
|
|
- let pre_hook_result = self.hook_runner.run_pre_tool_use_with_context(
|
|
|
- &tool_name,
|
|
|
- &input,
|
|
|
- Some(&self.hook_abort_signal),
|
|
|
- self.hook_progress_reporter.as_deref_mut(),
|
|
|
- );
|
|
|
+ let pre_hook_result = self.run_pre_tool_use_hook(&tool_name, &input);
|
|
|
let effective_input = pre_hook_result
|
|
|
- .updated_input_json()
|
|
|
+ .updated_input()
|
|
|
.map_or_else(|| input.clone(), ToOwned::to_owned);
|
|
|
let permission_context = PermissionContext::new(
|
|
|
- pre_hook_result.permission_decision(),
|
|
|
+ pre_hook_result.permission_override(),
|
|
|
pre_hook_result.permission_reason().map(ToOwned::to_owned),
|
|
|
);
|
|
|
|
|
|
@@ -274,21 +270,17 @@ where
|
|
|
output = merge_hook_feedback(pre_hook_result.messages(), output, false);
|
|
|
|
|
|
let post_hook_result = if is_error {
|
|
|
- self.hook_runner.run_post_tool_use_failure_with_context(
|
|
|
+ self.run_post_tool_use_failure_hook(
|
|
|
&tool_name,
|
|
|
&effective_input,
|
|
|
&output,
|
|
|
- Some(&self.hook_abort_signal),
|
|
|
- self.hook_progress_reporter.as_deref_mut(),
|
|
|
)
|
|
|
} else {
|
|
|
- self.hook_runner.run_post_tool_use_with_context(
|
|
|
+ self.run_post_tool_use_hook(
|
|
|
&tool_name,
|
|
|
&effective_input,
|
|
|
&output,
|
|
|
false,
|
|
|
- Some(&self.hook_abort_signal),
|
|
|
- self.hook_progress_reporter.as_deref_mut(),
|
|
|
)
|
|
|
};
|
|
|
if post_hook_result.is_denied() || post_hook_result.is_cancelled() {
|
|
|
@@ -322,6 +314,77 @@ where
|
|
|
})
|
|
|
}
|
|
|
|
|
|
+ fn run_pre_tool_use_hook(&mut self, tool_name: &str, input: &str) -> HookRunResult {
|
|
|
+ if let Some(reporter) = self.hook_progress_reporter.as_mut() {
|
|
|
+ self.hook_runner.run_pre_tool_use_with_context(
|
|
|
+ tool_name,
|
|
|
+ input,
|
|
|
+ Some(&self.hook_abort_signal),
|
|
|
+ Some(reporter.as_mut()),
|
|
|
+ )
|
|
|
+ } else {
|
|
|
+ self.hook_runner.run_pre_tool_use_with_context(
|
|
|
+ tool_name,
|
|
|
+ input,
|
|
|
+ Some(&self.hook_abort_signal),
|
|
|
+ None,
|
|
|
+ )
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fn run_post_tool_use_hook(
|
|
|
+ &mut self,
|
|
|
+ tool_name: &str,
|
|
|
+ input: &str,
|
|
|
+ output: &str,
|
|
|
+ is_error: bool,
|
|
|
+ ) -> HookRunResult {
|
|
|
+ if let Some(reporter) = self.hook_progress_reporter.as_mut() {
|
|
|
+ self.hook_runner.run_post_tool_use_with_context(
|
|
|
+ tool_name,
|
|
|
+ input,
|
|
|
+ output,
|
|
|
+ is_error,
|
|
|
+ Some(&self.hook_abort_signal),
|
|
|
+ Some(reporter.as_mut()),
|
|
|
+ )
|
|
|
+ } else {
|
|
|
+ self.hook_runner.run_post_tool_use_with_context(
|
|
|
+ tool_name,
|
|
|
+ input,
|
|
|
+ output,
|
|
|
+ is_error,
|
|
|
+ Some(&self.hook_abort_signal),
|
|
|
+ None,
|
|
|
+ )
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fn run_post_tool_use_failure_hook(
|
|
|
+ &mut self,
|
|
|
+ tool_name: &str,
|
|
|
+ input: &str,
|
|
|
+ output: &str,
|
|
|
+ ) -> HookRunResult {
|
|
|
+ if let Some(reporter) = self.hook_progress_reporter.as_mut() {
|
|
|
+ self.hook_runner.run_post_tool_use_failure_with_context(
|
|
|
+ tool_name,
|
|
|
+ input,
|
|
|
+ output,
|
|
|
+ Some(&self.hook_abort_signal),
|
|
|
+ Some(reporter.as_mut()),
|
|
|
+ )
|
|
|
+ } else {
|
|
|
+ self.hook_runner.run_post_tool_use_failure_with_context(
|
|
|
+ tool_name,
|
|
|
+ input,
|
|
|
+ output,
|
|
|
+ Some(&self.hook_abort_signal),
|
|
|
+ None,
|
|
|
+ )
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
#[must_use]
|
|
|
pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
|
|
|
compact_session(&self.session, config)
|
|
|
@@ -669,7 +732,7 @@ mod tests {
|
|
|
}),
|
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
|
vec!["system".to_string()],
|
|
|
- &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
|
|
+ RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
|
|
vec![shell_snippet("printf 'blocked by hook'; exit 2")],
|
|
|
Vec::new(),
|
|
|
Vec::new(),
|
|
|
@@ -736,7 +799,7 @@ mod tests {
|
|
|
StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
|
|
|
PermissionPolicy::new(PermissionMode::DangerFullAccess),
|
|
|
vec!["system".to_string()],
|
|
|
- &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
|
|
+ RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
|
|
|
vec![shell_snippet("printf 'pre hook ran'")],
|
|
|
vec![shell_snippet("printf 'post hook ran'")],
|
|
|
Vec::new(),
|