| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401 |
- use std::collections::BTreeMap;
- use std::fmt::{Display, Formatter};
- use serde_json::{Map, Value};
- use telemetry::SessionTracer;
- use crate::compact::{
- compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
- };
- use crate::config::RuntimeFeatureConfig;
- use crate::hooks::{HookAbortSignal, HookProgressReporter, HookRunResult, HookRunner};
- use crate::permissions::{
- PermissionContext, PermissionOutcome, PermissionPolicy, PermissionPrompter,
- };
- use crate::session::{ContentBlock, ConversationMessage, Session};
- use crate::usage::{TokenUsage, UsageTracker};
- const DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD: u32 = 100_000;
- const AUTO_COMPACTION_THRESHOLD_ENV_VAR: &str = "CLAUDE_CODE_AUTO_COMPACT_INPUT_TOKENS";
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct ApiRequest {
- pub system_prompt: Vec<String>,
- pub messages: Vec<ConversationMessage>,
- }
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub enum AssistantEvent {
- TextDelta(String),
- ToolUse {
- id: String,
- name: String,
- input: String,
- },
- Usage(TokenUsage),
- PromptCache(PromptCacheEvent),
- MessageStop,
- }
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct PromptCacheEvent {
- pub unexpected: bool,
- pub reason: String,
- pub previous_cache_read_input_tokens: u32,
- pub current_cache_read_input_tokens: u32,
- pub token_drop: u32,
- }
- pub trait ApiClient {
- fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
- }
- pub trait ToolExecutor {
- fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError>;
- }
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct ToolError {
- message: String,
- }
- impl ToolError {
- #[must_use]
- pub fn new(message: impl Into<String>) -> Self {
- Self {
- message: message.into(),
- }
- }
- }
- impl Display for ToolError {
- fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.message)
- }
- }
- impl std::error::Error for ToolError {}
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct RuntimeError {
- message: String,
- }
- impl RuntimeError {
- #[must_use]
- pub fn new(message: impl Into<String>) -> Self {
- Self {
- message: message.into(),
- }
- }
- }
- impl Display for RuntimeError {
- fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.message)
- }
- }
- impl std::error::Error for RuntimeError {}
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct TurnSummary {
- pub assistant_messages: Vec<ConversationMessage>,
- pub tool_results: Vec<ConversationMessage>,
- pub prompt_cache_events: Vec<PromptCacheEvent>,
- pub iterations: usize,
- pub usage: TokenUsage,
- pub auto_compaction: Option<AutoCompactionEvent>,
- }
- #[derive(Debug, Clone, Copy, PartialEq, Eq)]
- pub struct AutoCompactionEvent {
- pub removed_message_count: usize,
- }
- pub struct ConversationRuntime<C, T> {
- session: Session,
- api_client: C,
- tool_executor: T,
- permission_policy: PermissionPolicy,
- system_prompt: Vec<String>,
- max_iterations: usize,
- usage_tracker: UsageTracker,
- hook_runner: HookRunner,
- auto_compaction_input_tokens_threshold: u32,
- hook_abort_signal: HookAbortSignal,
- hook_progress_reporter: Option<Box<dyn HookProgressReporter>>,
- session_tracer: Option<SessionTracer>,
- }
- impl<C, T> ConversationRuntime<C, T>
- where
- C: ApiClient,
- T: ToolExecutor,
- {
- #[must_use]
- pub fn new(
- session: Session,
- api_client: C,
- tool_executor: T,
- permission_policy: PermissionPolicy,
- system_prompt: Vec<String>,
- ) -> Self {
- Self::new_with_features(
- session,
- api_client,
- tool_executor,
- permission_policy,
- system_prompt,
- &RuntimeFeatureConfig::default(),
- )
- }
- #[must_use]
- #[allow(clippy::needless_pass_by_value)]
- pub fn new_with_features(
- session: Session,
- api_client: C,
- tool_executor: T,
- permission_policy: PermissionPolicy,
- system_prompt: Vec<String>,
- feature_config: &RuntimeFeatureConfig,
- ) -> Self {
- let usage_tracker = UsageTracker::from_session(&session);
- Self {
- session,
- api_client,
- tool_executor,
- permission_policy,
- system_prompt,
- max_iterations: usize::MAX,
- usage_tracker,
- hook_runner: HookRunner::from_feature_config(feature_config),
- auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(),
- hook_abort_signal: HookAbortSignal::default(),
- hook_progress_reporter: None,
- session_tracer: None,
- }
- }
- #[must_use]
- pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
- self.max_iterations = max_iterations;
- self
- }
- #[must_use]
- pub fn with_auto_compaction_input_tokens_threshold(mut self, threshold: u32) -> Self {
- self.auto_compaction_input_tokens_threshold = threshold;
- self
- }
- #[must_use]
- pub fn with_hook_abort_signal(mut self, hook_abort_signal: HookAbortSignal) -> Self {
- self.hook_abort_signal = hook_abort_signal;
- self
- }
- #[must_use]
- pub fn with_hook_progress_reporter(
- mut self,
- hook_progress_reporter: Box<dyn HookProgressReporter>,
- ) -> Self {
- self.hook_progress_reporter = Some(hook_progress_reporter);
- self
- }
- #[must_use]
- pub fn with_session_tracer(mut self, session_tracer: SessionTracer) -> Self {
- self.session_tracer = Some(session_tracer);
- self
- }
- fn run_pre_tool_use_hook(&mut self, tool_name: &str, input: &str) -> HookRunResult {
- if let Some(reporter) = self.hook_progress_reporter.as_mut() {
- self.hook_runner.run_pre_tool_use_with_context(
- tool_name,
- input,
- Some(&self.hook_abort_signal),
- Some(reporter.as_mut()),
- )
- } else {
- self.hook_runner.run_pre_tool_use_with_context(
- tool_name,
- input,
- Some(&self.hook_abort_signal),
- None,
- )
- }
- }
- fn run_post_tool_use_hook(
- &mut self,
- tool_name: &str,
- input: &str,
- output: &str,
- is_error: bool,
- ) -> HookRunResult {
- if let Some(reporter) = self.hook_progress_reporter.as_mut() {
- self.hook_runner.run_post_tool_use_with_context(
- tool_name,
- input,
- output,
- is_error,
- Some(&self.hook_abort_signal),
- Some(reporter.as_mut()),
- )
- } else {
- self.hook_runner.run_post_tool_use_with_context(
- tool_name,
- input,
- output,
- is_error,
- Some(&self.hook_abort_signal),
- None,
- )
- }
- }
- fn run_post_tool_use_failure_hook(
- &mut self,
- tool_name: &str,
- input: &str,
- output: &str,
- ) -> HookRunResult {
- if let Some(reporter) = self.hook_progress_reporter.as_mut() {
- self.hook_runner.run_post_tool_use_failure_with_context(
- tool_name,
- input,
- output,
- Some(&self.hook_abort_signal),
- Some(reporter.as_mut()),
- )
- } else {
- self.hook_runner.run_post_tool_use_failure_with_context(
- tool_name,
- input,
- output,
- Some(&self.hook_abort_signal),
- None,
- )
- }
- }
- #[allow(clippy::too_many_lines)]
- pub fn run_turn(
- &mut self,
- user_input: impl Into<String>,
- mut prompter: Option<&mut dyn PermissionPrompter>,
- ) -> Result<TurnSummary, RuntimeError> {
- let user_input = user_input.into();
- self.record_turn_started(&user_input);
- self.session
- .push_user_text(user_input)
- .map_err(|error| RuntimeError::new(error.to_string()))?;
- let mut assistant_messages = Vec::new();
- let mut tool_results = Vec::new();
- let mut prompt_cache_events = Vec::new();
- let mut iterations = 0;
- loop {
- iterations += 1;
- if iterations > self.max_iterations {
- let error = RuntimeError::new(
- "conversation loop exceeded the maximum number of iterations",
- );
- self.record_turn_failed(iterations, &error);
- return Err(error);
- }
- let request = ApiRequest {
- system_prompt: self.system_prompt.clone(),
- messages: self.session.messages.clone(),
- };
- let events = match self.api_client.stream(request) {
- Ok(events) => events,
- Err(error) => {
- self.record_turn_failed(iterations, &error);
- return Err(error);
- }
- };
- let (assistant_message, usage, turn_prompt_cache_events) =
- match build_assistant_message(events) {
- Ok(result) => result,
- Err(error) => {
- self.record_turn_failed(iterations, &error);
- return Err(error);
- }
- };
- if let Some(usage) = usage {
- self.usage_tracker.record(usage);
- }
- prompt_cache_events.extend(turn_prompt_cache_events);
- let pending_tool_uses = assistant_message
- .blocks
- .iter()
- .filter_map(|block| match block {
- ContentBlock::ToolUse { id, name, input } => {
- Some((id.clone(), name.clone(), input.clone()))
- }
- _ => None,
- })
- .collect::<Vec<_>>();
- self.record_assistant_iteration(
- iterations,
- &assistant_message,
- pending_tool_uses.len(),
- );
- self.session
- .push_message(assistant_message.clone())
- .map_err(|error| RuntimeError::new(error.to_string()))?;
- assistant_messages.push(assistant_message);
- if pending_tool_uses.is_empty() {
- break;
- }
- for (tool_use_id, tool_name, input) in pending_tool_uses {
- let pre_hook_result = self.run_pre_tool_use_hook(&tool_name, &input);
- let effective_input = pre_hook_result
- .updated_input()
- .map_or_else(|| input.clone(), ToOwned::to_owned);
- let permission_context = PermissionContext::new(
- pre_hook_result.permission_override(),
- pre_hook_result.permission_reason().map(ToOwned::to_owned),
- );
- let permission_outcome = if pre_hook_result.is_cancelled() {
- PermissionOutcome::Deny {
- reason: format_hook_message(
- &pre_hook_result,
- &format!("PreToolUse hook cancelled tool `{tool_name}`"),
- ),
- }
- } else if pre_hook_result.is_denied() {
- PermissionOutcome::Deny {
- reason: format_hook_message(
- &pre_hook_result,
- &format!("PreToolUse hook denied tool `{tool_name}`"),
- ),
- }
- } else if let Some(prompt) = prompter.as_mut() {
- self.permission_policy.authorize_with_context(
- &tool_name,
- &effective_input,
- &permission_context,
- Some(*prompt),
- )
- } else {
- self.permission_policy.authorize_with_context(
- &tool_name,
- &effective_input,
- &permission_context,
- None,
- )
- };
- let result_message = match permission_outcome {
- PermissionOutcome::Allow => {
- self.record_tool_started(iterations, &tool_name);
- let (mut output, mut is_error) =
- match self.tool_executor.execute(&tool_name, &effective_input) {
- Ok(output) => (output, false),
- Err(error) => (error.to_string(), true),
- };
- output = merge_hook_feedback(pre_hook_result.messages(), output, false);
- let post_hook_result = if is_error {
- self.run_post_tool_use_failure_hook(
- &tool_name,
- &effective_input,
- &output,
- )
- } else {
- self.run_post_tool_use_hook(
- &tool_name,
- &effective_input,
- &output,
- false,
- )
- };
- if post_hook_result.is_denied() || post_hook_result.is_cancelled() {
- is_error = true;
- }
- output = merge_hook_feedback(
- post_hook_result.messages(),
- output,
- post_hook_result.is_denied() || post_hook_result.is_cancelled(),
- );
- ConversationMessage::tool_result(tool_use_id, tool_name, output, is_error)
- }
- PermissionOutcome::Deny { reason } => ConversationMessage::tool_result(
- tool_use_id,
- tool_name,
- merge_hook_feedback(pre_hook_result.messages(), reason, true),
- true,
- ),
- };
- self.session
- .push_message(result_message.clone())
- .map_err(|error| RuntimeError::new(error.to_string()))?;
- self.record_tool_finished(iterations, &result_message);
- tool_results.push(result_message);
- }
- }
- let auto_compaction = self.maybe_auto_compact();
- let summary = TurnSummary {
- assistant_messages,
- tool_results,
- prompt_cache_events,
- iterations,
- usage: self.usage_tracker.cumulative_usage(),
- auto_compaction,
- };
- self.record_turn_completed(&summary);
- Ok(summary)
- }
- #[must_use]
- pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
- compact_session(&self.session, config)
- }
- #[must_use]
- pub fn estimated_tokens(&self) -> usize {
- estimate_session_tokens(&self.session)
- }
- #[must_use]
- pub fn usage(&self) -> &UsageTracker {
- &self.usage_tracker
- }
- #[must_use]
- pub fn session(&self) -> &Session {
- &self.session
- }
- #[must_use]
- pub fn fork_session(&self, branch_name: Option<String>) -> Session {
- self.session.fork(branch_name)
- }
- #[must_use]
- pub fn into_session(self) -> Session {
- self.session
- }
- fn maybe_auto_compact(&mut self) -> Option<AutoCompactionEvent> {
- if self.usage_tracker.cumulative_usage().input_tokens
- < self.auto_compaction_input_tokens_threshold
- {
- return None;
- }
- let result = compact_session(
- &self.session,
- CompactionConfig {
- max_estimated_tokens: 0,
- ..CompactionConfig::default()
- },
- );
- if result.removed_message_count == 0 {
- return None;
- }
- self.session = result.compacted_session;
- Some(AutoCompactionEvent {
- removed_message_count: result.removed_message_count,
- })
- }
- fn record_turn_started(&self, user_input: &str) {
- let Some(session_tracer) = &self.session_tracer else {
- return;
- };
- let mut attributes = Map::new();
- attributes.insert(
- "user_input".to_string(),
- Value::String(user_input.to_string()),
- );
- session_tracer.record("turn_started", attributes);
- }
- fn record_assistant_iteration(
- &self,
- iteration: usize,
- assistant_message: &ConversationMessage,
- pending_tool_use_count: usize,
- ) {
- let Some(session_tracer) = &self.session_tracer else {
- return;
- };
- let mut attributes = Map::new();
- attributes.insert("iteration".to_string(), Value::from(iteration as u64));
- attributes.insert(
- "assistant_blocks".to_string(),
- Value::from(assistant_message.blocks.len() as u64),
- );
- attributes.insert(
- "pending_tool_use_count".to_string(),
- Value::from(pending_tool_use_count as u64),
- );
- session_tracer.record("assistant_iteration_completed", attributes);
- }
- fn record_tool_started(&self, iteration: usize, tool_name: &str) {
- let Some(session_tracer) = &self.session_tracer else {
- return;
- };
- let mut attributes = Map::new();
- attributes.insert("iteration".to_string(), Value::from(iteration as u64));
- attributes.insert(
- "tool_name".to_string(),
- Value::String(tool_name.to_string()),
- );
- session_tracer.record("tool_execution_started", attributes);
- }
- fn record_tool_finished(&self, iteration: usize, result_message: &ConversationMessage) {
- let Some(session_tracer) = &self.session_tracer else {
- return;
- };
- let Some(ContentBlock::ToolResult {
- tool_name,
- is_error,
- ..
- }) = result_message.blocks.first()
- else {
- return;
- };
- let mut attributes = Map::new();
- attributes.insert("iteration".to_string(), Value::from(iteration as u64));
- attributes.insert("tool_name".to_string(), Value::String(tool_name.clone()));
- attributes.insert("is_error".to_string(), Value::Bool(*is_error));
- session_tracer.record("tool_execution_finished", attributes);
- }
- fn record_turn_completed(&self, summary: &TurnSummary) {
- let Some(session_tracer) = &self.session_tracer else {
- return;
- };
- let mut attributes = Map::new();
- attributes.insert(
- "iterations".to_string(),
- Value::from(summary.iterations as u64),
- );
- attributes.insert(
- "assistant_messages".to_string(),
- Value::from(summary.assistant_messages.len() as u64),
- );
- attributes.insert(
- "tool_results".to_string(),
- Value::from(summary.tool_results.len() as u64),
- );
- attributes.insert(
- "prompt_cache_events".to_string(),
- Value::from(summary.prompt_cache_events.len() as u64),
- );
- session_tracer.record("turn_completed", attributes);
- }
- fn record_turn_failed(&self, iteration: usize, error: &RuntimeError) {
- let Some(session_tracer) = &self.session_tracer else {
- return;
- };
- let mut attributes = Map::new();
- attributes.insert("iteration".to_string(), Value::from(iteration as u64));
- attributes.insert("error".to_string(), Value::String(error.to_string()));
- session_tracer.record("turn_failed", attributes);
- }
- }
- #[must_use]
- pub fn auto_compaction_threshold_from_env() -> u32 {
- parse_auto_compaction_threshold(
- std::env::var(AUTO_COMPACTION_THRESHOLD_ENV_VAR)
- .ok()
- .as_deref(),
- )
- }
- #[must_use]
- fn parse_auto_compaction_threshold(value: Option<&str>) -> u32 {
- value
- .and_then(|raw| raw.trim().parse::<u32>().ok())
- .filter(|threshold| *threshold > 0)
- .unwrap_or(DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD)
- }
- fn build_assistant_message(
- events: Vec<AssistantEvent>,
- ) -> Result<
- (
- ConversationMessage,
- Option<TokenUsage>,
- Vec<PromptCacheEvent>,
- ),
- RuntimeError,
- > {
- let mut text = String::new();
- let mut blocks = Vec::new();
- let mut prompt_cache_events = Vec::new();
- let mut finished = false;
- let mut usage = None;
- for event in events {
- match event {
- AssistantEvent::TextDelta(delta) => text.push_str(&delta),
- AssistantEvent::ToolUse { id, name, input } => {
- flush_text_block(&mut text, &mut blocks);
- blocks.push(ContentBlock::ToolUse { id, name, input });
- }
- AssistantEvent::Usage(value) => usage = Some(value),
- AssistantEvent::PromptCache(event) => prompt_cache_events.push(event),
- AssistantEvent::MessageStop => {
- finished = true;
- }
- }
- }
- flush_text_block(&mut text, &mut blocks);
- if !finished {
- return Err(RuntimeError::new(
- "assistant stream ended without a message stop event",
- ));
- }
- if blocks.is_empty() {
- return Err(RuntimeError::new("assistant stream produced no content"));
- }
- Ok((
- ConversationMessage::assistant_with_usage(blocks, usage),
- usage,
- prompt_cache_events,
- ))
- }
- fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
- if !text.is_empty() {
- blocks.push(ContentBlock::Text {
- text: std::mem::take(text),
- });
- }
- }
- fn format_hook_message(result: &HookRunResult, fallback: &str) -> String {
- if result.messages().is_empty() {
- fallback.to_string()
- } else {
- result.messages().join("\n")
- }
- }
- fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> String {
- if messages.is_empty() {
- return output;
- }
- let mut sections = Vec::new();
- if !output.trim().is_empty() {
- sections.push(output);
- }
- let label = if denied {
- "Hook feedback (denied)"
- } else {
- "Hook feedback"
- };
- sections.push(format!("{label}:\n{}", messages.join("\n")));
- sections.join("\n\n")
- }
- type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
- #[derive(Default)]
- pub struct StaticToolExecutor {
- handlers: BTreeMap<String, ToolHandler>,
- }
- impl StaticToolExecutor {
- #[must_use]
- pub fn new() -> Self {
- Self::default()
- }
- #[must_use]
- pub fn register(
- mut self,
- tool_name: impl Into<String>,
- handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
- ) -> Self {
- self.handlers.insert(tool_name.into(), Box::new(handler));
- self
- }
- }
- impl ToolExecutor for StaticToolExecutor {
- fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
- self.handlers
- .get_mut(tool_name)
- .ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
- }
- }
- #[cfg(test)]
- mod tests {
- use super::{
- parse_auto_compaction_threshold, ApiClient, ApiRequest, AssistantEvent,
- AutoCompactionEvent, ConversationRuntime, PromptCacheEvent, RuntimeError,
- StaticToolExecutor, DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD,
- };
- use crate::compact::CompactionConfig;
- use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
- use crate::permissions::{
- PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
- PermissionRequest,
- };
- use crate::prompt::{ProjectContext, SystemPromptBuilder};
- use crate::session::{ContentBlock, MessageRole, Session};
- use crate::usage::TokenUsage;
- use std::fs;
- use std::path::PathBuf;
- use std::sync::Arc;
- use std::time::{SystemTime, UNIX_EPOCH};
- use telemetry::{MemoryTelemetrySink, SessionTracer, TelemetryEvent};
- struct ScriptedApiClient {
- call_count: usize,
- }
- impl ApiClient for ScriptedApiClient {
- fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
- self.call_count += 1;
- match self.call_count {
- 1 => {
- assert!(request
- .messages
- .iter()
- .any(|message| message.role == MessageRole::User));
- Ok(vec![
- AssistantEvent::TextDelta("Let me calculate that.".to_string()),
- AssistantEvent::ToolUse {
- id: "tool-1".to_string(),
- name: "add".to_string(),
- input: "2,2".to_string(),
- },
- AssistantEvent::Usage(TokenUsage {
- input_tokens: 20,
- output_tokens: 6,
- cache_creation_input_tokens: 1,
- cache_read_input_tokens: 2,
- }),
- AssistantEvent::MessageStop,
- ])
- }
- 2 => {
- let last_message = request
- .messages
- .last()
- .expect("tool result should be present");
- assert_eq!(last_message.role, MessageRole::Tool);
- Ok(vec![
- AssistantEvent::TextDelta("The answer is 4.".to_string()),
- AssistantEvent::Usage(TokenUsage {
- input_tokens: 24,
- output_tokens: 4,
- cache_creation_input_tokens: 1,
- cache_read_input_tokens: 3,
- }),
- AssistantEvent::PromptCache(PromptCacheEvent {
- unexpected: true,
- reason:
- "cache read tokens dropped while prompt fingerprint remained stable"
- .to_string(),
- previous_cache_read_input_tokens: 6_000,
- current_cache_read_input_tokens: 1_000,
- token_drop: 5_000,
- }),
- AssistantEvent::MessageStop,
- ])
- }
- _ => Err(RuntimeError::new("unexpected extra API call")),
- }
- }
- }
- struct PromptAllowOnce;
- impl PermissionPrompter for PromptAllowOnce {
- fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
- assert_eq!(request.tool_name, "add");
- PermissionPromptDecision::Allow
- }
- }
- #[test]
- fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
- let api_client = ScriptedApiClient { call_count: 0 };
- let tool_executor = StaticToolExecutor::new().register("add", |input| {
- let total = input
- .split(',')
- .map(|part| part.parse::<i32>().expect("input must be valid integer"))
- .sum::<i32>();
- Ok(total.to_string())
- });
- let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
- let system_prompt = SystemPromptBuilder::new()
- .with_project_context(ProjectContext {
- cwd: PathBuf::from("/tmp/project"),
- current_date: "2026-03-31".to_string(),
- git_status: None,
- git_diff: None,
- instruction_files: Vec::new(),
- })
- .with_os("linux", "6.8")
- .build();
- let mut runtime = ConversationRuntime::new(
- Session::new(),
- api_client,
- tool_executor,
- permission_policy,
- system_prompt,
- );
- let summary = runtime
- .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
- .expect("conversation loop should succeed");
- assert_eq!(summary.iterations, 2);
- assert_eq!(summary.assistant_messages.len(), 2);
- assert_eq!(summary.tool_results.len(), 1);
- assert_eq!(summary.prompt_cache_events.len(), 1);
- assert_eq!(runtime.session().messages.len(), 4);
- assert_eq!(summary.usage.output_tokens, 10);
- assert_eq!(summary.auto_compaction, None);
- assert!(matches!(
- runtime.session().messages[1].blocks[1],
- ContentBlock::ToolUse { .. }
- ));
- assert!(matches!(
- runtime.session().messages[2].blocks[0],
- ContentBlock::ToolResult {
- is_error: false,
- ..
- }
- ));
- }
- #[test]
- fn records_runtime_session_trace_events() {
- let sink = Arc::new(MemoryTelemetrySink::default());
- let tracer = SessionTracer::new("session-runtime", sink.clone());
- let mut runtime = ConversationRuntime::new(
- Session::new(),
- ScriptedApiClient { call_count: 0 },
- StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
- PermissionPolicy::new(PermissionMode::WorkspaceWrite),
- vec!["system".to_string()],
- )
- .with_session_tracer(tracer);
- runtime
- .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
- .expect("conversation loop should succeed");
- let events = sink.events();
- let trace_names = events
- .iter()
- .filter_map(|event| match event {
- TelemetryEvent::SessionTrace(trace) => Some(trace.name.as_str()),
- _ => None,
- })
- .collect::<Vec<_>>();
- assert!(trace_names.contains(&"turn_started"));
- assert!(trace_names.contains(&"assistant_iteration_completed"));
- assert!(trace_names.contains(&"tool_execution_started"));
- assert!(trace_names.contains(&"tool_execution_finished"));
- assert!(trace_names.contains(&"turn_completed"));
- }
- #[test]
- fn records_denied_tool_results_when_prompt_rejects() {
- struct RejectPrompter;
- impl PermissionPrompter for RejectPrompter {
- fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
- PermissionPromptDecision::Deny {
- reason: "not now".to_string(),
- }
- }
- }
- struct SingleCallApiClient;
- impl ApiClient for SingleCallApiClient {
- fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
- if request
- .messages
- .iter()
- .any(|message| message.role == MessageRole::Tool)
- {
- return Ok(vec![
- AssistantEvent::TextDelta("I could not use the tool.".to_string()),
- AssistantEvent::MessageStop,
- ]);
- }
- Ok(vec![
- AssistantEvent::ToolUse {
- id: "tool-1".to_string(),
- name: "blocked".to_string(),
- input: "secret".to_string(),
- },
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut runtime = ConversationRuntime::new(
- Session::new(),
- SingleCallApiClient,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::WorkspaceWrite),
- vec!["system".to_string()],
- );
- let summary = runtime
- .run_turn("use the tool", Some(&mut RejectPrompter))
- .expect("conversation should continue after denied tool");
- assert_eq!(summary.tool_results.len(), 1);
- assert!(matches!(
- &summary.tool_results[0].blocks[0],
- ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
- ));
- }
- #[test]
- fn denies_tool_use_when_pre_tool_hook_blocks() {
- struct SingleCallApiClient;
- impl ApiClient for SingleCallApiClient {
- fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
- if request
- .messages
- .iter()
- .any(|message| message.role == MessageRole::Tool)
- {
- return Ok(vec![
- AssistantEvent::TextDelta("blocked".to_string()),
- AssistantEvent::MessageStop,
- ]);
- }
- Ok(vec![
- AssistantEvent::ToolUse {
- id: "tool-1".to_string(),
- name: "blocked".to_string(),
- input: r#"{"path":"secret.txt"}"#.to_string(),
- },
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut runtime = ConversationRuntime::new_with_features(
- Session::new(),
- SingleCallApiClient,
- StaticToolExecutor::new().register("blocked", |_input| {
- panic!("tool should not execute when hook denies")
- }),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
- vec![shell_snippet("printf 'blocked by hook'; exit 2")],
- Vec::new(),
- Vec::new(),
- )),
- );
- let summary = runtime
- .run_turn("use the tool", None)
- .expect("conversation should continue after hook denial");
- assert_eq!(summary.tool_results.len(), 1);
- let ContentBlock::ToolResult {
- is_error, output, ..
- } = &summary.tool_results[0].blocks[0]
- else {
- panic!("expected tool result block");
- };
- assert!(
- *is_error,
- "hook denial should produce an error result: {output}"
- );
- assert!(
- output.contains("denied tool") || output.contains("blocked by hook"),
- "unexpected hook denial output: {output:?}"
- );
- }
- #[test]
- fn appends_post_tool_hook_feedback_to_tool_result() {
- struct TwoCallApiClient {
- calls: usize,
- }
- impl ApiClient for TwoCallApiClient {
- fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
- self.calls += 1;
- match self.calls {
- 1 => Ok(vec![
- AssistantEvent::ToolUse {
- id: "tool-1".to_string(),
- name: "add".to_string(),
- input: r#"{"lhs":2,"rhs":2}"#.to_string(),
- },
- AssistantEvent::MessageStop,
- ]),
- 2 => {
- assert!(request
- .messages
- .iter()
- .any(|message| message.role == MessageRole::Tool));
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::MessageStop,
- ])
- }
- _ => Err(RuntimeError::new("unexpected extra API call")),
- }
- }
- }
- let mut runtime = ConversationRuntime::new_with_features(
- Session::new(),
- TwoCallApiClient { calls: 0 },
- StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
- vec![shell_snippet("printf 'pre hook ran'")],
- vec![shell_snippet("printf 'post hook ran'")],
- Vec::new(),
- )),
- );
- let summary = runtime
- .run_turn("use add", None)
- .expect("tool loop succeeds");
- assert_eq!(summary.tool_results.len(), 1);
- let ContentBlock::ToolResult {
- is_error, output, ..
- } = &summary.tool_results[0].blocks[0]
- else {
- panic!("expected tool result block");
- };
- assert!(
- !*is_error,
- "post hook should preserve non-error result: {output:?}"
- );
- assert!(
- output.contains('4'),
- "tool output missing value: {output:?}"
- );
- assert!(
- output.contains("pre hook ran"),
- "tool output missing pre hook feedback: {output:?}"
- );
- assert!(
- output.contains("post hook ran"),
- "tool output missing post hook feedback: {output:?}"
- );
- }
- #[test]
- fn reconstructs_usage_tracker_from_restored_session() {
- struct SimpleApi;
- impl ApiClient for SimpleApi {
- fn stream(
- &mut self,
- _request: ApiRequest,
- ) -> Result<Vec<AssistantEvent>, RuntimeError> {
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut session = Session::new();
- session
- .messages
- .push(crate::session::ConversationMessage::assistant_with_usage(
- vec![ContentBlock::Text {
- text: "earlier".to_string(),
- }],
- Some(TokenUsage {
- input_tokens: 11,
- output_tokens: 7,
- cache_creation_input_tokens: 2,
- cache_read_input_tokens: 1,
- }),
- ));
- let runtime = ConversationRuntime::new(
- session,
- SimpleApi,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- );
- assert_eq!(runtime.usage().turns(), 1);
- assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
- }
- #[test]
- fn compacts_session_after_turns() {
- struct SimpleApi;
- impl ApiClient for SimpleApi {
- fn stream(
- &mut self,
- _request: ApiRequest,
- ) -> Result<Vec<AssistantEvent>, RuntimeError> {
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut runtime = ConversationRuntime::new(
- Session::new(),
- SimpleApi,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- );
- runtime.run_turn("a", None).expect("turn a");
- runtime.run_turn("b", None).expect("turn b");
- runtime.run_turn("c", None).expect("turn c");
- let result = runtime.compact(CompactionConfig {
- preserve_recent_messages: 2,
- max_estimated_tokens: 1,
- });
- assert!(result.summary.contains("Conversation summary"));
- assert_eq!(
- result.compacted_session.messages[0].role,
- MessageRole::System
- );
- assert_eq!(
- result.compacted_session.session_id,
- runtime.session().session_id
- );
- assert!(result.compacted_session.compaction.is_some());
- }
- #[test]
- fn persists_conversation_turn_messages_to_jsonl_session() {
- struct SimpleApi;
- impl ApiClient for SimpleApi {
- fn stream(
- &mut self,
- _request: ApiRequest,
- ) -> Result<Vec<AssistantEvent>, RuntimeError> {
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::MessageStop,
- ])
- }
- }
- let path = temp_session_path("persisted-turn");
- let session = Session::new().with_persistence_path(path.clone());
- let mut runtime = ConversationRuntime::new(
- session,
- SimpleApi,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- );
- runtime
- .run_turn("persist this turn", None)
- .expect("turn should succeed");
- let restored = Session::load_from_path(&path).expect("persisted session should reload");
- fs::remove_file(&path).expect("temp session file should be removable");
- assert_eq!(restored.messages.len(), 2);
- assert_eq!(restored.messages[0].role, MessageRole::User);
- assert_eq!(restored.messages[1].role, MessageRole::Assistant);
- assert_eq!(restored.session_id, runtime.session().session_id);
- }
- #[test]
- fn forks_runtime_session_without_mutating_original() {
- let mut session = Session::new();
- session
- .push_user_text("branch me")
- .expect("message should append");
- let runtime = ConversationRuntime::new(
- session.clone(),
- ScriptedApiClient { call_count: 0 },
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- );
- let forked = runtime.fork_session(Some("alt-path".to_string()));
- assert_eq!(forked.messages, session.messages);
- assert_ne!(forked.session_id, session.session_id);
- assert_eq!(
- forked
- .fork
- .as_ref()
- .map(|fork| (fork.parent_session_id.as_str(), fork.branch_name.as_deref())),
- Some((session.session_id.as_str(), Some("alt-path")))
- );
- assert!(runtime.session().fork.is_none());
- }
- fn temp_session_path(label: &str) -> PathBuf {
- let nanos = SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .expect("system time should be after epoch")
- .as_nanos();
- std::env::temp_dir().join(format!("runtime-conversation-{label}-{nanos}.json"))
- }
- #[cfg(windows)]
- fn shell_snippet(script: &str) -> String {
- script.replace('\'', "\"")
- }
- #[cfg(not(windows))]
- fn shell_snippet(script: &str) -> String {
- script.to_string()
- }
- #[test]
- fn auto_compacts_when_cumulative_input_threshold_is_crossed() {
- struct SimpleApi;
- impl ApiClient for SimpleApi {
- fn stream(
- &mut self,
- _request: ApiRequest,
- ) -> Result<Vec<AssistantEvent>, RuntimeError> {
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::Usage(TokenUsage {
- input_tokens: 120_000,
- output_tokens: 4,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- }),
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut session = Session::new();
- session.messages = vec![
- crate::session::ConversationMessage::user_text("one"),
- crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
- text: "two".to_string(),
- }]),
- crate::session::ConversationMessage::user_text("three"),
- crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
- text: "four".to_string(),
- }]),
- ];
- let mut runtime = ConversationRuntime::new(
- session,
- SimpleApi,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- )
- .with_auto_compaction_input_tokens_threshold(100_000);
- let summary = runtime
- .run_turn("trigger", None)
- .expect("turn should succeed");
- assert_eq!(
- summary.auto_compaction,
- Some(AutoCompactionEvent {
- removed_message_count: 2,
- })
- );
- assert_eq!(runtime.session().messages[0].role, MessageRole::System);
- }
- #[test]
- fn skips_auto_compaction_below_threshold() {
- struct SimpleApi;
- impl ApiClient for SimpleApi {
- fn stream(
- &mut self,
- _request: ApiRequest,
- ) -> Result<Vec<AssistantEvent>, RuntimeError> {
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::Usage(TokenUsage {
- input_tokens: 99_999,
- output_tokens: 4,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- }),
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut runtime = ConversationRuntime::new(
- Session::new(),
- SimpleApi,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::DangerFullAccess),
- vec!["system".to_string()],
- )
- .with_auto_compaction_input_tokens_threshold(100_000);
- let summary = runtime
- .run_turn("trigger", None)
- .expect("turn should succeed");
- assert_eq!(summary.auto_compaction, None);
- assert_eq!(runtime.session().messages.len(), 2);
- }
- #[test]
- fn auto_compaction_threshold_defaults_and_parses_values() {
- assert_eq!(
- parse_auto_compaction_threshold(None),
- DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
- );
- assert_eq!(parse_auto_compaction_threshold(Some("4321")), 4321);
- assert_eq!(
- parse_auto_compaction_threshold(Some("not-a-number")),
- DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
- );
- }
- }
|