conversation.rs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. use std::collections::BTreeMap;
  2. use std::fmt::{Display, Formatter};
  3. use crate::compact::{
  4. compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
  5. };
  6. use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter};
  7. use crate::session::{ContentBlock, ConversationMessage, Session};
  8. use crate::usage::{TokenUsage, UsageTracker};
  9. #[derive(Debug, Clone, PartialEq, Eq)]
  10. pub struct ApiRequest {
  11. pub system_prompt: Vec<String>,
  12. pub messages: Vec<ConversationMessage>,
  13. }
  14. #[derive(Debug, Clone, PartialEq, Eq)]
  15. pub enum AssistantEvent {
  16. TextDelta(String),
  17. ThinkingDelta(String),
  18. ThinkingSignature(String),
  19. ToolUse {
  20. id: String,
  21. name: String,
  22. input: String,
  23. },
  24. Usage(TokenUsage),
  25. MessageStop,
  26. }
  27. pub trait ApiClient {
  28. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
  29. }
  30. pub trait ToolExecutor {
  31. fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError>;
  32. }
  33. #[derive(Debug, Clone, PartialEq, Eq)]
  34. pub struct ToolError {
  35. message: String,
  36. }
  37. impl ToolError {
  38. #[must_use]
  39. pub fn new(message: impl Into<String>) -> Self {
  40. Self {
  41. message: message.into(),
  42. }
  43. }
  44. }
  45. impl Display for ToolError {
  46. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  47. write!(f, "{}", self.message)
  48. }
  49. }
  50. impl std::error::Error for ToolError {}
  51. #[derive(Debug, Clone, PartialEq, Eq)]
  52. pub struct RuntimeError {
  53. message: String,
  54. }
  55. impl RuntimeError {
  56. #[must_use]
  57. pub fn new(message: impl Into<String>) -> Self {
  58. Self {
  59. message: message.into(),
  60. }
  61. }
  62. }
  63. impl Display for RuntimeError {
  64. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  65. write!(f, "{}", self.message)
  66. }
  67. }
  68. impl std::error::Error for RuntimeError {}
  69. #[derive(Debug, Clone, PartialEq, Eq)]
  70. pub struct TurnSummary {
  71. pub assistant_messages: Vec<ConversationMessage>,
  72. pub tool_results: Vec<ConversationMessage>,
  73. pub iterations: usize,
  74. pub usage: TokenUsage,
  75. }
  76. pub struct ConversationRuntime<C, T> {
  77. session: Session,
  78. api_client: C,
  79. tool_executor: T,
  80. permission_policy: PermissionPolicy,
  81. system_prompt: Vec<String>,
  82. max_iterations: usize,
  83. usage_tracker: UsageTracker,
  84. }
  85. impl<C, T> ConversationRuntime<C, T>
  86. where
  87. C: ApiClient,
  88. T: ToolExecutor,
  89. {
  90. #[must_use]
  91. pub fn new(
  92. session: Session,
  93. api_client: C,
  94. tool_executor: T,
  95. permission_policy: PermissionPolicy,
  96. system_prompt: Vec<String>,
  97. ) -> Self {
  98. let usage_tracker = UsageTracker::from_session(&session);
  99. Self {
  100. session,
  101. api_client,
  102. tool_executor,
  103. permission_policy,
  104. system_prompt,
  105. max_iterations: 16,
  106. usage_tracker,
  107. }
  108. }
  109. #[must_use]
  110. pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
  111. self.max_iterations = max_iterations;
  112. self
  113. }
  114. pub fn run_turn(
  115. &mut self,
  116. user_input: impl Into<String>,
  117. mut prompter: Option<&mut dyn PermissionPrompter>,
  118. ) -> Result<TurnSummary, RuntimeError> {
  119. self.session
  120. .messages
  121. .push(ConversationMessage::user_text(user_input.into()));
  122. let mut assistant_messages = Vec::new();
  123. let mut tool_results = Vec::new();
  124. let mut iterations = 0;
  125. loop {
  126. iterations += 1;
  127. if iterations > self.max_iterations {
  128. return Err(RuntimeError::new(
  129. "conversation loop exceeded the maximum number of iterations",
  130. ));
  131. }
  132. let request = ApiRequest {
  133. system_prompt: self.system_prompt.clone(),
  134. messages: self.session.messages.clone(),
  135. };
  136. let events = self.api_client.stream(request)?;
  137. let (assistant_message, usage) = build_assistant_message(events)?;
  138. if let Some(usage) = usage {
  139. self.usage_tracker.record(usage);
  140. }
  141. let pending_tool_uses = assistant_message
  142. .blocks
  143. .iter()
  144. .filter_map(|block| match block {
  145. ContentBlock::ToolUse { id, name, input } => {
  146. Some((id.clone(), name.clone(), input.clone()))
  147. }
  148. _ => None,
  149. })
  150. .collect::<Vec<_>>();
  151. self.session.messages.push(assistant_message.clone());
  152. assistant_messages.push(assistant_message);
  153. if pending_tool_uses.is_empty() {
  154. break;
  155. }
  156. for (tool_use_id, tool_name, input) in pending_tool_uses {
  157. let permission_outcome = if let Some(prompt) = prompter.as_mut() {
  158. self.permission_policy
  159. .authorize(&tool_name, &input, Some(*prompt))
  160. } else {
  161. self.permission_policy.authorize(&tool_name, &input, None)
  162. };
  163. let result_message = match permission_outcome {
  164. PermissionOutcome::Allow => {
  165. match self.tool_executor.execute(&tool_name, &input) {
  166. Ok(output) => ConversationMessage::tool_result(
  167. tool_use_id,
  168. tool_name,
  169. output,
  170. false,
  171. ),
  172. Err(error) => ConversationMessage::tool_result(
  173. tool_use_id,
  174. tool_name,
  175. error.to_string(),
  176. true,
  177. ),
  178. }
  179. }
  180. PermissionOutcome::Deny { reason } => {
  181. ConversationMessage::tool_result(tool_use_id, tool_name, reason, true)
  182. }
  183. };
  184. self.session.messages.push(result_message.clone());
  185. tool_results.push(result_message);
  186. }
  187. }
  188. Ok(TurnSummary {
  189. assistant_messages,
  190. tool_results,
  191. iterations,
  192. usage: self.usage_tracker.cumulative_usage(),
  193. })
  194. }
  195. #[must_use]
  196. pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
  197. compact_session(&self.session, config)
  198. }
  199. #[must_use]
  200. pub fn estimated_tokens(&self) -> usize {
  201. estimate_session_tokens(&self.session)
  202. }
  203. #[must_use]
  204. pub fn usage(&self) -> &UsageTracker {
  205. &self.usage_tracker
  206. }
  207. #[must_use]
  208. pub fn session(&self) -> &Session {
  209. &self.session
  210. }
  211. #[must_use]
  212. pub fn into_session(self) -> Session {
  213. self.session
  214. }
  215. }
  216. fn build_assistant_message(
  217. events: Vec<AssistantEvent>,
  218. ) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
  219. let mut text = String::new();
  220. let mut thinking = String::new();
  221. let mut thinking_signature: Option<String> = None;
  222. let mut blocks = Vec::new();
  223. let mut finished = false;
  224. let mut usage = None;
  225. for event in events {
  226. match event {
  227. AssistantEvent::TextDelta(delta) => {
  228. flush_thinking_block(&mut thinking, &mut thinking_signature, &mut blocks);
  229. text.push_str(&delta);
  230. }
  231. AssistantEvent::ThinkingDelta(delta) => {
  232. flush_text_block(&mut text, &mut blocks);
  233. thinking.push_str(&delta);
  234. }
  235. AssistantEvent::ThinkingSignature(signature) => thinking_signature = Some(signature),
  236. AssistantEvent::ToolUse { id, name, input } => {
  237. flush_text_block(&mut text, &mut blocks);
  238. flush_thinking_block(&mut thinking, &mut thinking_signature, &mut blocks);
  239. blocks.push(ContentBlock::ToolUse { id, name, input });
  240. }
  241. AssistantEvent::Usage(value) => usage = Some(value),
  242. AssistantEvent::MessageStop => {
  243. finished = true;
  244. }
  245. }
  246. }
  247. flush_text_block(&mut text, &mut blocks);
  248. flush_thinking_block(&mut thinking, &mut thinking_signature, &mut blocks);
  249. if !finished {
  250. return Err(RuntimeError::new(
  251. "assistant stream ended without a message stop event",
  252. ));
  253. }
  254. if blocks.is_empty() {
  255. return Err(RuntimeError::new("assistant stream produced no content"));
  256. }
  257. Ok((
  258. ConversationMessage::assistant_with_usage(blocks, usage),
  259. usage,
  260. ))
  261. }
  262. fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
  263. if !text.is_empty() {
  264. blocks.push(ContentBlock::Text {
  265. text: std::mem::take(text),
  266. });
  267. }
  268. }
  269. fn flush_thinking_block(
  270. thinking: &mut String,
  271. signature: &mut Option<String>,
  272. blocks: &mut Vec<ContentBlock>,
  273. ) {
  274. if !thinking.is_empty() || signature.is_some() {
  275. blocks.push(ContentBlock::Thinking {
  276. text: std::mem::take(thinking),
  277. signature: signature.take(),
  278. });
  279. }
  280. }
  281. type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
  282. #[derive(Default)]
  283. pub struct StaticToolExecutor {
  284. handlers: BTreeMap<String, ToolHandler>,
  285. }
  286. impl StaticToolExecutor {
  287. #[must_use]
  288. pub fn new() -> Self {
  289. Self::default()
  290. }
  291. #[must_use]
  292. pub fn register(
  293. mut self,
  294. tool_name: impl Into<String>,
  295. handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
  296. ) -> Self {
  297. self.handlers.insert(tool_name.into(), Box::new(handler));
  298. self
  299. }
  300. }
  301. impl ToolExecutor for StaticToolExecutor {
  302. fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
  303. self.handlers
  304. .get_mut(tool_name)
  305. .ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
  306. }
  307. }
  308. #[cfg(test)]
  309. mod tests {
  310. use super::{
  311. build_assistant_message, ApiClient, ApiRequest, AssistantEvent, ConversationRuntime,
  312. RuntimeError, StaticToolExecutor,
  313. };
  314. use crate::compact::CompactionConfig;
  315. use crate::permissions::{
  316. PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
  317. PermissionRequest,
  318. };
  319. use crate::prompt::{ProjectContext, SystemPromptBuilder};
  320. use crate::session::{ContentBlock, MessageRole, Session};
  321. use crate::usage::TokenUsage;
  322. use std::path::PathBuf;
  323. struct ScriptedApiClient {
  324. call_count: usize,
  325. }
  326. impl ApiClient for ScriptedApiClient {
  327. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  328. self.call_count += 1;
  329. match self.call_count {
  330. 1 => {
  331. assert!(request
  332. .messages
  333. .iter()
  334. .any(|message| message.role == MessageRole::User));
  335. Ok(vec![
  336. AssistantEvent::TextDelta("Let me calculate that.".to_string()),
  337. AssistantEvent::ToolUse {
  338. id: "tool-1".to_string(),
  339. name: "add".to_string(),
  340. input: "2,2".to_string(),
  341. },
  342. AssistantEvent::Usage(TokenUsage {
  343. input_tokens: 20,
  344. output_tokens: 6,
  345. cache_creation_input_tokens: 1,
  346. cache_read_input_tokens: 2,
  347. }),
  348. AssistantEvent::MessageStop,
  349. ])
  350. }
  351. 2 => {
  352. let last_message = request
  353. .messages
  354. .last()
  355. .expect("tool result should be present");
  356. assert_eq!(last_message.role, MessageRole::Tool);
  357. Ok(vec![
  358. AssistantEvent::TextDelta("The answer is 4.".to_string()),
  359. AssistantEvent::Usage(TokenUsage {
  360. input_tokens: 24,
  361. output_tokens: 4,
  362. cache_creation_input_tokens: 1,
  363. cache_read_input_tokens: 3,
  364. }),
  365. AssistantEvent::MessageStop,
  366. ])
  367. }
  368. _ => Err(RuntimeError::new("unexpected extra API call")),
  369. }
  370. }
  371. }
  372. struct PromptAllowOnce;
  373. impl PermissionPrompter for PromptAllowOnce {
  374. fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
  375. assert_eq!(request.tool_name, "add");
  376. PermissionPromptDecision::Allow
  377. }
  378. }
  379. #[test]
  380. fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
  381. let api_client = ScriptedApiClient { call_count: 0 };
  382. let tool_executor = StaticToolExecutor::new().register("add", |input| {
  383. let total = input
  384. .split(',')
  385. .map(|part| part.parse::<i32>().expect("input must be valid integer"))
  386. .sum::<i32>();
  387. Ok(total.to_string())
  388. });
  389. let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
  390. let system_prompt = SystemPromptBuilder::new()
  391. .with_project_context(ProjectContext {
  392. cwd: PathBuf::from("/tmp/project"),
  393. current_date: "2026-03-31".to_string(),
  394. git_status: None,
  395. instruction_files: Vec::new(),
  396. memory_files: Vec::new(),
  397. })
  398. .with_os("linux", "6.8")
  399. .build();
  400. let mut runtime = ConversationRuntime::new(
  401. Session::new(),
  402. api_client,
  403. tool_executor,
  404. permission_policy,
  405. system_prompt,
  406. );
  407. let summary = runtime
  408. .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
  409. .expect("conversation loop should succeed");
  410. assert_eq!(summary.iterations, 2);
  411. assert_eq!(summary.assistant_messages.len(), 2);
  412. assert_eq!(summary.tool_results.len(), 1);
  413. assert_eq!(runtime.session().messages.len(), 4);
  414. assert_eq!(summary.usage.output_tokens, 10);
  415. assert!(matches!(
  416. runtime.session().messages[1].blocks[1],
  417. ContentBlock::ToolUse { .. }
  418. ));
  419. assert!(matches!(
  420. runtime.session().messages[2].blocks[0],
  421. ContentBlock::ToolResult {
  422. is_error: false,
  423. ..
  424. }
  425. ));
  426. }
  427. #[test]
  428. fn records_denied_tool_results_when_prompt_rejects() {
  429. struct RejectPrompter;
  430. impl PermissionPrompter for RejectPrompter {
  431. fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
  432. PermissionPromptDecision::Deny {
  433. reason: "not now".to_string(),
  434. }
  435. }
  436. }
  437. struct SingleCallApiClient;
  438. impl ApiClient for SingleCallApiClient {
  439. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  440. if request
  441. .messages
  442. .iter()
  443. .any(|message| message.role == MessageRole::Tool)
  444. {
  445. return Ok(vec![
  446. AssistantEvent::TextDelta("I could not use the tool.".to_string()),
  447. AssistantEvent::MessageStop,
  448. ]);
  449. }
  450. Ok(vec![
  451. AssistantEvent::ToolUse {
  452. id: "tool-1".to_string(),
  453. name: "blocked".to_string(),
  454. input: "secret".to_string(),
  455. },
  456. AssistantEvent::MessageStop,
  457. ])
  458. }
  459. }
  460. let mut runtime = ConversationRuntime::new(
  461. Session::new(),
  462. SingleCallApiClient,
  463. StaticToolExecutor::new(),
  464. PermissionPolicy::new(PermissionMode::WorkspaceWrite),
  465. vec!["system".to_string()],
  466. );
  467. let summary = runtime
  468. .run_turn("use the tool", Some(&mut RejectPrompter))
  469. .expect("conversation should continue after denied tool");
  470. assert_eq!(summary.tool_results.len(), 1);
  471. assert!(matches!(
  472. &summary.tool_results[0].blocks[0],
  473. ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
  474. ));
  475. }
  476. #[test]
  477. fn thinking_blocks_are_preserved_separately_from_text() {
  478. let (message, usage) = build_assistant_message(vec![
  479. AssistantEvent::ThinkingDelta("first ".to_string()),
  480. AssistantEvent::ThinkingDelta("second".to_string()),
  481. AssistantEvent::ThinkingSignature("sig-1".to_string()),
  482. AssistantEvent::TextDelta("final".to_string()),
  483. AssistantEvent::MessageStop,
  484. ])
  485. .expect("assistant message should build");
  486. assert_eq!(usage, None);
  487. assert!(matches!(
  488. &message.blocks[0],
  489. ContentBlock::Thinking { text, signature }
  490. if text == "first second" && signature.as_deref() == Some("sig-1")
  491. ));
  492. assert!(matches!(
  493. &message.blocks[1],
  494. ContentBlock::Text { text } if text == "final"
  495. ));
  496. }
  497. #[test]
  498. fn reconstructs_usage_tracker_from_restored_session() {
  499. struct SimpleApi;
  500. impl ApiClient for SimpleApi {
  501. fn stream(
  502. &mut self,
  503. _request: ApiRequest,
  504. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  505. Ok(vec![
  506. AssistantEvent::TextDelta("done".to_string()),
  507. AssistantEvent::MessageStop,
  508. ])
  509. }
  510. }
  511. let mut session = Session::new();
  512. session
  513. .messages
  514. .push(crate::session::ConversationMessage::assistant_with_usage(
  515. vec![ContentBlock::Text {
  516. text: "earlier".to_string(),
  517. }],
  518. Some(TokenUsage {
  519. input_tokens: 11,
  520. output_tokens: 7,
  521. cache_creation_input_tokens: 2,
  522. cache_read_input_tokens: 1,
  523. }),
  524. ));
  525. let runtime = ConversationRuntime::new(
  526. session,
  527. SimpleApi,
  528. StaticToolExecutor::new(),
  529. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  530. vec!["system".to_string()],
  531. );
  532. assert_eq!(runtime.usage().turns(), 1);
  533. assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
  534. }
  535. #[test]
  536. fn compacts_session_after_turns() {
  537. struct SimpleApi;
  538. impl ApiClient for SimpleApi {
  539. fn stream(
  540. &mut self,
  541. _request: ApiRequest,
  542. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  543. Ok(vec![
  544. AssistantEvent::TextDelta("done".to_string()),
  545. AssistantEvent::MessageStop,
  546. ])
  547. }
  548. }
  549. let mut runtime = ConversationRuntime::new(
  550. Session::new(),
  551. SimpleApi,
  552. StaticToolExecutor::new(),
  553. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  554. vec!["system".to_string()],
  555. );
  556. runtime.run_turn("a", None).expect("turn a");
  557. runtime.run_turn("b", None).expect("turn b");
  558. runtime.run_turn("c", None).expect("turn c");
  559. let result = runtime.compact(CompactionConfig {
  560. preserve_recent_messages: 2,
  561. max_estimated_tokens: 1,
  562. });
  563. assert!(result.summary.contains("Conversation summary"));
  564. assert_eq!(
  565. result.compacted_session.messages[0].role,
  566. MessageRole::System
  567. );
  568. }
  569. }