conversation.rs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. use std::collections::BTreeMap;
  2. use std::fmt::{Display, Formatter};
  3. use crate::compact::{
  4. compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
  5. };
  6. use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter};
  7. use crate::session::{ContentBlock, ConversationMessage, Session};
  8. use crate::usage::{TokenUsage, UsageTracker};
  9. #[derive(Debug, Clone, PartialEq, Eq)]
  10. pub struct ApiRequest {
  11. pub system_prompt: Vec<String>,
  12. pub messages: Vec<ConversationMessage>,
  13. }
  14. #[derive(Debug, Clone, PartialEq, Eq)]
  15. pub enum AssistantEvent {
  16. TextDelta(String),
  17. ToolUse {
  18. id: String,
  19. name: String,
  20. input: String,
  21. },
  22. Usage(TokenUsage),
  23. MessageStop,
  24. }
  25. pub trait ApiClient {
  26. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
  27. }
  28. pub trait ToolExecutor {
  29. fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError>;
  30. }
  31. #[derive(Debug, Clone, PartialEq, Eq)]
  32. pub struct ToolError {
  33. message: String,
  34. }
  35. impl ToolError {
  36. #[must_use]
  37. pub fn new(message: impl Into<String>) -> Self {
  38. Self {
  39. message: message.into(),
  40. }
  41. }
  42. }
  43. impl Display for ToolError {
  44. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  45. write!(f, "{}", self.message)
  46. }
  47. }
  48. impl std::error::Error for ToolError {}
  49. #[derive(Debug, Clone, PartialEq, Eq)]
  50. pub struct RuntimeError {
  51. message: String,
  52. }
  53. impl RuntimeError {
  54. #[must_use]
  55. pub fn new(message: impl Into<String>) -> Self {
  56. Self {
  57. message: message.into(),
  58. }
  59. }
  60. }
  61. impl Display for RuntimeError {
  62. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  63. write!(f, "{}", self.message)
  64. }
  65. }
  66. impl std::error::Error for RuntimeError {}
  67. #[derive(Debug, Clone, PartialEq, Eq)]
  68. pub struct TurnSummary {
  69. pub assistant_messages: Vec<ConversationMessage>,
  70. pub tool_results: Vec<ConversationMessage>,
  71. pub iterations: usize,
  72. pub usage: TokenUsage,
  73. }
  74. pub struct ConversationRuntime<C, T> {
  75. session: Session,
  76. api_client: C,
  77. tool_executor: T,
  78. permission_policy: PermissionPolicy,
  79. system_prompt: Vec<String>,
  80. max_iterations: usize,
  81. usage_tracker: UsageTracker,
  82. }
  83. impl<C, T> ConversationRuntime<C, T>
  84. where
  85. C: ApiClient,
  86. T: ToolExecutor,
  87. {
  88. #[must_use]
  89. pub fn new(
  90. session: Session,
  91. api_client: C,
  92. tool_executor: T,
  93. permission_policy: PermissionPolicy,
  94. system_prompt: Vec<String>,
  95. ) -> Self {
  96. let usage_tracker = UsageTracker::from_session(&session);
  97. Self {
  98. session,
  99. api_client,
  100. tool_executor,
  101. permission_policy,
  102. system_prompt,
  103. max_iterations: usize::MAX,
  104. usage_tracker,
  105. }
  106. }
  107. #[must_use]
  108. pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
  109. self.max_iterations = max_iterations;
  110. self
  111. }
  112. pub fn run_turn(
  113. &mut self,
  114. user_input: impl Into<String>,
  115. mut prompter: Option<&mut dyn PermissionPrompter>,
  116. ) -> Result<TurnSummary, RuntimeError> {
  117. self.session
  118. .messages
  119. .push(ConversationMessage::user_text(user_input.into()));
  120. let mut assistant_messages = Vec::new();
  121. let mut tool_results = Vec::new();
  122. let mut iterations = 0;
  123. loop {
  124. iterations += 1;
  125. if iterations > self.max_iterations {
  126. return Err(RuntimeError::new(
  127. "conversation loop exceeded the maximum number of iterations",
  128. ));
  129. }
  130. let request = ApiRequest {
  131. system_prompt: self.system_prompt.clone(),
  132. messages: self.session.messages.clone(),
  133. };
  134. let events = self.api_client.stream(request)?;
  135. let (assistant_message, usage) = build_assistant_message(events)?;
  136. if let Some(usage) = usage {
  137. self.usage_tracker.record(usage);
  138. }
  139. let pending_tool_uses = assistant_message
  140. .blocks
  141. .iter()
  142. .filter_map(|block| match block {
  143. ContentBlock::ToolUse { id, name, input } => {
  144. Some((id.clone(), name.clone(), input.clone()))
  145. }
  146. _ => None,
  147. })
  148. .collect::<Vec<_>>();
  149. self.session.messages.push(assistant_message.clone());
  150. assistant_messages.push(assistant_message);
  151. if pending_tool_uses.is_empty() {
  152. break;
  153. }
  154. for (tool_use_id, tool_name, input) in pending_tool_uses {
  155. let permission_outcome = if let Some(prompt) = prompter.as_mut() {
  156. self.permission_policy
  157. .authorize(&tool_name, &input, Some(*prompt))
  158. } else {
  159. self.permission_policy.authorize(&tool_name, &input, None)
  160. };
  161. let result_message = match permission_outcome {
  162. PermissionOutcome::Allow => {
  163. match self.tool_executor.execute(&tool_name, &input) {
  164. Ok(output) => ConversationMessage::tool_result(
  165. tool_use_id,
  166. tool_name,
  167. output,
  168. false,
  169. ),
  170. Err(error) => ConversationMessage::tool_result(
  171. tool_use_id,
  172. tool_name,
  173. error.to_string(),
  174. true,
  175. ),
  176. }
  177. }
  178. PermissionOutcome::Deny { reason } => {
  179. ConversationMessage::tool_result(tool_use_id, tool_name, reason, true)
  180. }
  181. };
  182. self.session.messages.push(result_message.clone());
  183. tool_results.push(result_message);
  184. }
  185. }
  186. Ok(TurnSummary {
  187. assistant_messages,
  188. tool_results,
  189. iterations,
  190. usage: self.usage_tracker.cumulative_usage(),
  191. })
  192. }
  193. #[must_use]
  194. pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
  195. compact_session(&self.session, config)
  196. }
  197. #[must_use]
  198. pub fn estimated_tokens(&self) -> usize {
  199. estimate_session_tokens(&self.session)
  200. }
  201. #[must_use]
  202. pub fn usage(&self) -> &UsageTracker {
  203. &self.usage_tracker
  204. }
  205. #[must_use]
  206. pub fn session(&self) -> &Session {
  207. &self.session
  208. }
  209. #[must_use]
  210. pub fn into_session(self) -> Session {
  211. self.session
  212. }
  213. }
  214. fn build_assistant_message(
  215. events: Vec<AssistantEvent>,
  216. ) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
  217. let mut text = String::new();
  218. let mut blocks = Vec::new();
  219. let mut finished = false;
  220. let mut usage = None;
  221. for event in events {
  222. match event {
  223. AssistantEvent::TextDelta(delta) => text.push_str(&delta),
  224. AssistantEvent::ToolUse { id, name, input } => {
  225. flush_text_block(&mut text, &mut blocks);
  226. blocks.push(ContentBlock::ToolUse { id, name, input });
  227. }
  228. AssistantEvent::Usage(value) => usage = Some(value),
  229. AssistantEvent::MessageStop => {
  230. finished = true;
  231. }
  232. }
  233. }
  234. flush_text_block(&mut text, &mut blocks);
  235. if !finished {
  236. return Err(RuntimeError::new(
  237. "assistant stream ended without a message stop event",
  238. ));
  239. }
  240. if blocks.is_empty() {
  241. return Err(RuntimeError::new("assistant stream produced no content"));
  242. }
  243. Ok((
  244. ConversationMessage::assistant_with_usage(blocks, usage),
  245. usage,
  246. ))
  247. }
  248. fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
  249. if !text.is_empty() {
  250. blocks.push(ContentBlock::Text {
  251. text: std::mem::take(text),
  252. });
  253. }
  254. }
  255. type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
  256. #[derive(Default)]
  257. pub struct StaticToolExecutor {
  258. handlers: BTreeMap<String, ToolHandler>,
  259. }
  260. impl StaticToolExecutor {
  261. #[must_use]
  262. pub fn new() -> Self {
  263. Self::default()
  264. }
  265. #[must_use]
  266. pub fn register(
  267. mut self,
  268. tool_name: impl Into<String>,
  269. handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
  270. ) -> Self {
  271. self.handlers.insert(tool_name.into(), Box::new(handler));
  272. self
  273. }
  274. }
  275. impl ToolExecutor for StaticToolExecutor {
  276. fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
  277. self.handlers
  278. .get_mut(tool_name)
  279. .ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
  280. }
  281. }
  282. #[cfg(test)]
  283. mod tests {
  284. use super::{
  285. ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError,
  286. StaticToolExecutor,
  287. };
  288. use crate::compact::CompactionConfig;
  289. use crate::permissions::{
  290. PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
  291. PermissionRequest,
  292. };
  293. use crate::prompt::{ProjectContext, SystemPromptBuilder};
  294. use crate::session::{ContentBlock, MessageRole, Session};
  295. use crate::usage::TokenUsage;
  296. use std::path::PathBuf;
  297. struct ScriptedApiClient {
  298. call_count: usize,
  299. }
  300. impl ApiClient for ScriptedApiClient {
  301. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  302. self.call_count += 1;
  303. match self.call_count {
  304. 1 => {
  305. assert!(request
  306. .messages
  307. .iter()
  308. .any(|message| message.role == MessageRole::User));
  309. Ok(vec![
  310. AssistantEvent::TextDelta("Let me calculate that.".to_string()),
  311. AssistantEvent::ToolUse {
  312. id: "tool-1".to_string(),
  313. name: "add".to_string(),
  314. input: "2,2".to_string(),
  315. },
  316. AssistantEvent::Usage(TokenUsage {
  317. input_tokens: 20,
  318. output_tokens: 6,
  319. cache_creation_input_tokens: 1,
  320. cache_read_input_tokens: 2,
  321. }),
  322. AssistantEvent::MessageStop,
  323. ])
  324. }
  325. 2 => {
  326. let last_message = request
  327. .messages
  328. .last()
  329. .expect("tool result should be present");
  330. assert_eq!(last_message.role, MessageRole::Tool);
  331. Ok(vec![
  332. AssistantEvent::TextDelta("The answer is 4.".to_string()),
  333. AssistantEvent::Usage(TokenUsage {
  334. input_tokens: 24,
  335. output_tokens: 4,
  336. cache_creation_input_tokens: 1,
  337. cache_read_input_tokens: 3,
  338. }),
  339. AssistantEvent::MessageStop,
  340. ])
  341. }
  342. _ => Err(RuntimeError::new("unexpected extra API call")),
  343. }
  344. }
  345. }
  346. struct PromptAllowOnce;
  347. impl PermissionPrompter for PromptAllowOnce {
  348. fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
  349. assert_eq!(request.tool_name, "add");
  350. PermissionPromptDecision::Allow
  351. }
  352. }
  353. #[test]
  354. fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
  355. let api_client = ScriptedApiClient { call_count: 0 };
  356. let tool_executor = StaticToolExecutor::new().register("add", |input| {
  357. let total = input
  358. .split(',')
  359. .map(|part| part.parse::<i32>().expect("input must be valid integer"))
  360. .sum::<i32>();
  361. Ok(total.to_string())
  362. });
  363. let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
  364. let system_prompt = SystemPromptBuilder::new()
  365. .with_project_context(ProjectContext {
  366. cwd: PathBuf::from("/tmp/project"),
  367. current_date: "2026-03-31".to_string(),
  368. git_status: None,
  369. git_diff: None,
  370. instruction_files: Vec::new(),
  371. })
  372. .with_os("linux", "6.8")
  373. .build();
  374. let mut runtime = ConversationRuntime::new(
  375. Session::new(),
  376. api_client,
  377. tool_executor,
  378. permission_policy,
  379. system_prompt,
  380. );
  381. let summary = runtime
  382. .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
  383. .expect("conversation loop should succeed");
  384. assert_eq!(summary.iterations, 2);
  385. assert_eq!(summary.assistant_messages.len(), 2);
  386. assert_eq!(summary.tool_results.len(), 1);
  387. assert_eq!(runtime.session().messages.len(), 4);
  388. assert_eq!(summary.usage.output_tokens, 10);
  389. assert!(matches!(
  390. runtime.session().messages[1].blocks[1],
  391. ContentBlock::ToolUse { .. }
  392. ));
  393. assert!(matches!(
  394. runtime.session().messages[2].blocks[0],
  395. ContentBlock::ToolResult {
  396. is_error: false,
  397. ..
  398. }
  399. ));
  400. }
  401. #[test]
  402. fn records_denied_tool_results_when_prompt_rejects() {
  403. struct RejectPrompter;
  404. impl PermissionPrompter for RejectPrompter {
  405. fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
  406. PermissionPromptDecision::Deny {
  407. reason: "not now".to_string(),
  408. }
  409. }
  410. }
  411. struct SingleCallApiClient;
  412. impl ApiClient for SingleCallApiClient {
  413. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  414. if request
  415. .messages
  416. .iter()
  417. .any(|message| message.role == MessageRole::Tool)
  418. {
  419. return Ok(vec![
  420. AssistantEvent::TextDelta("I could not use the tool.".to_string()),
  421. AssistantEvent::MessageStop,
  422. ]);
  423. }
  424. Ok(vec![
  425. AssistantEvent::ToolUse {
  426. id: "tool-1".to_string(),
  427. name: "blocked".to_string(),
  428. input: "secret".to_string(),
  429. },
  430. AssistantEvent::MessageStop,
  431. ])
  432. }
  433. }
  434. let mut runtime = ConversationRuntime::new(
  435. Session::new(),
  436. SingleCallApiClient,
  437. StaticToolExecutor::new(),
  438. PermissionPolicy::new(PermissionMode::WorkspaceWrite),
  439. vec!["system".to_string()],
  440. );
  441. let summary = runtime
  442. .run_turn("use the tool", Some(&mut RejectPrompter))
  443. .expect("conversation should continue after denied tool");
  444. assert_eq!(summary.tool_results.len(), 1);
  445. assert!(matches!(
  446. &summary.tool_results[0].blocks[0],
  447. ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
  448. ));
  449. }
  450. #[test]
  451. fn reconstructs_usage_tracker_from_restored_session() {
  452. struct SimpleApi;
  453. impl ApiClient for SimpleApi {
  454. fn stream(
  455. &mut self,
  456. _request: ApiRequest,
  457. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  458. Ok(vec![
  459. AssistantEvent::TextDelta("done".to_string()),
  460. AssistantEvent::MessageStop,
  461. ])
  462. }
  463. }
  464. let mut session = Session::new();
  465. session
  466. .messages
  467. .push(crate::session::ConversationMessage::assistant_with_usage(
  468. vec![ContentBlock::Text {
  469. text: "earlier".to_string(),
  470. }],
  471. Some(TokenUsage {
  472. input_tokens: 11,
  473. output_tokens: 7,
  474. cache_creation_input_tokens: 2,
  475. cache_read_input_tokens: 1,
  476. }),
  477. ));
  478. let runtime = ConversationRuntime::new(
  479. session,
  480. SimpleApi,
  481. StaticToolExecutor::new(),
  482. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  483. vec!["system".to_string()],
  484. );
  485. assert_eq!(runtime.usage().turns(), 1);
  486. assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
  487. }
  488. #[test]
  489. fn compacts_session_after_turns() {
  490. struct SimpleApi;
  491. impl ApiClient for SimpleApi {
  492. fn stream(
  493. &mut self,
  494. _request: ApiRequest,
  495. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  496. Ok(vec![
  497. AssistantEvent::TextDelta("done".to_string()),
  498. AssistantEvent::MessageStop,
  499. ])
  500. }
  501. }
  502. let mut runtime = ConversationRuntime::new(
  503. Session::new(),
  504. SimpleApi,
  505. StaticToolExecutor::new(),
  506. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  507. vec!["system".to_string()],
  508. );
  509. runtime.run_turn("a", None).expect("turn a");
  510. runtime.run_turn("b", None).expect("turn b");
  511. runtime.run_turn("c", None).expect("turn c");
  512. let result = runtime.compact(CompactionConfig {
  513. preserve_recent_messages: 2,
  514. max_estimated_tokens: 1,
  515. });
  516. assert!(result.summary.contains("Conversation summary"));
  517. assert_eq!(
  518. result.compacted_session.messages[0].role,
  519. MessageRole::System
  520. );
  521. }
  522. }