conversation.rs 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973
  1. use std::collections::BTreeMap;
  2. use std::fmt::{Display, Formatter};
  3. use crate::compact::{
  4. compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
  5. };
  6. use crate::config::RuntimeFeatureConfig;
  7. use crate::hooks::{HookRunResult, HookRunner};
  8. use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter};
  9. use crate::session::{ContentBlock, ConversationMessage, Session};
  10. use crate::usage::{TokenUsage, UsageTracker};
  11. const DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD: u32 = 200_000;
  12. const AUTO_COMPACTION_THRESHOLD_ENV_VAR: &str = "CLAUDE_CODE_AUTO_COMPACT_INPUT_TOKENS";
  13. #[derive(Debug, Clone, PartialEq, Eq)]
  14. pub struct ApiRequest {
  15. pub system_prompt: Vec<String>,
  16. pub messages: Vec<ConversationMessage>,
  17. }
  18. #[derive(Debug, Clone, PartialEq, Eq)]
  19. pub enum AssistantEvent {
  20. TextDelta(String),
  21. ToolUse {
  22. id: String,
  23. name: String,
  24. input: String,
  25. },
  26. Usage(TokenUsage),
  27. MessageStop,
  28. }
  29. pub trait ApiClient {
  30. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
  31. }
  32. pub trait ToolExecutor {
  33. fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError>;
  34. }
  35. #[derive(Debug, Clone, PartialEq, Eq)]
  36. pub struct ToolError {
  37. message: String,
  38. }
  39. impl ToolError {
  40. #[must_use]
  41. pub fn new(message: impl Into<String>) -> Self {
  42. Self {
  43. message: message.into(),
  44. }
  45. }
  46. }
  47. impl Display for ToolError {
  48. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  49. write!(f, "{}", self.message)
  50. }
  51. }
  52. impl std::error::Error for ToolError {}
  53. #[derive(Debug, Clone, PartialEq, Eq)]
  54. pub struct RuntimeError {
  55. message: String,
  56. }
  57. impl RuntimeError {
  58. #[must_use]
  59. pub fn new(message: impl Into<String>) -> Self {
  60. Self {
  61. message: message.into(),
  62. }
  63. }
  64. }
  65. impl Display for RuntimeError {
  66. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  67. write!(f, "{}", self.message)
  68. }
  69. }
  70. impl std::error::Error for RuntimeError {}
  71. #[derive(Debug, Clone, PartialEq, Eq)]
  72. pub struct TurnSummary {
  73. pub assistant_messages: Vec<ConversationMessage>,
  74. pub tool_results: Vec<ConversationMessage>,
  75. pub iterations: usize,
  76. pub usage: TokenUsage,
  77. pub auto_compaction: Option<AutoCompactionEvent>,
  78. }
  79. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  80. pub struct AutoCompactionEvent {
  81. pub removed_message_count: usize,
  82. }
  83. pub struct ConversationRuntime<C, T> {
  84. session: Session,
  85. api_client: C,
  86. tool_executor: T,
  87. permission_policy: PermissionPolicy,
  88. system_prompt: Vec<String>,
  89. max_iterations: usize,
  90. usage_tracker: UsageTracker,
  91. hook_runner: HookRunner,
  92. auto_compaction_input_tokens_threshold: u32,
  93. }
  94. impl<C, T> ConversationRuntime<C, T>
  95. where
  96. C: ApiClient,
  97. T: ToolExecutor,
  98. {
  99. #[must_use]
  100. pub fn new(
  101. session: Session,
  102. api_client: C,
  103. tool_executor: T,
  104. permission_policy: PermissionPolicy,
  105. system_prompt: Vec<String>,
  106. ) -> Self {
  107. Self::new_with_features(
  108. session,
  109. api_client,
  110. tool_executor,
  111. permission_policy,
  112. system_prompt,
  113. RuntimeFeatureConfig::default(),
  114. )
  115. }
  116. #[must_use]
  117. #[allow(clippy::needless_pass_by_value)]
  118. pub fn new_with_features(
  119. session: Session,
  120. api_client: C,
  121. tool_executor: T,
  122. permission_policy: PermissionPolicy,
  123. system_prompt: Vec<String>,
  124. feature_config: RuntimeFeatureConfig,
  125. ) -> Self {
  126. let usage_tracker = UsageTracker::from_session(&session);
  127. Self {
  128. session,
  129. api_client,
  130. tool_executor,
  131. permission_policy,
  132. system_prompt,
  133. max_iterations: usize::MAX,
  134. usage_tracker,
  135. hook_runner: HookRunner::from_feature_config(&feature_config),
  136. auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(),
  137. }
  138. }
  139. #[must_use]
  140. pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
  141. self.max_iterations = max_iterations;
  142. self
  143. }
  144. #[must_use]
  145. pub fn with_auto_compaction_input_tokens_threshold(mut self, threshold: u32) -> Self {
  146. self.auto_compaction_input_tokens_threshold = threshold;
  147. self
  148. }
  149. pub fn run_turn(
  150. &mut self,
  151. user_input: impl Into<String>,
  152. mut prompter: Option<&mut dyn PermissionPrompter>,
  153. ) -> Result<TurnSummary, RuntimeError> {
  154. self.session
  155. .messages
  156. .push(ConversationMessage::user_text(user_input.into()));
  157. let mut assistant_messages = Vec::new();
  158. let mut tool_results = Vec::new();
  159. let mut iterations = 0;
  160. loop {
  161. iterations += 1;
  162. if iterations > self.max_iterations {
  163. return Err(RuntimeError::new(
  164. "conversation loop exceeded the maximum number of iterations",
  165. ));
  166. }
  167. let request = ApiRequest {
  168. system_prompt: self.system_prompt.clone(),
  169. messages: self.session.messages.clone(),
  170. };
  171. let events = self.api_client.stream(request)?;
  172. let (assistant_message, usage) = build_assistant_message(events)?;
  173. if let Some(usage) = usage {
  174. self.usage_tracker.record(usage);
  175. }
  176. let pending_tool_uses = assistant_message
  177. .blocks
  178. .iter()
  179. .filter_map(|block| match block {
  180. ContentBlock::ToolUse { id, name, input } => {
  181. Some((id.clone(), name.clone(), input.clone()))
  182. }
  183. _ => None,
  184. })
  185. .collect::<Vec<_>>();
  186. self.session.messages.push(assistant_message.clone());
  187. assistant_messages.push(assistant_message);
  188. if pending_tool_uses.is_empty() {
  189. break;
  190. }
  191. for (tool_use_id, tool_name, input) in pending_tool_uses {
  192. let permission_outcome = if let Some(prompt) = prompter.as_mut() {
  193. self.permission_policy
  194. .authorize(&tool_name, &input, Some(*prompt))
  195. } else {
  196. self.permission_policy.authorize(&tool_name, &input, None)
  197. };
  198. let result_message = match permission_outcome {
  199. PermissionOutcome::Allow => {
  200. let pre_hook_result = self.hook_runner.run_pre_tool_use(&tool_name, &input);
  201. if pre_hook_result.is_denied() {
  202. let deny_message = format!("PreToolUse hook denied tool `{tool_name}`");
  203. ConversationMessage::tool_result(
  204. tool_use_id,
  205. tool_name,
  206. format_hook_message(&pre_hook_result, &deny_message),
  207. true,
  208. )
  209. } else {
  210. let (mut output, mut is_error) =
  211. match self.tool_executor.execute(&tool_name, &input) {
  212. Ok(output) => (output, false),
  213. Err(error) => (error.to_string(), true),
  214. };
  215. output = merge_hook_feedback(pre_hook_result.messages(), output, false);
  216. let post_hook_result = self
  217. .hook_runner
  218. .run_post_tool_use(&tool_name, &input, &output, is_error);
  219. if post_hook_result.is_denied() {
  220. is_error = true;
  221. }
  222. output = merge_hook_feedback(
  223. post_hook_result.messages(),
  224. output,
  225. post_hook_result.is_denied(),
  226. );
  227. ConversationMessage::tool_result(
  228. tool_use_id,
  229. tool_name,
  230. output,
  231. is_error,
  232. )
  233. }
  234. }
  235. PermissionOutcome::Deny { reason } => {
  236. ConversationMessage::tool_result(tool_use_id, tool_name, reason, true)
  237. }
  238. };
  239. self.session.messages.push(result_message.clone());
  240. tool_results.push(result_message);
  241. }
  242. }
  243. let auto_compaction = self.maybe_auto_compact();
  244. Ok(TurnSummary {
  245. assistant_messages,
  246. tool_results,
  247. iterations,
  248. usage: self.usage_tracker.cumulative_usage(),
  249. auto_compaction,
  250. })
  251. }
  252. #[must_use]
  253. pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
  254. compact_session(&self.session, config)
  255. }
  256. #[must_use]
  257. pub fn estimated_tokens(&self) -> usize {
  258. estimate_session_tokens(&self.session)
  259. }
  260. #[must_use]
  261. pub fn usage(&self) -> &UsageTracker {
  262. &self.usage_tracker
  263. }
  264. #[must_use]
  265. pub fn session(&self) -> &Session {
  266. &self.session
  267. }
  268. #[must_use]
  269. pub fn into_session(self) -> Session {
  270. self.session
  271. }
  272. fn maybe_auto_compact(&mut self) -> Option<AutoCompactionEvent> {
  273. if self.usage_tracker.cumulative_usage().input_tokens
  274. < self.auto_compaction_input_tokens_threshold
  275. {
  276. return None;
  277. }
  278. let result = compact_session(
  279. &self.session,
  280. CompactionConfig {
  281. max_estimated_tokens: 0,
  282. ..CompactionConfig::default()
  283. },
  284. );
  285. if result.removed_message_count == 0 {
  286. return None;
  287. }
  288. self.session = result.compacted_session;
  289. Some(AutoCompactionEvent {
  290. removed_message_count: result.removed_message_count,
  291. })
  292. }
  293. }
  294. #[must_use]
  295. pub fn auto_compaction_threshold_from_env() -> u32 {
  296. parse_auto_compaction_threshold(
  297. std::env::var(AUTO_COMPACTION_THRESHOLD_ENV_VAR)
  298. .ok()
  299. .as_deref(),
  300. )
  301. }
  302. #[must_use]
  303. fn parse_auto_compaction_threshold(value: Option<&str>) -> u32 {
  304. value
  305. .and_then(|raw| raw.trim().parse::<u32>().ok())
  306. .filter(|threshold| *threshold > 0)
  307. .unwrap_or(DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD)
  308. }
  309. fn build_assistant_message(
  310. events: Vec<AssistantEvent>,
  311. ) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
  312. let mut text = String::new();
  313. let mut blocks = Vec::new();
  314. let mut finished = false;
  315. let mut usage = None;
  316. for event in events {
  317. match event {
  318. AssistantEvent::TextDelta(delta) => text.push_str(&delta),
  319. AssistantEvent::ToolUse { id, name, input } => {
  320. flush_text_block(&mut text, &mut blocks);
  321. blocks.push(ContentBlock::ToolUse { id, name, input });
  322. }
  323. AssistantEvent::Usage(value) => usage = Some(value),
  324. AssistantEvent::MessageStop => {
  325. finished = true;
  326. }
  327. }
  328. }
  329. flush_text_block(&mut text, &mut blocks);
  330. if !finished {
  331. return Err(RuntimeError::new(
  332. "assistant stream ended without a message stop event",
  333. ));
  334. }
  335. if blocks.is_empty() {
  336. return Err(RuntimeError::new("assistant stream produced no content"));
  337. }
  338. Ok((
  339. ConversationMessage::assistant_with_usage(blocks, usage),
  340. usage,
  341. ))
  342. }
  343. fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
  344. if !text.is_empty() {
  345. blocks.push(ContentBlock::Text {
  346. text: std::mem::take(text),
  347. });
  348. }
  349. }
  350. fn format_hook_message(result: &HookRunResult, fallback: &str) -> String {
  351. if result.messages().is_empty() {
  352. fallback.to_string()
  353. } else {
  354. result.messages().join("\n")
  355. }
  356. }
  357. fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> String {
  358. if messages.is_empty() {
  359. return output;
  360. }
  361. let mut sections = Vec::new();
  362. if !output.trim().is_empty() {
  363. sections.push(output);
  364. }
  365. let label = if denied {
  366. "Hook feedback (denied)"
  367. } else {
  368. "Hook feedback"
  369. };
  370. sections.push(format!("{label}:\n{}", messages.join("\n")));
  371. sections.join("\n\n")
  372. }
  373. type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
  374. #[derive(Default)]
  375. pub struct StaticToolExecutor {
  376. handlers: BTreeMap<String, ToolHandler>,
  377. }
  378. impl StaticToolExecutor {
  379. #[must_use]
  380. pub fn new() -> Self {
  381. Self::default()
  382. }
  383. #[must_use]
  384. pub fn register(
  385. mut self,
  386. tool_name: impl Into<String>,
  387. handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
  388. ) -> Self {
  389. self.handlers.insert(tool_name.into(), Box::new(handler));
  390. self
  391. }
  392. }
  393. impl ToolExecutor for StaticToolExecutor {
  394. fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
  395. self.handlers
  396. .get_mut(tool_name)
  397. .ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
  398. }
  399. }
  400. #[cfg(test)]
  401. mod tests {
  402. use super::{
  403. parse_auto_compaction_threshold, ApiClient, ApiRequest, AssistantEvent,
  404. AutoCompactionEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
  405. DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD,
  406. };
  407. use crate::compact::CompactionConfig;
  408. use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
  409. use crate::permissions::{
  410. PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
  411. PermissionRequest,
  412. };
  413. use crate::prompt::{ProjectContext, SystemPromptBuilder};
  414. use crate::session::{ContentBlock, MessageRole, Session};
  415. use crate::usage::TokenUsage;
  416. use std::path::PathBuf;
  417. struct ScriptedApiClient {
  418. call_count: usize,
  419. }
  420. impl ApiClient for ScriptedApiClient {
  421. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  422. self.call_count += 1;
  423. match self.call_count {
  424. 1 => {
  425. assert!(request
  426. .messages
  427. .iter()
  428. .any(|message| message.role == MessageRole::User));
  429. Ok(vec![
  430. AssistantEvent::TextDelta("Let me calculate that.".to_string()),
  431. AssistantEvent::ToolUse {
  432. id: "tool-1".to_string(),
  433. name: "add".to_string(),
  434. input: "2,2".to_string(),
  435. },
  436. AssistantEvent::Usage(TokenUsage {
  437. input_tokens: 20,
  438. output_tokens: 6,
  439. cache_creation_input_tokens: 1,
  440. cache_read_input_tokens: 2,
  441. }),
  442. AssistantEvent::MessageStop,
  443. ])
  444. }
  445. 2 => {
  446. let last_message = request
  447. .messages
  448. .last()
  449. .expect("tool result should be present");
  450. assert_eq!(last_message.role, MessageRole::Tool);
  451. Ok(vec![
  452. AssistantEvent::TextDelta("The answer is 4.".to_string()),
  453. AssistantEvent::Usage(TokenUsage {
  454. input_tokens: 24,
  455. output_tokens: 4,
  456. cache_creation_input_tokens: 1,
  457. cache_read_input_tokens: 3,
  458. }),
  459. AssistantEvent::MessageStop,
  460. ])
  461. }
  462. _ => Err(RuntimeError::new("unexpected extra API call")),
  463. }
  464. }
  465. }
  466. struct PromptAllowOnce;
  467. impl PermissionPrompter for PromptAllowOnce {
  468. fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
  469. assert_eq!(request.tool_name, "add");
  470. PermissionPromptDecision::Allow
  471. }
  472. }
  473. #[test]
  474. fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
  475. let api_client = ScriptedApiClient { call_count: 0 };
  476. let tool_executor = StaticToolExecutor::new().register("add", |input| {
  477. let total = input
  478. .split(',')
  479. .map(|part| part.parse::<i32>().expect("input must be valid integer"))
  480. .sum::<i32>();
  481. Ok(total.to_string())
  482. });
  483. let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
  484. let system_prompt = SystemPromptBuilder::new()
  485. .with_project_context(ProjectContext {
  486. cwd: PathBuf::from("/tmp/project"),
  487. current_date: "2026-03-31".to_string(),
  488. git_status: None,
  489. git_diff: None,
  490. instruction_files: Vec::new(),
  491. })
  492. .with_os("linux", "6.8")
  493. .build();
  494. let mut runtime = ConversationRuntime::new(
  495. Session::new(),
  496. api_client,
  497. tool_executor,
  498. permission_policy,
  499. system_prompt,
  500. );
  501. let summary = runtime
  502. .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
  503. .expect("conversation loop should succeed");
  504. assert_eq!(summary.iterations, 2);
  505. assert_eq!(summary.assistant_messages.len(), 2);
  506. assert_eq!(summary.tool_results.len(), 1);
  507. assert_eq!(runtime.session().messages.len(), 4);
  508. assert_eq!(summary.usage.output_tokens, 10);
  509. assert_eq!(summary.auto_compaction, None);
  510. assert!(matches!(
  511. runtime.session().messages[1].blocks[1],
  512. ContentBlock::ToolUse { .. }
  513. ));
  514. assert!(matches!(
  515. runtime.session().messages[2].blocks[0],
  516. ContentBlock::ToolResult {
  517. is_error: false,
  518. ..
  519. }
  520. ));
  521. }
  522. #[test]
  523. fn records_denied_tool_results_when_prompt_rejects() {
  524. struct RejectPrompter;
  525. impl PermissionPrompter for RejectPrompter {
  526. fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
  527. PermissionPromptDecision::Deny {
  528. reason: "not now".to_string(),
  529. }
  530. }
  531. }
  532. struct SingleCallApiClient;
  533. impl ApiClient for SingleCallApiClient {
  534. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  535. if request
  536. .messages
  537. .iter()
  538. .any(|message| message.role == MessageRole::Tool)
  539. {
  540. return Ok(vec![
  541. AssistantEvent::TextDelta("I could not use the tool.".to_string()),
  542. AssistantEvent::MessageStop,
  543. ]);
  544. }
  545. Ok(vec![
  546. AssistantEvent::ToolUse {
  547. id: "tool-1".to_string(),
  548. name: "blocked".to_string(),
  549. input: "secret".to_string(),
  550. },
  551. AssistantEvent::MessageStop,
  552. ])
  553. }
  554. }
  555. let mut runtime = ConversationRuntime::new(
  556. Session::new(),
  557. SingleCallApiClient,
  558. StaticToolExecutor::new(),
  559. PermissionPolicy::new(PermissionMode::WorkspaceWrite),
  560. vec!["system".to_string()],
  561. );
  562. let summary = runtime
  563. .run_turn("use the tool", Some(&mut RejectPrompter))
  564. .expect("conversation should continue after denied tool");
  565. assert_eq!(summary.tool_results.len(), 1);
  566. assert!(matches!(
  567. &summary.tool_results[0].blocks[0],
  568. ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
  569. ));
  570. }
  571. #[test]
  572. fn denies_tool_use_when_pre_tool_hook_blocks() {
  573. struct SingleCallApiClient;
  574. impl ApiClient for SingleCallApiClient {
  575. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  576. if request
  577. .messages
  578. .iter()
  579. .any(|message| message.role == MessageRole::Tool)
  580. {
  581. return Ok(vec![
  582. AssistantEvent::TextDelta("blocked".to_string()),
  583. AssistantEvent::MessageStop,
  584. ]);
  585. }
  586. Ok(vec![
  587. AssistantEvent::ToolUse {
  588. id: "tool-1".to_string(),
  589. name: "blocked".to_string(),
  590. input: r#"{"path":"secret.txt"}"#.to_string(),
  591. },
  592. AssistantEvent::MessageStop,
  593. ])
  594. }
  595. }
  596. let mut runtime = ConversationRuntime::new_with_features(
  597. Session::new(),
  598. SingleCallApiClient,
  599. StaticToolExecutor::new().register("blocked", |_input| {
  600. panic!("tool should not execute when hook denies")
  601. }),
  602. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  603. vec!["system".to_string()],
  604. RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
  605. vec![shell_snippet("printf 'blocked by hook'; exit 2")],
  606. Vec::new(),
  607. )),
  608. );
  609. let summary = runtime
  610. .run_turn("use the tool", None)
  611. .expect("conversation should continue after hook denial");
  612. assert_eq!(summary.tool_results.len(), 1);
  613. let ContentBlock::ToolResult {
  614. is_error, output, ..
  615. } = &summary.tool_results[0].blocks[0]
  616. else {
  617. panic!("expected tool result block");
  618. };
  619. assert!(
  620. *is_error,
  621. "hook denial should produce an error result: {output}"
  622. );
  623. assert!(
  624. output.contains("denied tool") || output.contains("blocked by hook"),
  625. "unexpected hook denial output: {output:?}"
  626. );
  627. }
  628. #[test]
  629. fn appends_post_tool_hook_feedback_to_tool_result() {
  630. struct TwoCallApiClient {
  631. calls: usize,
  632. }
  633. impl ApiClient for TwoCallApiClient {
  634. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  635. self.calls += 1;
  636. match self.calls {
  637. 1 => Ok(vec![
  638. AssistantEvent::ToolUse {
  639. id: "tool-1".to_string(),
  640. name: "add".to_string(),
  641. input: r#"{"lhs":2,"rhs":2}"#.to_string(),
  642. },
  643. AssistantEvent::MessageStop,
  644. ]),
  645. 2 => {
  646. assert!(request
  647. .messages
  648. .iter()
  649. .any(|message| message.role == MessageRole::Tool));
  650. Ok(vec![
  651. AssistantEvent::TextDelta("done".to_string()),
  652. AssistantEvent::MessageStop,
  653. ])
  654. }
  655. _ => Err(RuntimeError::new("unexpected extra API call")),
  656. }
  657. }
  658. }
  659. let mut runtime = ConversationRuntime::new_with_features(
  660. Session::new(),
  661. TwoCallApiClient { calls: 0 },
  662. StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
  663. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  664. vec!["system".to_string()],
  665. RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
  666. vec![shell_snippet("printf 'pre hook ran'")],
  667. vec![shell_snippet("printf 'post hook ran'")],
  668. )),
  669. );
  670. let summary = runtime
  671. .run_turn("use add", None)
  672. .expect("tool loop succeeds");
  673. assert_eq!(summary.tool_results.len(), 1);
  674. let ContentBlock::ToolResult {
  675. is_error, output, ..
  676. } = &summary.tool_results[0].blocks[0]
  677. else {
  678. panic!("expected tool result block");
  679. };
  680. assert!(
  681. !*is_error,
  682. "post hook should preserve non-error result: {output:?}"
  683. );
  684. assert!(
  685. output.contains('4'),
  686. "tool output missing value: {output:?}"
  687. );
  688. assert!(
  689. output.contains("pre hook ran"),
  690. "tool output missing pre hook feedback: {output:?}"
  691. );
  692. assert!(
  693. output.contains("post hook ran"),
  694. "tool output missing post hook feedback: {output:?}"
  695. );
  696. }
  697. #[test]
  698. fn reconstructs_usage_tracker_from_restored_session() {
  699. struct SimpleApi;
  700. impl ApiClient for SimpleApi {
  701. fn stream(
  702. &mut self,
  703. _request: ApiRequest,
  704. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  705. Ok(vec![
  706. AssistantEvent::TextDelta("done".to_string()),
  707. AssistantEvent::MessageStop,
  708. ])
  709. }
  710. }
  711. let mut session = Session::new();
  712. session
  713. .messages
  714. .push(crate::session::ConversationMessage::assistant_with_usage(
  715. vec![ContentBlock::Text {
  716. text: "earlier".to_string(),
  717. }],
  718. Some(TokenUsage {
  719. input_tokens: 11,
  720. output_tokens: 7,
  721. cache_creation_input_tokens: 2,
  722. cache_read_input_tokens: 1,
  723. }),
  724. ));
  725. let runtime = ConversationRuntime::new(
  726. session,
  727. SimpleApi,
  728. StaticToolExecutor::new(),
  729. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  730. vec!["system".to_string()],
  731. );
  732. assert_eq!(runtime.usage().turns(), 1);
  733. assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
  734. }
  735. #[test]
  736. fn compacts_session_after_turns() {
  737. struct SimpleApi;
  738. impl ApiClient for SimpleApi {
  739. fn stream(
  740. &mut self,
  741. _request: ApiRequest,
  742. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  743. Ok(vec![
  744. AssistantEvent::TextDelta("done".to_string()),
  745. AssistantEvent::MessageStop,
  746. ])
  747. }
  748. }
  749. let mut runtime = ConversationRuntime::new(
  750. Session::new(),
  751. SimpleApi,
  752. StaticToolExecutor::new(),
  753. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  754. vec!["system".to_string()],
  755. );
  756. runtime.run_turn("a", None).expect("turn a");
  757. runtime.run_turn("b", None).expect("turn b");
  758. runtime.run_turn("c", None).expect("turn c");
  759. let result = runtime.compact(CompactionConfig {
  760. preserve_recent_messages: 2,
  761. max_estimated_tokens: 1,
  762. });
  763. assert!(result.summary.contains("Conversation summary"));
  764. assert_eq!(
  765. result.compacted_session.messages[0].role,
  766. MessageRole::System
  767. );
  768. }
  769. #[cfg(windows)]
  770. fn shell_snippet(script: &str) -> String {
  771. script.replace('\'', "\"")
  772. }
  773. #[cfg(not(windows))]
  774. fn shell_snippet(script: &str) -> String {
  775. script.to_string()
  776. }
  777. #[test]
  778. fn auto_compacts_when_cumulative_input_threshold_is_crossed() {
  779. struct SimpleApi;
  780. impl ApiClient for SimpleApi {
  781. fn stream(
  782. &mut self,
  783. _request: ApiRequest,
  784. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  785. Ok(vec![
  786. AssistantEvent::TextDelta("done".to_string()),
  787. AssistantEvent::Usage(TokenUsage {
  788. input_tokens: 120_000,
  789. output_tokens: 4,
  790. cache_creation_input_tokens: 0,
  791. cache_read_input_tokens: 0,
  792. }),
  793. AssistantEvent::MessageStop,
  794. ])
  795. }
  796. }
  797. let session = Session {
  798. version: 1,
  799. messages: vec![
  800. crate::session::ConversationMessage::user_text("one"),
  801. crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
  802. text: "two".to_string(),
  803. }]),
  804. crate::session::ConversationMessage::user_text("three"),
  805. crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
  806. text: "four".to_string(),
  807. }]),
  808. ],
  809. };
  810. let mut runtime = ConversationRuntime::new(
  811. session,
  812. SimpleApi,
  813. StaticToolExecutor::new(),
  814. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  815. vec!["system".to_string()],
  816. )
  817. .with_auto_compaction_input_tokens_threshold(100_000);
  818. let summary = runtime
  819. .run_turn("trigger", None)
  820. .expect("turn should succeed");
  821. assert_eq!(
  822. summary.auto_compaction,
  823. Some(AutoCompactionEvent {
  824. removed_message_count: 2,
  825. })
  826. );
  827. assert_eq!(runtime.session().messages[0].role, MessageRole::System);
  828. }
  829. #[test]
  830. fn skips_auto_compaction_below_threshold() {
  831. struct SimpleApi;
  832. impl ApiClient for SimpleApi {
  833. fn stream(
  834. &mut self,
  835. _request: ApiRequest,
  836. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  837. Ok(vec![
  838. AssistantEvent::TextDelta("done".to_string()),
  839. AssistantEvent::Usage(TokenUsage {
  840. input_tokens: 99_999,
  841. output_tokens: 4,
  842. cache_creation_input_tokens: 0,
  843. cache_read_input_tokens: 0,
  844. }),
  845. AssistantEvent::MessageStop,
  846. ])
  847. }
  848. }
  849. let mut runtime = ConversationRuntime::new(
  850. Session::new(),
  851. SimpleApi,
  852. StaticToolExecutor::new(),
  853. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  854. vec!["system".to_string()],
  855. )
  856. .with_auto_compaction_input_tokens_threshold(100_000);
  857. let summary = runtime
  858. .run_turn("trigger", None)
  859. .expect("turn should succeed");
  860. assert_eq!(summary.auto_compaction, None);
  861. assert_eq!(runtime.session().messages.len(), 2);
  862. }
  863. #[test]
  864. fn auto_compaction_threshold_defaults_and_parses_values() {
  865. assert_eq!(
  866. parse_auto_compaction_threshold(None),
  867. DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
  868. );
  869. assert_eq!(parse_auto_compaction_threshold(Some("4321")), 4321);
  870. assert_eq!(
  871. parse_auto_compaction_threshold(Some("not-a-number")),
  872. DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
  873. );
  874. }
  875. }