conversation.rs 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189
  1. use std::collections::BTreeMap;
  2. use std::fmt::{Display, Formatter};
  3. use crate::compact::{
  4. compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
  5. };
  6. use crate::config::RuntimeFeatureConfig;
  7. use crate::hooks::{HookAbortSignal, HookProgressReporter, HookRunResult, HookRunner};
  8. use crate::permissions::{
  9. PermissionContext, PermissionOutcome, PermissionPolicy, PermissionPrompter,
  10. };
  11. use crate::session::{ContentBlock, ConversationMessage, Session};
  12. use crate::usage::{TokenUsage, UsageTracker};
  13. const DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD: u32 = 100_000;
  14. const AUTO_COMPACTION_THRESHOLD_ENV_VAR: &str = "CLAUDE_CODE_AUTO_COMPACT_INPUT_TOKENS";
  15. #[derive(Debug, Clone, PartialEq, Eq)]
  16. pub struct ApiRequest {
  17. pub system_prompt: Vec<String>,
  18. pub messages: Vec<ConversationMessage>,
  19. }
  20. #[derive(Debug, Clone, PartialEq, Eq)]
  21. pub enum AssistantEvent {
  22. TextDelta(String),
  23. ToolUse {
  24. id: String,
  25. name: String,
  26. input: String,
  27. },
  28. Usage(TokenUsage),
  29. MessageStop,
  30. }
  31. pub trait ApiClient {
  32. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
  33. }
  34. pub trait ToolExecutor {
  35. fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError>;
  36. }
  37. #[derive(Debug, Clone, PartialEq, Eq)]
  38. pub struct ToolError {
  39. message: String,
  40. }
  41. impl ToolError {
  42. #[must_use]
  43. pub fn new(message: impl Into<String>) -> Self {
  44. Self {
  45. message: message.into(),
  46. }
  47. }
  48. }
  49. impl Display for ToolError {
  50. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  51. write!(f, "{}", self.message)
  52. }
  53. }
  54. impl std::error::Error for ToolError {}
  55. #[derive(Debug, Clone, PartialEq, Eq)]
  56. pub struct RuntimeError {
  57. message: String,
  58. }
  59. impl RuntimeError {
  60. #[must_use]
  61. pub fn new(message: impl Into<String>) -> Self {
  62. Self {
  63. message: message.into(),
  64. }
  65. }
  66. }
  67. impl Display for RuntimeError {
  68. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  69. write!(f, "{}", self.message)
  70. }
  71. }
  72. impl std::error::Error for RuntimeError {}
  73. #[derive(Debug, Clone, PartialEq, Eq)]
  74. pub struct TurnSummary {
  75. pub assistant_messages: Vec<ConversationMessage>,
  76. pub tool_results: Vec<ConversationMessage>,
  77. pub iterations: usize,
  78. pub usage: TokenUsage,
  79. pub auto_compaction: Option<AutoCompactionEvent>,
  80. }
  81. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  82. pub struct AutoCompactionEvent {
  83. pub removed_message_count: usize,
  84. }
  85. pub struct ConversationRuntime<C, T> {
  86. session: Session,
  87. api_client: C,
  88. tool_executor: T,
  89. permission_policy: PermissionPolicy,
  90. system_prompt: Vec<String>,
  91. max_iterations: usize,
  92. usage_tracker: UsageTracker,
  93. hook_runner: HookRunner,
  94. auto_compaction_input_tokens_threshold: u32,
  95. hook_abort_signal: HookAbortSignal,
  96. hook_progress_reporter: Option<Box<dyn HookProgressReporter>>,
  97. }
  98. impl<C, T> ConversationRuntime<C, T>
  99. where
  100. C: ApiClient,
  101. T: ToolExecutor,
  102. {
  103. #[must_use]
  104. pub fn new(
  105. session: Session,
  106. api_client: C,
  107. tool_executor: T,
  108. permission_policy: PermissionPolicy,
  109. system_prompt: Vec<String>,
  110. ) -> Self {
  111. Self::new_with_features(
  112. session,
  113. api_client,
  114. tool_executor,
  115. permission_policy,
  116. system_prompt,
  117. RuntimeFeatureConfig::default(),
  118. )
  119. }
  120. #[must_use]
  121. #[allow(clippy::needless_pass_by_value)]
  122. pub fn new_with_features(
  123. session: Session,
  124. api_client: C,
  125. tool_executor: T,
  126. permission_policy: PermissionPolicy,
  127. system_prompt: Vec<String>,
  128. feature_config: RuntimeFeatureConfig,
  129. ) -> Self {
  130. let usage_tracker = UsageTracker::from_session(&session);
  131. Self {
  132. session,
  133. api_client,
  134. tool_executor,
  135. permission_policy,
  136. system_prompt,
  137. max_iterations: usize::MAX,
  138. usage_tracker,
  139. hook_runner: HookRunner::from_feature_config(&feature_config),
  140. auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(),
  141. hook_abort_signal: HookAbortSignal::default(),
  142. hook_progress_reporter: None,
  143. }
  144. }
  145. #[must_use]
  146. pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
  147. self.max_iterations = max_iterations;
  148. self
  149. }
  150. #[must_use]
  151. pub fn with_auto_compaction_input_tokens_threshold(mut self, threshold: u32) -> Self {
  152. self.auto_compaction_input_tokens_threshold = threshold;
  153. self
  154. }
  155. #[must_use]
  156. pub fn with_hook_abort_signal(mut self, hook_abort_signal: HookAbortSignal) -> Self {
  157. self.hook_abort_signal = hook_abort_signal;
  158. self
  159. }
  160. #[must_use]
  161. pub fn with_hook_progress_reporter(
  162. mut self,
  163. hook_progress_reporter: Box<dyn HookProgressReporter>,
  164. ) -> Self {
  165. self.hook_progress_reporter = Some(hook_progress_reporter);
  166. self
  167. }
  168. fn run_pre_tool_use_hook(&mut self, tool_name: &str, input: &str) -> HookRunResult {
  169. if let Some(reporter) = self.hook_progress_reporter.as_mut() {
  170. self.hook_runner.run_pre_tool_use_with_context(
  171. tool_name,
  172. input,
  173. Some(&self.hook_abort_signal),
  174. Some(reporter.as_mut()),
  175. )
  176. } else {
  177. self.hook_runner.run_pre_tool_use_with_context(
  178. tool_name,
  179. input,
  180. Some(&self.hook_abort_signal),
  181. None,
  182. )
  183. }
  184. }
  185. fn run_post_tool_use_hook(
  186. &mut self,
  187. tool_name: &str,
  188. input: &str,
  189. output: &str,
  190. is_error: bool,
  191. ) -> HookRunResult {
  192. if let Some(reporter) = self.hook_progress_reporter.as_mut() {
  193. self.hook_runner.run_post_tool_use_with_context(
  194. tool_name,
  195. input,
  196. output,
  197. is_error,
  198. Some(&self.hook_abort_signal),
  199. Some(reporter.as_mut()),
  200. )
  201. } else {
  202. self.hook_runner.run_post_tool_use_with_context(
  203. tool_name,
  204. input,
  205. output,
  206. is_error,
  207. Some(&self.hook_abort_signal),
  208. None,
  209. )
  210. }
  211. }
  212. fn run_post_tool_use_failure_hook(
  213. &mut self,
  214. tool_name: &str,
  215. input: &str,
  216. output: &str,
  217. ) -> HookRunResult {
  218. if let Some(reporter) = self.hook_progress_reporter.as_mut() {
  219. self.hook_runner.run_post_tool_use_failure_with_context(
  220. tool_name,
  221. input,
  222. output,
  223. Some(&self.hook_abort_signal),
  224. Some(reporter.as_mut()),
  225. )
  226. } else {
  227. self.hook_runner.run_post_tool_use_failure_with_context(
  228. tool_name,
  229. input,
  230. output,
  231. Some(&self.hook_abort_signal),
  232. None,
  233. )
  234. }
  235. }
  236. #[allow(clippy::too_many_lines)]
  237. pub fn run_turn(
  238. &mut self,
  239. user_input: impl Into<String>,
  240. mut prompter: Option<&mut dyn PermissionPrompter>,
  241. ) -> Result<TurnSummary, RuntimeError> {
  242. self.session
  243. .push_user_text(user_input.into())
  244. .map_err(|error| RuntimeError::new(error.to_string()))?;
  245. let mut assistant_messages = Vec::new();
  246. let mut tool_results = Vec::new();
  247. let mut iterations = 0;
  248. loop {
  249. iterations += 1;
  250. if iterations > self.max_iterations {
  251. return Err(RuntimeError::new(
  252. "conversation loop exceeded the maximum number of iterations",
  253. ));
  254. }
  255. let request = ApiRequest {
  256. system_prompt: self.system_prompt.clone(),
  257. messages: self.session.messages.clone(),
  258. };
  259. let events = self.api_client.stream(request)?;
  260. let (assistant_message, usage) = build_assistant_message(events)?;
  261. if let Some(usage) = usage {
  262. self.usage_tracker.record(usage);
  263. }
  264. let pending_tool_uses = assistant_message
  265. .blocks
  266. .iter()
  267. .filter_map(|block| match block {
  268. ContentBlock::ToolUse { id, name, input } => {
  269. Some((id.clone(), name.clone(), input.clone()))
  270. }
  271. _ => None,
  272. })
  273. .collect::<Vec<_>>();
  274. self.session
  275. .push_message(assistant_message.clone())
  276. .map_err(|error| RuntimeError::new(error.to_string()))?;
  277. assistant_messages.push(assistant_message);
  278. if pending_tool_uses.is_empty() {
  279. break;
  280. }
  281. for (tool_use_id, tool_name, input) in pending_tool_uses {
  282. let pre_hook_result = self.run_pre_tool_use_hook(&tool_name, &input);
  283. let effective_input = pre_hook_result
  284. .updated_input()
  285. .map_or_else(|| input.clone(), ToOwned::to_owned);
  286. let permission_context = PermissionContext::new(
  287. pre_hook_result.permission_override(),
  288. pre_hook_result.permission_reason().map(ToOwned::to_owned),
  289. );
  290. let permission_outcome = if pre_hook_result.is_cancelled() {
  291. PermissionOutcome::Deny {
  292. reason: format_hook_message(
  293. &pre_hook_result,
  294. &format!("PreToolUse hook cancelled tool `{tool_name}`"),
  295. ),
  296. }
  297. } else if pre_hook_result.is_denied() {
  298. PermissionOutcome::Deny {
  299. reason: format_hook_message(
  300. &pre_hook_result,
  301. &format!("PreToolUse hook denied tool `{tool_name}`"),
  302. ),
  303. }
  304. } else if let Some(prompt) = prompter.as_mut() {
  305. self.permission_policy.authorize_with_context(
  306. &tool_name,
  307. &effective_input,
  308. &permission_context,
  309. Some(*prompt),
  310. )
  311. } else {
  312. self.permission_policy.authorize_with_context(
  313. &tool_name,
  314. &effective_input,
  315. &permission_context,
  316. None,
  317. )
  318. };
  319. let result_message = match permission_outcome {
  320. PermissionOutcome::Allow => {
  321. let (mut output, mut is_error) =
  322. match self.tool_executor.execute(&tool_name, &effective_input) {
  323. Ok(output) => (output, false),
  324. Err(error) => (error.to_string(), true),
  325. };
  326. output = merge_hook_feedback(pre_hook_result.messages(), output, false);
  327. let post_hook_result = if is_error {
  328. self.run_post_tool_use_failure_hook(
  329. &tool_name,
  330. &effective_input,
  331. &output,
  332. )
  333. } else {
  334. self.run_post_tool_use_hook(
  335. &tool_name,
  336. &effective_input,
  337. &output,
  338. false,
  339. )
  340. };
  341. if post_hook_result.is_denied() || post_hook_result.is_cancelled() {
  342. is_error = true;
  343. }
  344. output = merge_hook_feedback(
  345. post_hook_result.messages(),
  346. output,
  347. post_hook_result.is_denied() || post_hook_result.is_cancelled(),
  348. );
  349. ConversationMessage::tool_result(tool_use_id, tool_name, output, is_error)
  350. }
  351. PermissionOutcome::Deny { reason } => ConversationMessage::tool_result(
  352. tool_use_id,
  353. tool_name,
  354. merge_hook_feedback(pre_hook_result.messages(), reason, true),
  355. true,
  356. ),
  357. };
  358. self.session
  359. .push_message(result_message.clone())
  360. .map_err(|error| RuntimeError::new(error.to_string()))?;
  361. tool_results.push(result_message);
  362. }
  363. }
  364. let auto_compaction = self.maybe_auto_compact();
  365. Ok(TurnSummary {
  366. assistant_messages,
  367. tool_results,
  368. iterations,
  369. usage: self.usage_tracker.cumulative_usage(),
  370. auto_compaction,
  371. })
  372. }
  373. #[must_use]
  374. pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
  375. compact_session(&self.session, config)
  376. }
  377. #[must_use]
  378. pub fn estimated_tokens(&self) -> usize {
  379. estimate_session_tokens(&self.session)
  380. }
  381. #[must_use]
  382. pub fn usage(&self) -> &UsageTracker {
  383. &self.usage_tracker
  384. }
  385. #[must_use]
  386. pub fn session(&self) -> &Session {
  387. &self.session
  388. }
  389. #[must_use]
  390. pub fn fork_session(&self, branch_name: Option<String>) -> Session {
  391. self.session.fork(branch_name)
  392. }
  393. #[must_use]
  394. pub fn into_session(self) -> Session {
  395. self.session
  396. }
  397. fn maybe_auto_compact(&mut self) -> Option<AutoCompactionEvent> {
  398. if self.usage_tracker.cumulative_usage().input_tokens
  399. < self.auto_compaction_input_tokens_threshold
  400. {
  401. return None;
  402. }
  403. let result = compact_session(
  404. &self.session,
  405. CompactionConfig {
  406. max_estimated_tokens: 0,
  407. ..CompactionConfig::default()
  408. },
  409. );
  410. if result.removed_message_count == 0 {
  411. return None;
  412. }
  413. self.session = result.compacted_session;
  414. Some(AutoCompactionEvent {
  415. removed_message_count: result.removed_message_count,
  416. })
  417. }
  418. }
  419. #[must_use]
  420. pub fn auto_compaction_threshold_from_env() -> u32 {
  421. parse_auto_compaction_threshold(
  422. std::env::var(AUTO_COMPACTION_THRESHOLD_ENV_VAR)
  423. .ok()
  424. .as_deref(),
  425. )
  426. }
  427. #[must_use]
  428. fn parse_auto_compaction_threshold(value: Option<&str>) -> u32 {
  429. value
  430. .and_then(|raw| raw.trim().parse::<u32>().ok())
  431. .filter(|threshold| *threshold > 0)
  432. .unwrap_or(DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD)
  433. }
  434. fn build_assistant_message(
  435. events: Vec<AssistantEvent>,
  436. ) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
  437. let mut text = String::new();
  438. let mut blocks = Vec::new();
  439. let mut finished = false;
  440. let mut usage = None;
  441. for event in events {
  442. match event {
  443. AssistantEvent::TextDelta(delta) => text.push_str(&delta),
  444. AssistantEvent::ToolUse { id, name, input } => {
  445. flush_text_block(&mut text, &mut blocks);
  446. blocks.push(ContentBlock::ToolUse { id, name, input });
  447. }
  448. AssistantEvent::Usage(value) => usage = Some(value),
  449. AssistantEvent::MessageStop => {
  450. finished = true;
  451. }
  452. }
  453. }
  454. flush_text_block(&mut text, &mut blocks);
  455. if !finished {
  456. return Err(RuntimeError::new(
  457. "assistant stream ended without a message stop event",
  458. ));
  459. }
  460. if blocks.is_empty() {
  461. return Err(RuntimeError::new("assistant stream produced no content"));
  462. }
  463. Ok((
  464. ConversationMessage::assistant_with_usage(blocks, usage),
  465. usage,
  466. ))
  467. }
  468. fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
  469. if !text.is_empty() {
  470. blocks.push(ContentBlock::Text {
  471. text: std::mem::take(text),
  472. });
  473. }
  474. }
  475. fn format_hook_message(result: &HookRunResult, fallback: &str) -> String {
  476. if result.messages().is_empty() {
  477. fallback.to_string()
  478. } else {
  479. result.messages().join("\n")
  480. }
  481. }
  482. fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> String {
  483. if messages.is_empty() {
  484. return output;
  485. }
  486. let mut sections = Vec::new();
  487. if !output.trim().is_empty() {
  488. sections.push(output);
  489. }
  490. let label = if denied {
  491. "Hook feedback (denied)"
  492. } else {
  493. "Hook feedback"
  494. };
  495. sections.push(format!("{label}:\n{}", messages.join("\n")));
  496. sections.join("\n\n")
  497. }
  498. type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
  499. #[derive(Default)]
  500. pub struct StaticToolExecutor {
  501. handlers: BTreeMap<String, ToolHandler>,
  502. }
  503. impl StaticToolExecutor {
  504. #[must_use]
  505. pub fn new() -> Self {
  506. Self::default()
  507. }
  508. #[must_use]
  509. pub fn register(
  510. mut self,
  511. tool_name: impl Into<String>,
  512. handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
  513. ) -> Self {
  514. self.handlers.insert(tool_name.into(), Box::new(handler));
  515. self
  516. }
  517. }
  518. impl ToolExecutor for StaticToolExecutor {
  519. fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
  520. self.handlers
  521. .get_mut(tool_name)
  522. .ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
  523. }
  524. }
  525. #[cfg(test)]
  526. mod tests {
  527. use super::{
  528. parse_auto_compaction_threshold, ApiClient, ApiRequest, AssistantEvent,
  529. AutoCompactionEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
  530. DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD,
  531. };
  532. use crate::compact::CompactionConfig;
  533. use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
  534. use crate::permissions::{
  535. PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
  536. PermissionRequest,
  537. };
  538. use crate::prompt::{ProjectContext, SystemPromptBuilder};
  539. use crate::session::{ContentBlock, MessageRole, Session};
  540. use crate::usage::TokenUsage;
  541. use std::fs;
  542. use std::path::PathBuf;
  543. use std::time::{SystemTime, UNIX_EPOCH};
  544. struct ScriptedApiClient {
  545. call_count: usize,
  546. }
  547. impl ApiClient for ScriptedApiClient {
  548. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  549. self.call_count += 1;
  550. match self.call_count {
  551. 1 => {
  552. assert!(request
  553. .messages
  554. .iter()
  555. .any(|message| message.role == MessageRole::User));
  556. Ok(vec![
  557. AssistantEvent::TextDelta("Let me calculate that.".to_string()),
  558. AssistantEvent::ToolUse {
  559. id: "tool-1".to_string(),
  560. name: "add".to_string(),
  561. input: "2,2".to_string(),
  562. },
  563. AssistantEvent::Usage(TokenUsage {
  564. input_tokens: 20,
  565. output_tokens: 6,
  566. cache_creation_input_tokens: 1,
  567. cache_read_input_tokens: 2,
  568. }),
  569. AssistantEvent::MessageStop,
  570. ])
  571. }
  572. 2 => {
  573. let last_message = request
  574. .messages
  575. .last()
  576. .expect("tool result should be present");
  577. assert_eq!(last_message.role, MessageRole::Tool);
  578. Ok(vec![
  579. AssistantEvent::TextDelta("The answer is 4.".to_string()),
  580. AssistantEvent::Usage(TokenUsage {
  581. input_tokens: 24,
  582. output_tokens: 4,
  583. cache_creation_input_tokens: 1,
  584. cache_read_input_tokens: 3,
  585. }),
  586. AssistantEvent::MessageStop,
  587. ])
  588. }
  589. _ => Err(RuntimeError::new("unexpected extra API call")),
  590. }
  591. }
  592. }
  593. struct PromptAllowOnce;
  594. impl PermissionPrompter for PromptAllowOnce {
  595. fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
  596. assert_eq!(request.tool_name, "add");
  597. PermissionPromptDecision::Allow
  598. }
  599. }
  600. #[test]
  601. fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
  602. let api_client = ScriptedApiClient { call_count: 0 };
  603. let tool_executor = StaticToolExecutor::new().register("add", |input| {
  604. let total = input
  605. .split(',')
  606. .map(|part| part.parse::<i32>().expect("input must be valid integer"))
  607. .sum::<i32>();
  608. Ok(total.to_string())
  609. });
  610. let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
  611. let system_prompt = SystemPromptBuilder::new()
  612. .with_project_context(ProjectContext {
  613. cwd: PathBuf::from("/tmp/project"),
  614. current_date: "2026-03-31".to_string(),
  615. git_status: None,
  616. git_diff: None,
  617. instruction_files: Vec::new(),
  618. })
  619. .with_os("linux", "6.8")
  620. .build();
  621. let mut runtime = ConversationRuntime::new(
  622. Session::new(),
  623. api_client,
  624. tool_executor,
  625. permission_policy,
  626. system_prompt,
  627. );
  628. let summary = runtime
  629. .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
  630. .expect("conversation loop should succeed");
  631. assert_eq!(summary.iterations, 2);
  632. assert_eq!(summary.assistant_messages.len(), 2);
  633. assert_eq!(summary.tool_results.len(), 1);
  634. assert_eq!(runtime.session().messages.len(), 4);
  635. assert_eq!(summary.usage.output_tokens, 10);
  636. assert_eq!(summary.auto_compaction, None);
  637. assert!(matches!(
  638. runtime.session().messages[1].blocks[1],
  639. ContentBlock::ToolUse { .. }
  640. ));
  641. assert!(matches!(
  642. runtime.session().messages[2].blocks[0],
  643. ContentBlock::ToolResult {
  644. is_error: false,
  645. ..
  646. }
  647. ));
  648. }
  649. #[test]
  650. fn records_denied_tool_results_when_prompt_rejects() {
  651. struct RejectPrompter;
  652. impl PermissionPrompter for RejectPrompter {
  653. fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
  654. PermissionPromptDecision::Deny {
  655. reason: "not now".to_string(),
  656. }
  657. }
  658. }
  659. struct SingleCallApiClient;
  660. impl ApiClient for SingleCallApiClient {
  661. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  662. if request
  663. .messages
  664. .iter()
  665. .any(|message| message.role == MessageRole::Tool)
  666. {
  667. return Ok(vec![
  668. AssistantEvent::TextDelta("I could not use the tool.".to_string()),
  669. AssistantEvent::MessageStop,
  670. ]);
  671. }
  672. Ok(vec![
  673. AssistantEvent::ToolUse {
  674. id: "tool-1".to_string(),
  675. name: "blocked".to_string(),
  676. input: "secret".to_string(),
  677. },
  678. AssistantEvent::MessageStop,
  679. ])
  680. }
  681. }
  682. let mut runtime = ConversationRuntime::new(
  683. Session::new(),
  684. SingleCallApiClient,
  685. StaticToolExecutor::new(),
  686. PermissionPolicy::new(PermissionMode::WorkspaceWrite),
  687. vec!["system".to_string()],
  688. );
  689. let summary = runtime
  690. .run_turn("use the tool", Some(&mut RejectPrompter))
  691. .expect("conversation should continue after denied tool");
  692. assert_eq!(summary.tool_results.len(), 1);
  693. assert!(matches!(
  694. &summary.tool_results[0].blocks[0],
  695. ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
  696. ));
  697. }
  698. #[test]
  699. fn denies_tool_use_when_pre_tool_hook_blocks() {
  700. struct SingleCallApiClient;
  701. impl ApiClient for SingleCallApiClient {
  702. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  703. if request
  704. .messages
  705. .iter()
  706. .any(|message| message.role == MessageRole::Tool)
  707. {
  708. return Ok(vec![
  709. AssistantEvent::TextDelta("blocked".to_string()),
  710. AssistantEvent::MessageStop,
  711. ]);
  712. }
  713. Ok(vec![
  714. AssistantEvent::ToolUse {
  715. id: "tool-1".to_string(),
  716. name: "blocked".to_string(),
  717. input: r#"{"path":"secret.txt"}"#.to_string(),
  718. },
  719. AssistantEvent::MessageStop,
  720. ])
  721. }
  722. }
  723. let mut runtime = ConversationRuntime::new_with_features(
  724. Session::new(),
  725. SingleCallApiClient,
  726. StaticToolExecutor::new().register("blocked", |_input| {
  727. panic!("tool should not execute when hook denies")
  728. }),
  729. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  730. vec!["system".to_string()],
  731. RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
  732. vec![shell_snippet("printf 'blocked by hook'; exit 2")],
  733. Vec::new(),
  734. Vec::new(),
  735. )),
  736. );
  737. let summary = runtime
  738. .run_turn("use the tool", None)
  739. .expect("conversation should continue after hook denial");
  740. assert_eq!(summary.tool_results.len(), 1);
  741. let ContentBlock::ToolResult {
  742. is_error, output, ..
  743. } = &summary.tool_results[0].blocks[0]
  744. else {
  745. panic!("expected tool result block");
  746. };
  747. assert!(
  748. *is_error,
  749. "hook denial should produce an error result: {output}"
  750. );
  751. assert!(
  752. output.contains("denied tool") || output.contains("blocked by hook"),
  753. "unexpected hook denial output: {output:?}"
  754. );
  755. }
  756. #[test]
  757. fn appends_post_tool_hook_feedback_to_tool_result() {
  758. struct TwoCallApiClient {
  759. calls: usize,
  760. }
  761. impl ApiClient for TwoCallApiClient {
  762. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  763. self.calls += 1;
  764. match self.calls {
  765. 1 => Ok(vec![
  766. AssistantEvent::ToolUse {
  767. id: "tool-1".to_string(),
  768. name: "add".to_string(),
  769. input: r#"{"lhs":2,"rhs":2}"#.to_string(),
  770. },
  771. AssistantEvent::MessageStop,
  772. ]),
  773. 2 => {
  774. assert!(request
  775. .messages
  776. .iter()
  777. .any(|message| message.role == MessageRole::Tool));
  778. Ok(vec![
  779. AssistantEvent::TextDelta("done".to_string()),
  780. AssistantEvent::MessageStop,
  781. ])
  782. }
  783. _ => Err(RuntimeError::new("unexpected extra API call")),
  784. }
  785. }
  786. }
  787. let mut runtime = ConversationRuntime::new_with_features(
  788. Session::new(),
  789. TwoCallApiClient { calls: 0 },
  790. StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
  791. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  792. vec!["system".to_string()],
  793. RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
  794. vec![shell_snippet("printf 'pre hook ran'")],
  795. vec![shell_snippet("printf 'post hook ran'")],
  796. Vec::new(),
  797. )),
  798. );
  799. let summary = runtime
  800. .run_turn("use add", None)
  801. .expect("tool loop succeeds");
  802. assert_eq!(summary.tool_results.len(), 1);
  803. let ContentBlock::ToolResult {
  804. is_error, output, ..
  805. } = &summary.tool_results[0].blocks[0]
  806. else {
  807. panic!("expected tool result block");
  808. };
  809. assert!(
  810. !*is_error,
  811. "post hook should preserve non-error result: {output:?}"
  812. );
  813. assert!(
  814. output.contains('4'),
  815. "tool output missing value: {output:?}"
  816. );
  817. assert!(
  818. output.contains("pre hook ran"),
  819. "tool output missing pre hook feedback: {output:?}"
  820. );
  821. assert!(
  822. output.contains("post hook ran"),
  823. "tool output missing post hook feedback: {output:?}"
  824. );
  825. }
  826. #[test]
  827. fn reconstructs_usage_tracker_from_restored_session() {
  828. struct SimpleApi;
  829. impl ApiClient for SimpleApi {
  830. fn stream(
  831. &mut self,
  832. _request: ApiRequest,
  833. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  834. Ok(vec![
  835. AssistantEvent::TextDelta("done".to_string()),
  836. AssistantEvent::MessageStop,
  837. ])
  838. }
  839. }
  840. let mut session = Session::new();
  841. session
  842. .messages
  843. .push(crate::session::ConversationMessage::assistant_with_usage(
  844. vec![ContentBlock::Text {
  845. text: "earlier".to_string(),
  846. }],
  847. Some(TokenUsage {
  848. input_tokens: 11,
  849. output_tokens: 7,
  850. cache_creation_input_tokens: 2,
  851. cache_read_input_tokens: 1,
  852. }),
  853. ));
  854. let runtime = ConversationRuntime::new(
  855. session,
  856. SimpleApi,
  857. StaticToolExecutor::new(),
  858. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  859. vec!["system".to_string()],
  860. );
  861. assert_eq!(runtime.usage().turns(), 1);
  862. assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
  863. }
  864. #[test]
  865. fn compacts_session_after_turns() {
  866. struct SimpleApi;
  867. impl ApiClient for SimpleApi {
  868. fn stream(
  869. &mut self,
  870. _request: ApiRequest,
  871. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  872. Ok(vec![
  873. AssistantEvent::TextDelta("done".to_string()),
  874. AssistantEvent::MessageStop,
  875. ])
  876. }
  877. }
  878. let mut runtime = ConversationRuntime::new(
  879. Session::new(),
  880. SimpleApi,
  881. StaticToolExecutor::new(),
  882. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  883. vec!["system".to_string()],
  884. );
  885. runtime.run_turn("a", None).expect("turn a");
  886. runtime.run_turn("b", None).expect("turn b");
  887. runtime.run_turn("c", None).expect("turn c");
  888. let result = runtime.compact(CompactionConfig {
  889. preserve_recent_messages: 2,
  890. max_estimated_tokens: 1,
  891. });
  892. assert!(result.summary.contains("Conversation summary"));
  893. assert_eq!(
  894. result.compacted_session.messages[0].role,
  895. MessageRole::System
  896. );
  897. assert_eq!(
  898. result.compacted_session.session_id,
  899. runtime.session().session_id
  900. );
  901. assert!(result.compacted_session.compaction.is_some());
  902. }
  903. #[test]
  904. fn persists_conversation_turn_messages_to_jsonl_session() {
  905. struct SimpleApi;
  906. impl ApiClient for SimpleApi {
  907. fn stream(
  908. &mut self,
  909. _request: ApiRequest,
  910. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  911. Ok(vec![
  912. AssistantEvent::TextDelta("done".to_string()),
  913. AssistantEvent::MessageStop,
  914. ])
  915. }
  916. }
  917. let path = temp_session_path("persisted-turn");
  918. let session = Session::new().with_persistence_path(path.clone());
  919. let mut runtime = ConversationRuntime::new(
  920. session,
  921. SimpleApi,
  922. StaticToolExecutor::new(),
  923. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  924. vec!["system".to_string()],
  925. );
  926. runtime
  927. .run_turn("persist this turn", None)
  928. .expect("turn should succeed");
  929. let restored = Session::load_from_path(&path).expect("persisted session should reload");
  930. fs::remove_file(&path).expect("temp session file should be removable");
  931. assert_eq!(restored.messages.len(), 2);
  932. assert_eq!(restored.messages[0].role, MessageRole::User);
  933. assert_eq!(restored.messages[1].role, MessageRole::Assistant);
  934. assert_eq!(restored.session_id, runtime.session().session_id);
  935. }
  936. #[test]
  937. fn forks_runtime_session_without_mutating_original() {
  938. let mut session = Session::new();
  939. session
  940. .push_user_text("branch me")
  941. .expect("message should append");
  942. let runtime = ConversationRuntime::new(
  943. session.clone(),
  944. ScriptedApiClient { call_count: 0 },
  945. StaticToolExecutor::new(),
  946. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  947. vec!["system".to_string()],
  948. );
  949. let forked = runtime.fork_session(Some("alt-path".to_string()));
  950. assert_eq!(forked.messages, session.messages);
  951. assert_ne!(forked.session_id, session.session_id);
  952. assert_eq!(
  953. forked
  954. .fork
  955. .as_ref()
  956. .map(|fork| (fork.parent_session_id.as_str(), fork.branch_name.as_deref())),
  957. Some((session.session_id.as_str(), Some("alt-path")))
  958. );
  959. assert!(runtime.session().fork.is_none());
  960. }
  961. fn temp_session_path(label: &str) -> PathBuf {
  962. let nanos = SystemTime::now()
  963. .duration_since(UNIX_EPOCH)
  964. .expect("system time should be after epoch")
  965. .as_nanos();
  966. std::env::temp_dir().join(format!("runtime-conversation-{label}-{nanos}.json"))
  967. }
  968. #[cfg(windows)]
  969. fn shell_snippet(script: &str) -> String {
  970. script.replace('\'', "\"")
  971. }
  972. #[cfg(not(windows))]
  973. fn shell_snippet(script: &str) -> String {
  974. script.to_string()
  975. }
  976. #[test]
  977. fn auto_compacts_when_cumulative_input_threshold_is_crossed() {
  978. struct SimpleApi;
  979. impl ApiClient for SimpleApi {
  980. fn stream(
  981. &mut self,
  982. _request: ApiRequest,
  983. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  984. Ok(vec![
  985. AssistantEvent::TextDelta("done".to_string()),
  986. AssistantEvent::Usage(TokenUsage {
  987. input_tokens: 120_000,
  988. output_tokens: 4,
  989. cache_creation_input_tokens: 0,
  990. cache_read_input_tokens: 0,
  991. }),
  992. AssistantEvent::MessageStop,
  993. ])
  994. }
  995. }
  996. let session = Session {
  997. version: 1,
  998. messages: vec![
  999. crate::session::ConversationMessage::user_text("one"),
  1000. crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
  1001. text: "two".to_string(),
  1002. }]),
  1003. crate::session::ConversationMessage::user_text("three"),
  1004. crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
  1005. text: "four".to_string(),
  1006. }]),
  1007. ],
  1008. };
  1009. let mut runtime = ConversationRuntime::new(
  1010. session,
  1011. SimpleApi,
  1012. StaticToolExecutor::new(),
  1013. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  1014. vec!["system".to_string()],
  1015. )
  1016. .with_auto_compaction_input_tokens_threshold(100_000);
  1017. let summary = runtime
  1018. .run_turn("trigger", None)
  1019. .expect("turn should succeed");
  1020. assert_eq!(
  1021. summary.auto_compaction,
  1022. Some(AutoCompactionEvent {
  1023. removed_message_count: 2,
  1024. })
  1025. );
  1026. assert_eq!(runtime.session().messages[0].role, MessageRole::System);
  1027. }
  1028. #[test]
  1029. fn skips_auto_compaction_below_threshold() {
  1030. struct SimpleApi;
  1031. impl ApiClient for SimpleApi {
  1032. fn stream(
  1033. &mut self,
  1034. _request: ApiRequest,
  1035. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  1036. Ok(vec![
  1037. AssistantEvent::TextDelta("done".to_string()),
  1038. AssistantEvent::Usage(TokenUsage {
  1039. input_tokens: 99_999,
  1040. output_tokens: 4,
  1041. cache_creation_input_tokens: 0,
  1042. cache_read_input_tokens: 0,
  1043. }),
  1044. AssistantEvent::MessageStop,
  1045. ])
  1046. }
  1047. }
  1048. let mut runtime = ConversationRuntime::new(
  1049. Session::new(),
  1050. SimpleApi,
  1051. StaticToolExecutor::new(),
  1052. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  1053. vec!["system".to_string()],
  1054. )
  1055. .with_auto_compaction_input_tokens_threshold(100_000);
  1056. let summary = runtime
  1057. .run_turn("trigger", None)
  1058. .expect("turn should succeed");
  1059. assert_eq!(summary.auto_compaction, None);
  1060. assert_eq!(runtime.session().messages.len(), 2);
  1061. }
  1062. #[test]
  1063. fn auto_compaction_threshold_defaults_and_parses_values() {
  1064. assert_eq!(
  1065. parse_auto_compaction_threshold(None),
  1066. DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
  1067. );
  1068. assert_eq!(parse_auto_compaction_threshold(Some("4321")), 4321);
  1069. assert_eq!(
  1070. parse_auto_compaction_threshold(Some("not-a-number")),
  1071. DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
  1072. );
  1073. }
  1074. }