| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189 |
- use std::collections::BTreeMap;
- use std::fmt::{Display, Formatter};
- use crate::compact::{
- compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
- };
- use crate::config::RuntimeFeatureConfig;
- use crate::hooks::{HookAbortSignal, HookProgressReporter, HookRunResult, HookRunner};
- use crate::permissions::{
- PermissionContext, PermissionOutcome, PermissionPolicy, PermissionPrompter,
- };
- use crate::session::{ContentBlock, ConversationMessage, Session};
- use crate::usage::{TokenUsage, UsageTracker};
- const DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD: u32 = 100_000;
- const AUTO_COMPACTION_THRESHOLD_ENV_VAR: &str = "CLAUDE_CODE_AUTO_COMPACT_INPUT_TOKENS";
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct ApiRequest {
- pub system_prompt: Vec<String>,
- pub messages: Vec<ConversationMessage>,
- }
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub enum AssistantEvent {
- TextDelta(String),
- ToolUse {
- id: String,
- name: String,
- input: String,
- },
- Usage(TokenUsage),
- MessageStop,
- }
- pub trait ApiClient {
- fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
- }
- pub trait ToolExecutor {
- fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError>;
- }
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct ToolError {
- message: String,
- }
- impl ToolError {
- #[must_use]
- pub fn new(message: impl Into<String>) -> Self {
- Self {
- message: message.into(),
- }
- }
- }
- impl Display for ToolError {
- fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.message)
- }
- }
- impl std::error::Error for ToolError {}
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct RuntimeError {
- message: String,
- }
- impl RuntimeError {
- #[must_use]
- pub fn new(message: impl Into<String>) -> Self {
- Self {
- message: message.into(),
- }
- }
- }
- impl Display for RuntimeError {
- fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.message)
- }
- }
- impl std::error::Error for RuntimeError {}
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct TurnSummary {
- pub assistant_messages: Vec<ConversationMessage>,
- pub tool_results: Vec<ConversationMessage>,
- pub iterations: usize,
- pub usage: TokenUsage,
- pub auto_compaction: Option<AutoCompactionEvent>,
- }
- #[derive(Debug, Clone, Copy, PartialEq, Eq)]
- pub struct AutoCompactionEvent {
- pub removed_message_count: usize,
- }
- pub struct ConversationRuntime<C, T> {
- session: Session,
- api_client: C,
- tool_executor: T,
- permission_policy: PermissionPolicy,
- system_prompt: Vec<String>,
- max_iterations: usize,
- usage_tracker: UsageTracker,
- hook_runner: HookRunner,
- auto_compaction_input_tokens_threshold: u32,
- hook_abort_signal: HookAbortSignal,
- hook_progress_reporter: Option<Box<dyn HookProgressReporter>>,
- }
- impl<C, T> ConversationRuntime<C, T>
- where
- C: ApiClient,
- T: ToolExecutor,
- {
- #[must_use]
- pub fn new(
- session: Session,
- api_client: C,
- tool_executor: T,
- permission_policy: PermissionPolicy,
- system_prompt: Vec<String>,
- ) -> Self {
- Self::new_with_features(
- session,
- api_client,
- tool_executor,
- permission_policy,
- system_prompt,
- RuntimeFeatureConfig::default(),
- )
- }
- #[must_use]
- #[allow(clippy::needless_pass_by_value)]
- pub fn new_with_features(
- session: Session,
- api_client: C,
- tool_executor: T,
- permission_policy: PermissionPolicy,
- system_prompt: Vec<String>,
- feature_config: RuntimeFeatureConfig,
- ) -> Self {
- let usage_tracker = UsageTracker::from_session(&session);
- Self {
- session,
- api_client,
- tool_executor,
- permission_policy,
- system_prompt,
- max_iterations: usize::MAX,
- usage_tracker,
- hook_runner: HookRunner::from_feature_config(&feature_config),
- auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(),
- hook_abort_signal: HookAbortSignal::default(),
- hook_progress_reporter: None,
- }
- }
- #[must_use]
- pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
- self.max_iterations = max_iterations;
- self
- }
- #[must_use]
- pub fn with_auto_compaction_input_tokens_threshold(mut self, threshold: u32) -> Self {
- self.auto_compaction_input_tokens_threshold = threshold;
- self
- }
- #[must_use]
- pub fn with_hook_abort_signal(mut self, hook_abort_signal: HookAbortSignal) -> Self {
- self.hook_abort_signal = hook_abort_signal;
- self
- }
- #[must_use]
- pub fn with_hook_progress_reporter(
- mut self,
- hook_progress_reporter: Box<dyn HookProgressReporter>,
- ) -> Self {
- self.hook_progress_reporter = Some(hook_progress_reporter);
- self
- }
- fn run_pre_tool_use_hook(&mut self, tool_name: &str, input: &str) -> HookRunResult {
- if let Some(reporter) = self.hook_progress_reporter.as_mut() {
- self.hook_runner.run_pre_tool_use_with_context(
- tool_name,
- input,
- Some(&self.hook_abort_signal),
- Some(reporter.as_mut()),
- )
- } else {
- self.hook_runner.run_pre_tool_use_with_context(
- tool_name,
- input,
- Some(&self.hook_abort_signal),
- None,
- )
- }
- }
- fn run_post_tool_use_hook(
- &mut self,
- tool_name: &str,
- input: &str,
- output: &str,
- is_error: bool,
- ) -> HookRunResult {
- if let Some(reporter) = self.hook_progress_reporter.as_mut() {
- self.hook_runner.run_post_tool_use_with_context(
- tool_name,
- input,
- output,
- is_error,
- Some(&self.hook_abort_signal),
- Some(reporter.as_mut()),
- )
- } else {
- self.hook_runner.run_post_tool_use_with_context(
- tool_name,
- input,
- output,
- is_error,
- Some(&self.hook_abort_signal),
- None,
- )
- }
- }
- fn run_post_tool_use_failure_hook(
- &mut self,
- tool_name: &str,
- input: &str,
- output: &str,
- ) -> HookRunResult {
- if let Some(reporter) = self.hook_progress_reporter.as_mut() {
- self.hook_runner.run_post_tool_use_failure_with_context(
- tool_name,
- input,
- output,
- Some(&self.hook_abort_signal),
- Some(reporter.as_mut()),
- )
- } else {
- self.hook_runner.run_post_tool_use_failure_with_context(
- tool_name,
- input,
- output,
- Some(&self.hook_abort_signal),
- None,
- )
- }
- }
- #[allow(clippy::too_many_lines)]
- pub fn run_turn(
- &mut self,
- user_input: impl Into<String>,
- mut prompter: Option<&mut dyn PermissionPrompter>,
- ) -> Result<TurnSummary, RuntimeError> {
- self.session
- .push_user_text(user_input.into())
- .map_err(|error| RuntimeError::new(error.to_string()))?;
- let mut assistant_messages = Vec::new();
- let mut tool_results = Vec::new();
- let mut iterations = 0;
- loop {
- iterations += 1;
- if iterations > self.max_iterations {
- return Err(RuntimeError::new(
- "conversation loop exceeded the maximum number of iterations",
- ));
- }
- let request = ApiRequest {
- system_prompt: self.system_prompt.clone(),
- messages: self.session.messages.clone(),
- };
- let events = self.api_client.stream(request)?;
- let (assistant_message, usage) = build_assistant_message(events)?;
- if let Some(usage) = usage {
- self.usage_tracker.record(usage);
- }
- let pending_tool_uses = assistant_message
- .blocks
- .iter()
- .filter_map(|block| match block {
- ContentBlock::ToolUse { id, name, input } => {
- Some((id.clone(), name.clone(), input.clone()))
- }
- _ => None,
- })
- .collect::<Vec<_>>();
- self.session
- .push_message(assistant_message.clone())
- .map_err(|error| RuntimeError::new(error.to_string()))?;
- assistant_messages.push(assistant_message);
- if pending_tool_uses.is_empty() {
- break;
- }
- for (tool_use_id, tool_name, input) in pending_tool_uses {
- let pre_hook_result = self.run_pre_tool_use_hook(&tool_name, &input);
- let effective_input = pre_hook_result
- .updated_input()
- .map_or_else(|| input.clone(), ToOwned::to_owned);
- let permission_context = PermissionContext::new(
- pre_hook_result.permission_override(),
- pre_hook_result.permission_reason().map(ToOwned::to_owned),
- );
- let permission_outcome = if pre_hook_result.is_cancelled() {
- PermissionOutcome::Deny {
- reason: format_hook_message(
- &pre_hook_result,
- &format!("PreToolUse hook cancelled tool `{tool_name}`"),
- ),
- }
- } else if pre_hook_result.is_denied() {
- PermissionOutcome::Deny {
- reason: format_hook_message(
- &pre_hook_result,
- &format!("PreToolUse hook denied tool `{tool_name}`"),
- ),
- }
- } else if let Some(prompt) = prompter.as_mut() {
- self.permission_policy.authorize_with_context(
- &tool_name,
- &effective_input,
- &permission_context,
- Some(*prompt),
- )
- } else {
- self.permission_policy.authorize_with_context(
- &tool_name,
- &effective_input,
- &permission_context,
- None,
- )
- };
- let result_message = match permission_outcome {
- PermissionOutcome::Allow => {
- let (mut output, mut is_error) =
- match self.tool_executor.execute(&tool_name, &effective_input) {
- Ok(output) => (output, false),
- Err(error) => (error.to_string(), true),
- };
- output = merge_hook_feedback(pre_hook_result.messages(), output, false);
- let post_hook_result = if is_error {
- self.run_post_tool_use_failure_hook(
- &tool_name,
- &effective_input,
- &output,
- )
- } else {
- self.run_post_tool_use_hook(
- &tool_name,
- &effective_input,
- &output,
- false,
- )
- };
- if post_hook_result.is_denied() || post_hook_result.is_cancelled() {
- is_error = true;
- }
- output = merge_hook_feedback(
- post_hook_result.messages(),
- output,
- post_hook_result.is_denied() || post_hook_result.is_cancelled(),
- );
- ConversationMessage::tool_result(tool_use_id, tool_name, output, is_error)
- }
- PermissionOutcome::Deny { reason } => ConversationMessage::tool_result(
- tool_use_id,
- tool_name,
- merge_hook_feedback(pre_hook_result.messages(), reason, true),
- true,
- ),
- };
- self.session
- .push_message(result_message.clone())
- .map_err(|error| RuntimeError::new(error.to_string()))?;
- tool_results.push(result_message);
- }
- }
- let auto_compaction = self.maybe_auto_compact();
- Ok(TurnSummary {
- assistant_messages,
- tool_results,
- iterations,
- usage: self.usage_tracker.cumulative_usage(),
- auto_compaction,
- })
- }
- #[must_use]
- pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
- compact_session(&self.session, config)
- }
- #[must_use]
- pub fn estimated_tokens(&self) -> usize {
- estimate_session_tokens(&self.session)
- }
- #[must_use]
- pub fn usage(&self) -> &UsageTracker {
- &self.usage_tracker
- }
- #[must_use]
- pub fn session(&self) -> &Session {
- &self.session
- }
- #[must_use]
- pub fn fork_session(&self, branch_name: Option<String>) -> Session {
- self.session.fork(branch_name)
- }
- #[must_use]
- pub fn into_session(self) -> Session {
- self.session
- }
- fn maybe_auto_compact(&mut self) -> Option<AutoCompactionEvent> {
- if self.usage_tracker.cumulative_usage().input_tokens
- < self.auto_compaction_input_tokens_threshold
- {
- return None;
- }
- let result = compact_session(
- &self.session,
- CompactionConfig {
- max_estimated_tokens: 0,
- ..CompactionConfig::default()
- },
- );
- if result.removed_message_count == 0 {
- return None;
- }
- self.session = result.compacted_session;
- Some(AutoCompactionEvent {
- removed_message_count: result.removed_message_count,
- })
- }
- }
- #[must_use]
- pub fn auto_compaction_threshold_from_env() -> u32 {
- parse_auto_compaction_threshold(
- std::env::var(AUTO_COMPACTION_THRESHOLD_ENV_VAR)
- .ok()
- .as_deref(),
- )
- }
- #[must_use]
- fn parse_auto_compaction_threshold(value: Option<&str>) -> u32 {
- value
- .and_then(|raw| raw.trim().parse::<u32>().ok())
- .filter(|threshold| *threshold > 0)
- .unwrap_or(DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD)
- }
- fn build_assistant_message(
- events: Vec<AssistantEvent>,
- ) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
- let mut text = String::new();
- let mut blocks = Vec::new();
- let mut finished = false;
- let mut usage = None;
- for event in events {
- match event {
- AssistantEvent::TextDelta(delta) => text.push_str(&delta),
- AssistantEvent::ToolUse { id, name, input } => {
- flush_text_block(&mut text, &mut blocks);
- blocks.push(ContentBlock::ToolUse { id, name, input });
- }
- AssistantEvent::Usage(value) => usage = Some(value),
- AssistantEvent::MessageStop => {
- finished = true;
- }
- }
- }
- flush_text_block(&mut text, &mut blocks);
- if !finished {
- return Err(RuntimeError::new(
- "assistant stream ended without a message stop event",
- ));
- }
- if blocks.is_empty() {
- return Err(RuntimeError::new("assistant stream produced no content"));
- }
- Ok((
- ConversationMessage::assistant_with_usage(blocks, usage),
- usage,
- ))
- }
- fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
- if !text.is_empty() {
- blocks.push(ContentBlock::Text {
- text: std::mem::take(text),
- });
- }
- }
- fn format_hook_message(result: &HookRunResult, fallback: &str) -> String {
- if result.messages().is_empty() {
- fallback.to_string()
- } else {
- result.messages().join("\n")
- }
- }
- fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> String {
- if messages.is_empty() {
- return output;
- }
- let mut sections = Vec::new();
- if !output.trim().is_empty() {
- sections.push(output);
- }
- let label = if denied {
- "Hook feedback (denied)"
- } else {
- "Hook feedback"
- };
- sections.push(format!("{label}:\n{}", messages.join("\n")));
- sections.join("\n\n")
- }
- type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
- #[derive(Default)]
- pub struct StaticToolExecutor {
- handlers: BTreeMap<String, ToolHandler>,
- }
- impl StaticToolExecutor {
- #[must_use]
- pub fn new() -> Self {
- Self::default()
- }
- #[must_use]
- pub fn register(
- mut self,
- tool_name: impl Into<String>,
- handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
- ) -> Self {
- self.handlers.insert(tool_name.into(), Box::new(handler));
- self
- }
- }
- impl ToolExecutor for StaticToolExecutor {
- fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
- self.handlers
- .get_mut(tool_name)
- .ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
- }
- }
- #[cfg(test)]
- mod tests {
- use super::{
- parse_auto_compaction_threshold, ApiClient, ApiRequest, AssistantEvent,
- AutoCompactionEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
- DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD,
- };
- use crate::compact::CompactionConfig;
- use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
- use crate::permissions::{
- PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
- PermissionRequest,
- };
- use crate::prompt::{ProjectContext, SystemPromptBuilder};
- use crate::session::{ContentBlock, MessageRole, Session};
- use crate::usage::TokenUsage;
- use std::fs;
- use std::path::PathBuf;
- use std::time::{SystemTime, UNIX_EPOCH};
- struct ScriptedApiClient {
- call_count: usize,
- }
- impl ApiClient for ScriptedApiClient {
- fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
- self.call_count += 1;
- match self.call_count {
- 1 => {
- assert!(request
- .messages
- .iter()
- .any(|message| message.role == MessageRole::User));
- Ok(vec![
- AssistantEvent::TextDelta("Let me calculate that.".to_string()),
- AssistantEvent::ToolUse {
- id: "tool-1".to_string(),
- name: "add".to_string(),
- input: "2,2".to_string(),
- },
- AssistantEvent::Usage(TokenUsage {
- input_tokens: 20,
- output_tokens: 6,
- cache_creation_input_tokens: 1,
- cache_read_input_tokens: 2,
- }),
- AssistantEvent::MessageStop,
- ])
- }
- 2 => {
- let last_message = request
- .messages
- .last()
- .expect("tool result should be present");
- assert_eq!(last_message.role, MessageRole::Tool);
- Ok(vec![
- AssistantEvent::TextDelta("The answer is 4.".to_string()),
- AssistantEvent::Usage(TokenUsage {
- input_tokens: 24,
- output_tokens: 4,
- cache_creation_input_tokens: 1,
- cache_read_input_tokens: 3,
- }),
- AssistantEvent::MessageStop,
- ])
- }
- _ => Err(RuntimeError::new("unexpected extra API call")),
- }
- }
- }
- struct PromptAllowOnce;
- impl PermissionPrompter for PromptAllowOnce {
- fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
- assert_eq!(request.tool_name, "add");
- PermissionPromptDecision::Allow
- }
- }
- #[test]
- fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
- let api_client = ScriptedApiClient { call_count: 0 };
- let tool_executor = StaticToolExecutor::new().register("add", |input| {
- let total = input
- .split(',')
- .map(|part| part.parse::<i32>().expect("input must be valid integer"))
- .sum::<i32>();
- Ok(total.to_string())
- });
- let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
- let system_prompt = SystemPromptBuilder::new()
- .with_project_context(ProjectContext {
- cwd: PathBuf::from("/tmp/project"),
- current_date: "2026-03-31".to_string(),
- git_status: None,
- git_diff: None,
- instruction_files: Vec::new(),
- })
- .with_os("linux", "6.8")
- .build();
- let mut runtime = ConversationRuntime::new(
- Session::new(),
- api_client,
- tool_executor,
- permission_policy,
- system_prompt,
- );
- let summary = runtime
- .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
- .expect("conversation loop should succeed");
- assert_eq!(summary.iterations, 2);
- assert_eq!(summary.assistant_messages.len(), 2);
- assert_eq!(summary.tool_results.len(), 1);
- assert_eq!(runtime.session().messages.len(), 4);
- assert_eq!(summary.usage.output_tokens, 10);
- assert_eq!(summary.auto_compaction, None);
- assert!(matches!(
- runtime.session().messages[1].blocks[1],
- ContentBlock::ToolUse { .. }
- ));
- assert!(matches!(
- runtime.session().messages[2].blocks[0],
- ContentBlock::ToolResult {
- is_error: false,
- ..
- }
- ));
- }
- #[test]
- fn records_denied_tool_results_when_prompt_rejects() {
- struct RejectPrompter;
- impl PermissionPrompter for RejectPrompter {
- fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
- PermissionPromptDecision::Deny {
- reason: "not now".to_string(),
- }
- }
- }
- struct SingleCallApiClient;
- impl ApiClient for SingleCallApiClient {
- fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
- if request
- .messages
- .iter()
- .any(|message| message.role == MessageRole::Tool)
- {
- return Ok(vec![
- AssistantEvent::TextDelta("I could not use the tool.".to_string()),
- AssistantEvent::MessageStop,
- ]);
- }
- Ok(vec![
- AssistantEvent::ToolUse {
- id: "tool-1".to_string(),
- name: "blocked".to_string(),
- input: "secret".to_string(),
- },
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut runtime = ConversationRuntime::new(
- Session::new(),
- SingleCallApiClient,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::WorkspaceWrite),
- vec!["system".to_string()],
- );
- let summary = runtime
- .run_turn("use the tool", Some(&mut RejectPrompter))
- .expect("conversation should continue after denied tool");
- assert_eq!(summary.tool_results.len(), 1);
- assert!(matches!(
- &summary.tool_results[0].blocks[0],
- ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
- ));
- }
- #[test]
- fn denies_tool_use_when_pre_tool_hook_blocks() {
- struct SingleCallApiClient;
- impl ApiClient for SingleCallApiClient {
- fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
- if request
- .messages
- .iter()
- .any(|message| message.role == MessageRole::Tool)
- {
- return Ok(vec![
- AssistantEvent::TextDelta("blocked".to_string()),
- AssistantEvent::MessageStop,
- ]);
- }
- Ok(vec![
- AssistantEvent::ToolUse {
- id: "tool-1".to_string(),
- name: "blocked".to_string(),
- input: r#"{"path":"secret.txt"}"#.to_string(),
- },
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut runtime = ConversationRuntime::new_with_features(
- Session::new(),
- SingleCallApiClient,
- StaticToolExecutor::new().register("blocked", |_input| {
- panic!("tool should not execute when hook denies")
- }),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
- vec![shell_snippet("printf 'blocked by hook'; exit 2")],
- Vec::new(),
- Vec::new(),
- )),
- );
- let summary = runtime
- .run_turn("use the tool", None)
- .expect("conversation should continue after hook denial");
- assert_eq!(summary.tool_results.len(), 1);
- let ContentBlock::ToolResult {
- is_error, output, ..
- } = &summary.tool_results[0].blocks[0]
- else {
- panic!("expected tool result block");
- };
- assert!(
- *is_error,
- "hook denial should produce an error result: {output}"
- );
- assert!(
- output.contains("denied tool") || output.contains("blocked by hook"),
- "unexpected hook denial output: {output:?}"
- );
- }
- #[test]
- fn appends_post_tool_hook_feedback_to_tool_result() {
- struct TwoCallApiClient {
- calls: usize,
- }
- impl ApiClient for TwoCallApiClient {
- fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
- self.calls += 1;
- match self.calls {
- 1 => Ok(vec![
- AssistantEvent::ToolUse {
- id: "tool-1".to_string(),
- name: "add".to_string(),
- input: r#"{"lhs":2,"rhs":2}"#.to_string(),
- },
- AssistantEvent::MessageStop,
- ]),
- 2 => {
- assert!(request
- .messages
- .iter()
- .any(|message| message.role == MessageRole::Tool));
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::MessageStop,
- ])
- }
- _ => Err(RuntimeError::new("unexpected extra API call")),
- }
- }
- }
- let mut runtime = ConversationRuntime::new_with_features(
- Session::new(),
- TwoCallApiClient { calls: 0 },
- StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
- vec![shell_snippet("printf 'pre hook ran'")],
- vec![shell_snippet("printf 'post hook ran'")],
- Vec::new(),
- )),
- );
- let summary = runtime
- .run_turn("use add", None)
- .expect("tool loop succeeds");
- assert_eq!(summary.tool_results.len(), 1);
- let ContentBlock::ToolResult {
- is_error, output, ..
- } = &summary.tool_results[0].blocks[0]
- else {
- panic!("expected tool result block");
- };
- assert!(
- !*is_error,
- "post hook should preserve non-error result: {output:?}"
- );
- assert!(
- output.contains('4'),
- "tool output missing value: {output:?}"
- );
- assert!(
- output.contains("pre hook ran"),
- "tool output missing pre hook feedback: {output:?}"
- );
- assert!(
- output.contains("post hook ran"),
- "tool output missing post hook feedback: {output:?}"
- );
- }
- #[test]
- fn reconstructs_usage_tracker_from_restored_session() {
- struct SimpleApi;
- impl ApiClient for SimpleApi {
- fn stream(
- &mut self,
- _request: ApiRequest,
- ) -> Result<Vec<AssistantEvent>, RuntimeError> {
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut session = Session::new();
- session
- .messages
- .push(crate::session::ConversationMessage::assistant_with_usage(
- vec![ContentBlock::Text {
- text: "earlier".to_string(),
- }],
- Some(TokenUsage {
- input_tokens: 11,
- output_tokens: 7,
- cache_creation_input_tokens: 2,
- cache_read_input_tokens: 1,
- }),
- ));
- let runtime = ConversationRuntime::new(
- session,
- SimpleApi,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- );
- assert_eq!(runtime.usage().turns(), 1);
- assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
- }
- #[test]
- fn compacts_session_after_turns() {
- struct SimpleApi;
- impl ApiClient for SimpleApi {
- fn stream(
- &mut self,
- _request: ApiRequest,
- ) -> Result<Vec<AssistantEvent>, RuntimeError> {
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut runtime = ConversationRuntime::new(
- Session::new(),
- SimpleApi,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- );
- runtime.run_turn("a", None).expect("turn a");
- runtime.run_turn("b", None).expect("turn b");
- runtime.run_turn("c", None).expect("turn c");
- let result = runtime.compact(CompactionConfig {
- preserve_recent_messages: 2,
- max_estimated_tokens: 1,
- });
- assert!(result.summary.contains("Conversation summary"));
- assert_eq!(
- result.compacted_session.messages[0].role,
- MessageRole::System
- );
- assert_eq!(
- result.compacted_session.session_id,
- runtime.session().session_id
- );
- assert!(result.compacted_session.compaction.is_some());
- }
- #[test]
- fn persists_conversation_turn_messages_to_jsonl_session() {
- struct SimpleApi;
- impl ApiClient for SimpleApi {
- fn stream(
- &mut self,
- _request: ApiRequest,
- ) -> Result<Vec<AssistantEvent>, RuntimeError> {
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::MessageStop,
- ])
- }
- }
- let path = temp_session_path("persisted-turn");
- let session = Session::new().with_persistence_path(path.clone());
- let mut runtime = ConversationRuntime::new(
- session,
- SimpleApi,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- );
- runtime
- .run_turn("persist this turn", None)
- .expect("turn should succeed");
- let restored = Session::load_from_path(&path).expect("persisted session should reload");
- fs::remove_file(&path).expect("temp session file should be removable");
- assert_eq!(restored.messages.len(), 2);
- assert_eq!(restored.messages[0].role, MessageRole::User);
- assert_eq!(restored.messages[1].role, MessageRole::Assistant);
- assert_eq!(restored.session_id, runtime.session().session_id);
- }
- #[test]
- fn forks_runtime_session_without_mutating_original() {
- let mut session = Session::new();
- session
- .push_user_text("branch me")
- .expect("message should append");
- let runtime = ConversationRuntime::new(
- session.clone(),
- ScriptedApiClient { call_count: 0 },
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- );
- let forked = runtime.fork_session(Some("alt-path".to_string()));
- assert_eq!(forked.messages, session.messages);
- assert_ne!(forked.session_id, session.session_id);
- assert_eq!(
- forked
- .fork
- .as_ref()
- .map(|fork| (fork.parent_session_id.as_str(), fork.branch_name.as_deref())),
- Some((session.session_id.as_str(), Some("alt-path")))
- );
- assert!(runtime.session().fork.is_none());
- }
- fn temp_session_path(label: &str) -> PathBuf {
- let nanos = SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .expect("system time should be after epoch")
- .as_nanos();
- std::env::temp_dir().join(format!("runtime-conversation-{label}-{nanos}.json"))
- }
- #[cfg(windows)]
- fn shell_snippet(script: &str) -> String {
- script.replace('\'', "\"")
- }
- #[cfg(not(windows))]
- fn shell_snippet(script: &str) -> String {
- script.to_string()
- }
- #[test]
- fn auto_compacts_when_cumulative_input_threshold_is_crossed() {
- struct SimpleApi;
- impl ApiClient for SimpleApi {
- fn stream(
- &mut self,
- _request: ApiRequest,
- ) -> Result<Vec<AssistantEvent>, RuntimeError> {
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::Usage(TokenUsage {
- input_tokens: 120_000,
- output_tokens: 4,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- }),
- AssistantEvent::MessageStop,
- ])
- }
- }
- let session = Session {
- version: 1,
- messages: vec![
- crate::session::ConversationMessage::user_text("one"),
- crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
- text: "two".to_string(),
- }]),
- crate::session::ConversationMessage::user_text("three"),
- crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
- text: "four".to_string(),
- }]),
- ],
- };
- let mut runtime = ConversationRuntime::new(
- session,
- SimpleApi,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- )
- .with_auto_compaction_input_tokens_threshold(100_000);
- let summary = runtime
- .run_turn("trigger", None)
- .expect("turn should succeed");
- assert_eq!(
- summary.auto_compaction,
- Some(AutoCompactionEvent {
- removed_message_count: 2,
- })
- );
- assert_eq!(runtime.session().messages[0].role, MessageRole::System);
- }
- #[test]
- fn skips_auto_compaction_below_threshold() {
- struct SimpleApi;
- impl ApiClient for SimpleApi {
- fn stream(
- &mut self,
- _request: ApiRequest,
- ) -> Result<Vec<AssistantEvent>, RuntimeError> {
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::Usage(TokenUsage {
- input_tokens: 99_999,
- output_tokens: 4,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- }),
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut runtime = ConversationRuntime::new(
- Session::new(),
- SimpleApi,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- )
- .with_auto_compaction_input_tokens_threshold(100_000);
- let summary = runtime
- .run_turn("trigger", None)
- .expect("turn should succeed");
- assert_eq!(summary.auto_compaction, None);
- assert_eq!(runtime.session().messages.len(), 2);
- }
- #[test]
- fn auto_compaction_threshold_defaults_and_parses_values() {
- assert_eq!(
- parse_auto_compaction_threshold(None),
- DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
- );
- assert_eq!(parse_auto_compaction_threshold(Some("4321")), 4321);
- assert_eq!(
- parse_auto_compaction_threshold(Some("not-a-number")),
- DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
- );
- }
- }
|