| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583 |
- use std::collections::BTreeMap;
- use std::fmt::{Display, Formatter};
- use crate::compact::{
- compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
- };
- use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter};
- use crate::session::{ContentBlock, ConversationMessage, Session};
- use crate::usage::{TokenUsage, UsageTracker};
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct ApiRequest {
- pub system_prompt: Vec<String>,
- pub messages: Vec<ConversationMessage>,
- }
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub enum AssistantEvent {
- TextDelta(String),
- ToolUse {
- id: String,
- name: String,
- input: String,
- },
- Usage(TokenUsage),
- MessageStop,
- }
- pub trait ApiClient {
- fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
- }
- pub trait ToolExecutor {
- fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError>;
- }
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct ToolError {
- message: String,
- }
- impl ToolError {
- #[must_use]
- pub fn new(message: impl Into<String>) -> Self {
- Self {
- message: message.into(),
- }
- }
- }
- impl Display for ToolError {
- fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.message)
- }
- }
- impl std::error::Error for ToolError {}
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct RuntimeError {
- message: String,
- }
- impl RuntimeError {
- #[must_use]
- pub fn new(message: impl Into<String>) -> Self {
- Self {
- message: message.into(),
- }
- }
- }
- impl Display for RuntimeError {
- fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
- write!(f, "{}", self.message)
- }
- }
- impl std::error::Error for RuntimeError {}
- #[derive(Debug, Clone, PartialEq, Eq)]
- pub struct TurnSummary {
- pub assistant_messages: Vec<ConversationMessage>,
- pub tool_results: Vec<ConversationMessage>,
- pub iterations: usize,
- pub usage: TokenUsage,
- }
- pub struct ConversationRuntime<C, T> {
- session: Session,
- api_client: C,
- tool_executor: T,
- permission_policy: PermissionPolicy,
- system_prompt: Vec<String>,
- max_iterations: usize,
- usage_tracker: UsageTracker,
- }
- impl<C, T> ConversationRuntime<C, T>
- where
- C: ApiClient,
- T: ToolExecutor,
- {
- #[must_use]
- pub fn new(
- session: Session,
- api_client: C,
- tool_executor: T,
- permission_policy: PermissionPolicy,
- system_prompt: Vec<String>,
- ) -> Self {
- let usage_tracker = UsageTracker::from_session(&session);
- Self {
- session,
- api_client,
- tool_executor,
- permission_policy,
- system_prompt,
- max_iterations: 16,
- usage_tracker,
- }
- }
- #[must_use]
- pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
- self.max_iterations = max_iterations;
- self
- }
- pub fn run_turn(
- &mut self,
- user_input: impl Into<String>,
- mut prompter: Option<&mut dyn PermissionPrompter>,
- ) -> Result<TurnSummary, RuntimeError> {
- self.session
- .messages
- .push(ConversationMessage::user_text(user_input.into()));
- let mut assistant_messages = Vec::new();
- let mut tool_results = Vec::new();
- let mut iterations = 0;
- loop {
- iterations += 1;
- if iterations > self.max_iterations {
- return Err(RuntimeError::new(
- "conversation loop exceeded the maximum number of iterations",
- ));
- }
- let request = ApiRequest {
- system_prompt: self.system_prompt.clone(),
- messages: self.session.messages.clone(),
- };
- let events = self.api_client.stream(request)?;
- let (assistant_message, usage) = build_assistant_message(events)?;
- if let Some(usage) = usage {
- self.usage_tracker.record(usage);
- }
- let pending_tool_uses = assistant_message
- .blocks
- .iter()
- .filter_map(|block| match block {
- ContentBlock::ToolUse { id, name, input } => {
- Some((id.clone(), name.clone(), input.clone()))
- }
- _ => None,
- })
- .collect::<Vec<_>>();
- self.session.messages.push(assistant_message.clone());
- assistant_messages.push(assistant_message);
- if pending_tool_uses.is_empty() {
- break;
- }
- for (tool_use_id, tool_name, input) in pending_tool_uses {
- let permission_outcome = if let Some(prompt) = prompter.as_mut() {
- self.permission_policy
- .authorize(&tool_name, &input, Some(*prompt))
- } else {
- self.permission_policy.authorize(&tool_name, &input, None)
- };
- let result_message = match permission_outcome {
- PermissionOutcome::Allow => {
- match self.tool_executor.execute(&tool_name, &input) {
- Ok(output) => ConversationMessage::tool_result(
- tool_use_id,
- tool_name,
- output,
- false,
- ),
- Err(error) => ConversationMessage::tool_result(
- tool_use_id,
- tool_name,
- error.to_string(),
- true,
- ),
- }
- }
- PermissionOutcome::Deny { reason } => {
- ConversationMessage::tool_result(tool_use_id, tool_name, reason, true)
- }
- };
- self.session.messages.push(result_message.clone());
- tool_results.push(result_message);
- }
- }
- Ok(TurnSummary {
- assistant_messages,
- tool_results,
- iterations,
- usage: self.usage_tracker.cumulative_usage(),
- })
- }
- #[must_use]
- pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
- compact_session(&self.session, config)
- }
- #[must_use]
- pub fn estimated_tokens(&self) -> usize {
- estimate_session_tokens(&self.session)
- }
- #[must_use]
- pub fn usage(&self) -> &UsageTracker {
- &self.usage_tracker
- }
- #[must_use]
- pub fn session(&self) -> &Session {
- &self.session
- }
- #[must_use]
- pub fn into_session(self) -> Session {
- self.session
- }
- }
- fn build_assistant_message(
- events: Vec<AssistantEvent>,
- ) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
- let mut text = String::new();
- let mut blocks = Vec::new();
- let mut finished = false;
- let mut usage = None;
- for event in events {
- match event {
- AssistantEvent::TextDelta(delta) => text.push_str(&delta),
- AssistantEvent::ToolUse { id, name, input } => {
- flush_text_block(&mut text, &mut blocks);
- blocks.push(ContentBlock::ToolUse { id, name, input });
- }
- AssistantEvent::Usage(value) => usage = Some(value),
- AssistantEvent::MessageStop => {
- finished = true;
- }
- }
- }
- flush_text_block(&mut text, &mut blocks);
- if !finished {
- return Err(RuntimeError::new(
- "assistant stream ended without a message stop event",
- ));
- }
- if blocks.is_empty() {
- return Err(RuntimeError::new("assistant stream produced no content"));
- }
- Ok((
- ConversationMessage::assistant_with_usage(blocks, usage),
- usage,
- ))
- }
- fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
- if !text.is_empty() {
- blocks.push(ContentBlock::Text {
- text: std::mem::take(text),
- });
- }
- }
- type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
- #[derive(Default)]
- pub struct StaticToolExecutor {
- handlers: BTreeMap<String, ToolHandler>,
- }
- impl StaticToolExecutor {
- #[must_use]
- pub fn new() -> Self {
- Self::default()
- }
- #[must_use]
- pub fn register(
- mut self,
- tool_name: impl Into<String>,
- handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
- ) -> Self {
- self.handlers.insert(tool_name.into(), Box::new(handler));
- self
- }
- }
- impl ToolExecutor for StaticToolExecutor {
- fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
- self.handlers
- .get_mut(tool_name)
- .ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
- }
- }
- #[cfg(test)]
- mod tests {
- use super::{
- ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError,
- StaticToolExecutor,
- };
- use crate::compact::CompactionConfig;
- use crate::permissions::{
- PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
- PermissionRequest,
- };
- use crate::prompt::{ProjectContext, SystemPromptBuilder};
- use crate::session::{ContentBlock, MessageRole, Session};
- use crate::usage::TokenUsage;
- use std::path::PathBuf;
- struct ScriptedApiClient {
- call_count: usize,
- }
- impl ApiClient for ScriptedApiClient {
- fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
- self.call_count += 1;
- match self.call_count {
- 1 => {
- assert!(request
- .messages
- .iter()
- .any(|message| message.role == MessageRole::User));
- Ok(vec![
- AssistantEvent::TextDelta("Let me calculate that.".to_string()),
- AssistantEvent::ToolUse {
- id: "tool-1".to_string(),
- name: "add".to_string(),
- input: "2,2".to_string(),
- },
- AssistantEvent::Usage(TokenUsage {
- input_tokens: 20,
- output_tokens: 6,
- cache_creation_input_tokens: 1,
- cache_read_input_tokens: 2,
- }),
- AssistantEvent::MessageStop,
- ])
- }
- 2 => {
- let last_message = request
- .messages
- .last()
- .expect("tool result should be present");
- assert_eq!(last_message.role, MessageRole::Tool);
- Ok(vec![
- AssistantEvent::TextDelta("The answer is 4.".to_string()),
- AssistantEvent::Usage(TokenUsage {
- input_tokens: 24,
- output_tokens: 4,
- cache_creation_input_tokens: 1,
- cache_read_input_tokens: 3,
- }),
- AssistantEvent::MessageStop,
- ])
- }
- _ => Err(RuntimeError::new("unexpected extra API call")),
- }
- }
- }
- struct PromptAllowOnce;
- impl PermissionPrompter for PromptAllowOnce {
- fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
- assert_eq!(request.tool_name, "add");
- PermissionPromptDecision::Allow
- }
- }
- #[test]
- fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
- let api_client = ScriptedApiClient { call_count: 0 };
- let tool_executor = StaticToolExecutor::new().register("add", |input| {
- let total = input
- .split(',')
- .map(|part| part.parse::<i32>().expect("input must be valid integer"))
- .sum::<i32>();
- Ok(total.to_string())
- });
- let permission_policy = PermissionPolicy::new(PermissionMode::Prompt);
- let system_prompt = SystemPromptBuilder::new()
- .with_project_context(ProjectContext {
- cwd: PathBuf::from("/tmp/project"),
- current_date: "2026-03-31".to_string(),
- git_status: None,
- instruction_files: Vec::new(),
- })
- .with_os("linux", "6.8")
- .build();
- let mut runtime = ConversationRuntime::new(
- Session::new(),
- api_client,
- tool_executor,
- permission_policy,
- system_prompt,
- );
- let summary = runtime
- .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
- .expect("conversation loop should succeed");
- assert_eq!(summary.iterations, 2);
- assert_eq!(summary.assistant_messages.len(), 2);
- assert_eq!(summary.tool_results.len(), 1);
- assert_eq!(runtime.session().messages.len(), 4);
- assert_eq!(summary.usage.output_tokens, 10);
- assert!(matches!(
- runtime.session().messages[1].blocks[1],
- ContentBlock::ToolUse { .. }
- ));
- assert!(matches!(
- runtime.session().messages[2].blocks[0],
- ContentBlock::ToolResult {
- is_error: false,
- ..
- }
- ));
- }
- #[test]
- fn records_denied_tool_results_when_prompt_rejects() {
- struct RejectPrompter;
- impl PermissionPrompter for RejectPrompter {
- fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
- PermissionPromptDecision::Deny {
- reason: "not now".to_string(),
- }
- }
- }
- struct SingleCallApiClient;
- impl ApiClient for SingleCallApiClient {
- fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
- if request
- .messages
- .iter()
- .any(|message| message.role == MessageRole::Tool)
- {
- return Ok(vec![
- AssistantEvent::TextDelta("I could not use the tool.".to_string()),
- AssistantEvent::MessageStop,
- ]);
- }
- Ok(vec![
- AssistantEvent::ToolUse {
- id: "tool-1".to_string(),
- name: "blocked".to_string(),
- input: "secret".to_string(),
- },
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut runtime = ConversationRuntime::new(
- Session::new(),
- SingleCallApiClient,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::Prompt),
- vec!["system".to_string()],
- );
- let summary = runtime
- .run_turn("use the tool", Some(&mut RejectPrompter))
- .expect("conversation should continue after denied tool");
- assert_eq!(summary.tool_results.len(), 1);
- assert!(matches!(
- &summary.tool_results[0].blocks[0],
- ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
- ));
- }
- #[test]
- fn reconstructs_usage_tracker_from_restored_session() {
- struct SimpleApi;
- impl ApiClient for SimpleApi {
- fn stream(
- &mut self,
- _request: ApiRequest,
- ) -> Result<Vec<AssistantEvent>, RuntimeError> {
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut session = Session::new();
- session
- .messages
- .push(crate::session::ConversationMessage::assistant_with_usage(
- vec![ContentBlock::Text {
- text: "earlier".to_string(),
- }],
- Some(TokenUsage {
- input_tokens: 11,
- output_tokens: 7,
- cache_creation_input_tokens: 2,
- cache_read_input_tokens: 1,
- }),
- ));
- let runtime = ConversationRuntime::new(
- session,
- SimpleApi,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::Allow),
- vec!["system".to_string()],
- );
- assert_eq!(runtime.usage().turns(), 1);
- assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
- }
- #[test]
- fn compacts_session_after_turns() {
- struct SimpleApi;
- impl ApiClient for SimpleApi {
- fn stream(
- &mut self,
- _request: ApiRequest,
- ) -> Result<Vec<AssistantEvent>, RuntimeError> {
- Ok(vec![
- AssistantEvent::TextDelta("done".to_string()),
- AssistantEvent::MessageStop,
- ])
- }
- }
- let mut runtime = ConversationRuntime::new(
- Session::new(),
- SimpleApi,
- StaticToolExecutor::new(),
- PermissionPolicy::new(PermissionMode::Allow),
- vec!["system".to_string()],
- );
- runtime.run_turn("a", None).expect("turn a");
- runtime.run_turn("b", None).expect("turn b");
- runtime.run_turn("c", None).expect("turn c");
- let result = runtime.compact(CompactionConfig {
- preserve_recent_messages: 2,
- max_estimated_tokens: 1,
- });
- assert!(result.summary.contains("Conversation summary"));
- assert_eq!(
- result.compacted_session.messages[0].role,
- MessageRole::System
- );
- }
- }
|