conversation.rs 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096
  1. use std::collections::BTreeMap;
  2. use std::fmt::{Display, Formatter};
  3. use plugins::PluginRegistry;
  4. use crate::compact::{
  5. compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
  6. };
  7. use crate::config::RuntimeFeatureConfig;
  8. use crate::hooks::{HookRunResult, HookRunner};
  9. use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter};
  10. use crate::session::{ContentBlock, ConversationMessage, Session};
  11. use crate::usage::{TokenUsage, UsageTracker};
  12. const DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD: u32 = 200_000;
  13. const AUTO_COMPACTION_THRESHOLD_ENV_VAR: &str = "CLAUDE_CODE_AUTO_COMPACT_INPUT_TOKENS";
  14. #[derive(Debug, Clone, PartialEq, Eq)]
  15. pub struct ApiRequest {
  16. pub system_prompt: Vec<String>,
  17. pub messages: Vec<ConversationMessage>,
  18. }
  19. #[derive(Debug, Clone, PartialEq, Eq)]
  20. pub enum AssistantEvent {
  21. TextDelta(String),
  22. ToolUse {
  23. id: String,
  24. name: String,
  25. input: String,
  26. },
  27. Usage(TokenUsage),
  28. MessageStop,
  29. }
  30. pub trait ApiClient {
  31. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
  32. }
  33. pub trait ToolExecutor {
  34. fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError>;
  35. }
  36. #[derive(Debug, Clone, PartialEq, Eq)]
  37. pub struct ToolError {
  38. message: String,
  39. }
  40. impl ToolError {
  41. #[must_use]
  42. pub fn new(message: impl Into<String>) -> Self {
  43. Self {
  44. message: message.into(),
  45. }
  46. }
  47. }
  48. impl Display for ToolError {
  49. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  50. write!(f, "{}", self.message)
  51. }
  52. }
  53. impl std::error::Error for ToolError {}
  54. #[derive(Debug, Clone, PartialEq, Eq)]
  55. pub struct RuntimeError {
  56. message: String,
  57. }
  58. impl RuntimeError {
  59. #[must_use]
  60. pub fn new(message: impl Into<String>) -> Self {
  61. Self {
  62. message: message.into(),
  63. }
  64. }
  65. }
  66. impl Display for RuntimeError {
  67. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  68. write!(f, "{}", self.message)
  69. }
  70. }
  71. impl std::error::Error for RuntimeError {}
  72. #[derive(Debug, Clone, PartialEq, Eq)]
  73. pub struct TurnSummary {
  74. pub assistant_messages: Vec<ConversationMessage>,
  75. pub tool_results: Vec<ConversationMessage>,
  76. pub iterations: usize,
  77. pub usage: TokenUsage,
  78. pub auto_compaction: Option<AutoCompactionEvent>,
  79. }
  80. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  81. pub struct AutoCompactionEvent {
  82. pub removed_message_count: usize,
  83. }
  84. pub struct ConversationRuntime<C, T> {
  85. session: Session,
  86. api_client: C,
  87. tool_executor: T,
  88. permission_policy: PermissionPolicy,
  89. system_prompt: Vec<String>,
  90. max_iterations: usize,
  91. usage_tracker: UsageTracker,
  92. hook_runner: HookRunner,
  93. auto_compaction_input_tokens_threshold: u32,
  94. plugin_registry: Option<PluginRegistry>,
  95. plugins_shutdown: bool,
  96. }
  97. impl<C, T> ConversationRuntime<C, T>
  98. where
  99. C: ApiClient,
  100. T: ToolExecutor,
  101. {
  102. #[must_use]
  103. pub fn new(
  104. session: Session,
  105. api_client: C,
  106. tool_executor: T,
  107. permission_policy: PermissionPolicy,
  108. system_prompt: Vec<String>,
  109. ) -> Self {
  110. Self::new_with_features(
  111. session,
  112. api_client,
  113. tool_executor,
  114. permission_policy,
  115. system_prompt,
  116. RuntimeFeatureConfig::default(),
  117. )
  118. }
  119. #[must_use]
  120. #[allow(clippy::needless_pass_by_value)]
  121. pub fn new_with_features(
  122. session: Session,
  123. api_client: C,
  124. tool_executor: T,
  125. permission_policy: PermissionPolicy,
  126. system_prompt: Vec<String>,
  127. feature_config: RuntimeFeatureConfig,
  128. ) -> Self {
  129. let usage_tracker = UsageTracker::from_session(&session);
  130. Self {
  131. session,
  132. api_client,
  133. tool_executor,
  134. permission_policy,
  135. system_prompt,
  136. max_iterations: usize::MAX,
  137. usage_tracker,
  138. hook_runner: HookRunner::from_feature_config(&feature_config),
  139. auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(),
  140. plugin_registry: None,
  141. plugins_shutdown: false,
  142. }
  143. }
  144. #[allow(clippy::needless_pass_by_value)]
  145. pub fn new_with_plugins(
  146. session: Session,
  147. api_client: C,
  148. tool_executor: T,
  149. permission_policy: PermissionPolicy,
  150. system_prompt: Vec<String>,
  151. feature_config: RuntimeFeatureConfig,
  152. plugin_registry: PluginRegistry,
  153. ) -> Result<Self, RuntimeError> {
  154. plugin_registry
  155. .initialize()
  156. .map_err(|error| RuntimeError::new(format!("plugin initialization failed: {error}")))?;
  157. let mut runtime = Self::new_with_features(
  158. session,
  159. api_client,
  160. tool_executor,
  161. permission_policy,
  162. system_prompt,
  163. feature_config,
  164. );
  165. runtime.plugin_registry = Some(plugin_registry);
  166. Ok(runtime)
  167. }
  168. #[must_use]
  169. pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
  170. self.max_iterations = max_iterations;
  171. self
  172. }
  173. #[must_use]
  174. pub fn with_auto_compaction_input_tokens_threshold(mut self, threshold: u32) -> Self {
  175. self.auto_compaction_input_tokens_threshold = threshold;
  176. self
  177. }
  178. pub fn run_turn(
  179. &mut self,
  180. user_input: impl Into<String>,
  181. mut prompter: Option<&mut dyn PermissionPrompter>,
  182. ) -> Result<TurnSummary, RuntimeError> {
  183. self.session
  184. .messages
  185. .push(ConversationMessage::user_text(user_input.into()));
  186. let mut assistant_messages = Vec::new();
  187. let mut tool_results = Vec::new();
  188. let mut iterations = 0;
  189. loop {
  190. iterations += 1;
  191. if iterations > self.max_iterations {
  192. return Err(RuntimeError::new(
  193. "conversation loop exceeded the maximum number of iterations",
  194. ));
  195. }
  196. let request = ApiRequest {
  197. system_prompt: self.system_prompt.clone(),
  198. messages: self.session.messages.clone(),
  199. };
  200. let events = self.api_client.stream(request)?;
  201. let (assistant_message, usage) = build_assistant_message(events)?;
  202. if let Some(usage) = usage {
  203. self.usage_tracker.record(usage);
  204. }
  205. let pending_tool_uses = assistant_message
  206. .blocks
  207. .iter()
  208. .filter_map(|block| match block {
  209. ContentBlock::ToolUse { id, name, input } => {
  210. Some((id.clone(), name.clone(), input.clone()))
  211. }
  212. _ => None,
  213. })
  214. .collect::<Vec<_>>();
  215. self.session.messages.push(assistant_message.clone());
  216. assistant_messages.push(assistant_message);
  217. if pending_tool_uses.is_empty() {
  218. break;
  219. }
  220. for (tool_use_id, tool_name, input) in pending_tool_uses {
  221. let permission_outcome = if let Some(prompt) = prompter.as_mut() {
  222. self.permission_policy
  223. .authorize(&tool_name, &input, Some(*prompt))
  224. } else {
  225. self.permission_policy.authorize(&tool_name, &input, None)
  226. };
  227. let result_message = match permission_outcome {
  228. PermissionOutcome::Allow => {
  229. let pre_hook_result = self.hook_runner.run_pre_tool_use(&tool_name, &input);
  230. if pre_hook_result.is_denied() {
  231. let deny_message = format!("PreToolUse hook denied tool `{tool_name}`");
  232. ConversationMessage::tool_result(
  233. tool_use_id,
  234. tool_name,
  235. format_hook_message(&pre_hook_result, &deny_message),
  236. true,
  237. )
  238. } else {
  239. let (mut output, mut is_error) =
  240. match self.tool_executor.execute(&tool_name, &input) {
  241. Ok(output) => (output, false),
  242. Err(error) => (error.to_string(), true),
  243. };
  244. output = merge_hook_feedback(pre_hook_result.messages(), output, false);
  245. let post_hook_result = self
  246. .hook_runner
  247. .run_post_tool_use(&tool_name, &input, &output, is_error);
  248. if post_hook_result.is_denied() {
  249. is_error = true;
  250. }
  251. output = merge_hook_feedback(
  252. post_hook_result.messages(),
  253. output,
  254. post_hook_result.is_denied(),
  255. );
  256. ConversationMessage::tool_result(
  257. tool_use_id,
  258. tool_name,
  259. output,
  260. is_error,
  261. )
  262. }
  263. }
  264. PermissionOutcome::Deny { reason } => {
  265. ConversationMessage::tool_result(tool_use_id, tool_name, reason, true)
  266. }
  267. };
  268. self.session.messages.push(result_message.clone());
  269. tool_results.push(result_message);
  270. }
  271. }
  272. let auto_compaction = self.maybe_auto_compact();
  273. Ok(TurnSummary {
  274. assistant_messages,
  275. tool_results,
  276. iterations,
  277. usage: self.usage_tracker.cumulative_usage(),
  278. auto_compaction,
  279. })
  280. }
  281. #[must_use]
  282. pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
  283. compact_session(&self.session, config)
  284. }
  285. #[must_use]
  286. pub fn estimated_tokens(&self) -> usize {
  287. estimate_session_tokens(&self.session)
  288. }
  289. #[must_use]
  290. pub fn usage(&self) -> &UsageTracker {
  291. &self.usage_tracker
  292. }
  293. #[must_use]
  294. pub fn session(&self) -> &Session {
  295. &self.session
  296. }
  297. #[must_use]
  298. pub fn into_session(mut self) -> Session {
  299. let _ = self.shutdown_plugins();
  300. std::mem::take(&mut self.session)
  301. }
  302. pub fn shutdown_plugins(&mut self) -> Result<(), RuntimeError> {
  303. if self.plugins_shutdown {
  304. return Ok(());
  305. }
  306. if let Some(registry) = &self.plugin_registry {
  307. registry
  308. .shutdown()
  309. .map_err(|error| RuntimeError::new(format!("plugin shutdown failed: {error}")))?;
  310. }
  311. self.plugins_shutdown = true;
  312. Ok(())
  313. }
  314. fn maybe_auto_compact(&mut self) -> Option<AutoCompactionEvent> {
  315. if self.usage_tracker.cumulative_usage().input_tokens
  316. < self.auto_compaction_input_tokens_threshold
  317. {
  318. return None;
  319. }
  320. let result = compact_session(
  321. &self.session,
  322. CompactionConfig {
  323. max_estimated_tokens: 0,
  324. ..CompactionConfig::default()
  325. },
  326. );
  327. if result.removed_message_count == 0 {
  328. return None;
  329. }
  330. self.session = result.compacted_session;
  331. Some(AutoCompactionEvent {
  332. removed_message_count: result.removed_message_count,
  333. })
  334. }
  335. }
  336. impl<C, T> Drop for ConversationRuntime<C, T> {
  337. fn drop(&mut self) {
  338. let _ = self.shutdown_plugins();
  339. }
  340. }
  341. #[must_use]
  342. pub fn auto_compaction_threshold_from_env() -> u32 {
  343. parse_auto_compaction_threshold(
  344. std::env::var(AUTO_COMPACTION_THRESHOLD_ENV_VAR)
  345. .ok()
  346. .as_deref(),
  347. )
  348. }
  349. #[must_use]
  350. fn parse_auto_compaction_threshold(value: Option<&str>) -> u32 {
  351. value
  352. .and_then(|raw| raw.trim().parse::<u32>().ok())
  353. .filter(|threshold| *threshold > 0)
  354. .unwrap_or(DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD)
  355. }
  356. fn build_assistant_message(
  357. events: Vec<AssistantEvent>,
  358. ) -> Result<(ConversationMessage, Option<TokenUsage>), RuntimeError> {
  359. let mut text = String::new();
  360. let mut blocks = Vec::new();
  361. let mut finished = false;
  362. let mut usage = None;
  363. for event in events {
  364. match event {
  365. AssistantEvent::TextDelta(delta) => text.push_str(&delta),
  366. AssistantEvent::ToolUse { id, name, input } => {
  367. flush_text_block(&mut text, &mut blocks);
  368. blocks.push(ContentBlock::ToolUse { id, name, input });
  369. }
  370. AssistantEvent::Usage(value) => usage = Some(value),
  371. AssistantEvent::MessageStop => {
  372. finished = true;
  373. }
  374. }
  375. }
  376. flush_text_block(&mut text, &mut blocks);
  377. if !finished {
  378. return Err(RuntimeError::new(
  379. "assistant stream ended without a message stop event",
  380. ));
  381. }
  382. if blocks.is_empty() {
  383. return Err(RuntimeError::new("assistant stream produced no content"));
  384. }
  385. Ok((
  386. ConversationMessage::assistant_with_usage(blocks, usage),
  387. usage,
  388. ))
  389. }
  390. fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
  391. if !text.is_empty() {
  392. blocks.push(ContentBlock::Text {
  393. text: std::mem::take(text),
  394. });
  395. }
  396. }
  397. fn format_hook_message(result: &HookRunResult, fallback: &str) -> String {
  398. if result.messages().is_empty() {
  399. fallback.to_string()
  400. } else {
  401. result.messages().join("\n")
  402. }
  403. }
  404. fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> String {
  405. if messages.is_empty() {
  406. return output;
  407. }
  408. let mut sections = Vec::new();
  409. if !output.trim().is_empty() {
  410. sections.push(output);
  411. }
  412. let label = if denied {
  413. "Hook feedback (denied)"
  414. } else {
  415. "Hook feedback"
  416. };
  417. sections.push(format!("{label}:\n{}", messages.join("\n")));
  418. sections.join("\n\n")
  419. }
  420. type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
  421. #[derive(Default)]
  422. pub struct StaticToolExecutor {
  423. handlers: BTreeMap<String, ToolHandler>,
  424. }
  425. impl StaticToolExecutor {
  426. #[must_use]
  427. pub fn new() -> Self {
  428. Self::default()
  429. }
  430. #[must_use]
  431. pub fn register(
  432. mut self,
  433. tool_name: impl Into<String>,
  434. handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
  435. ) -> Self {
  436. self.handlers.insert(tool_name.into(), Box::new(handler));
  437. self
  438. }
  439. }
  440. impl ToolExecutor for StaticToolExecutor {
  441. fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
  442. self.handlers
  443. .get_mut(tool_name)
  444. .ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
  445. }
  446. }
  447. #[cfg(test)]
  448. mod tests {
  449. use super::{
  450. parse_auto_compaction_threshold, ApiClient, ApiRequest, AssistantEvent,
  451. AutoCompactionEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
  452. DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD,
  453. };
  454. use crate::compact::CompactionConfig;
  455. use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
  456. use crate::permissions::{
  457. PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
  458. PermissionRequest,
  459. };
  460. use crate::prompt::{ProjectContext, SystemPromptBuilder};
  461. use crate::session::{ContentBlock, MessageRole, Session};
  462. use crate::usage::TokenUsage;
  463. use plugins::{PluginManager, PluginManagerConfig};
  464. use std::fs;
  465. use std::path::Path;
  466. use std::path::PathBuf;
  467. use std::time::{SystemTime, UNIX_EPOCH};
  468. struct ScriptedApiClient {
  469. call_count: usize,
  470. }
  471. impl ApiClient for ScriptedApiClient {
  472. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  473. self.call_count += 1;
  474. match self.call_count {
  475. 1 => {
  476. assert!(request
  477. .messages
  478. .iter()
  479. .any(|message| message.role == MessageRole::User));
  480. Ok(vec![
  481. AssistantEvent::TextDelta("Let me calculate that.".to_string()),
  482. AssistantEvent::ToolUse {
  483. id: "tool-1".to_string(),
  484. name: "add".to_string(),
  485. input: "2,2".to_string(),
  486. },
  487. AssistantEvent::Usage(TokenUsage {
  488. input_tokens: 20,
  489. output_tokens: 6,
  490. cache_creation_input_tokens: 1,
  491. cache_read_input_tokens: 2,
  492. }),
  493. AssistantEvent::MessageStop,
  494. ])
  495. }
  496. 2 => {
  497. let last_message = request
  498. .messages
  499. .last()
  500. .expect("tool result should be present");
  501. assert_eq!(last_message.role, MessageRole::Tool);
  502. Ok(vec![
  503. AssistantEvent::TextDelta("The answer is 4.".to_string()),
  504. AssistantEvent::Usage(TokenUsage {
  505. input_tokens: 24,
  506. output_tokens: 4,
  507. cache_creation_input_tokens: 1,
  508. cache_read_input_tokens: 3,
  509. }),
  510. AssistantEvent::MessageStop,
  511. ])
  512. }
  513. _ => Err(RuntimeError::new("unexpected extra API call")),
  514. }
  515. }
  516. }
  517. struct PromptAllowOnce;
  518. impl PermissionPrompter for PromptAllowOnce {
  519. fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
  520. assert_eq!(request.tool_name, "add");
  521. PermissionPromptDecision::Allow
  522. }
  523. }
  524. fn temp_dir(label: &str) -> PathBuf {
  525. let nanos = SystemTime::now()
  526. .duration_since(UNIX_EPOCH)
  527. .expect("time should be after epoch")
  528. .as_nanos();
  529. std::env::temp_dir().join(format!("runtime-plugin-{label}-{nanos}"))
  530. }
  531. fn write_lifecycle_plugin(root: &Path, name: &str) -> PathBuf {
  532. fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir");
  533. fs::create_dir_all(root.join("lifecycle")).expect("lifecycle dir");
  534. let log_path = root.join("lifecycle.log");
  535. fs::write(
  536. root.join("lifecycle").join("init.sh"),
  537. "#!/bin/sh\nprintf 'init\\n' >> \"$(dirname \"$0\")/../lifecycle.log\"\n",
  538. )
  539. .expect("write init script");
  540. fs::write(
  541. root.join("lifecycle").join("shutdown.sh"),
  542. "#!/bin/sh\nprintf 'shutdown\\n' >> \"$(dirname \"$0\")/../lifecycle.log\"\n",
  543. )
  544. .expect("write shutdown script");
  545. fs::write(
  546. root.join(".claude-plugin").join("plugin.json"),
  547. format!(
  548. "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"runtime lifecycle plugin\",\n \"lifecycle\": {{\n \"Init\": [\"./lifecycle/init.sh\"],\n \"Shutdown\": [\"./lifecycle/shutdown.sh\"]\n }}\n}}"
  549. ),
  550. )
  551. .expect("write plugin manifest");
  552. log_path
  553. }
  554. #[test]
  555. fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
  556. let api_client = ScriptedApiClient { call_count: 0 };
  557. let tool_executor = StaticToolExecutor::new().register("add", |input| {
  558. let total = input
  559. .split(',')
  560. .map(|part| part.parse::<i32>().expect("input must be valid integer"))
  561. .sum::<i32>();
  562. Ok(total.to_string())
  563. });
  564. let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
  565. let system_prompt = SystemPromptBuilder::new()
  566. .with_project_context(ProjectContext {
  567. cwd: PathBuf::from("/tmp/project"),
  568. current_date: "2026-03-31".to_string(),
  569. git_status: None,
  570. git_diff: None,
  571. instruction_files: Vec::new(),
  572. })
  573. .with_os("linux", "6.8")
  574. .build();
  575. let mut runtime = ConversationRuntime::new(
  576. Session::new(),
  577. api_client,
  578. tool_executor,
  579. permission_policy,
  580. system_prompt,
  581. );
  582. let summary = runtime
  583. .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
  584. .expect("conversation loop should succeed");
  585. assert_eq!(summary.iterations, 2);
  586. assert_eq!(summary.assistant_messages.len(), 2);
  587. assert_eq!(summary.tool_results.len(), 1);
  588. assert_eq!(runtime.session().messages.len(), 4);
  589. assert_eq!(summary.usage.output_tokens, 10);
  590. assert_eq!(summary.auto_compaction, None);
  591. assert!(matches!(
  592. runtime.session().messages[1].blocks[1],
  593. ContentBlock::ToolUse { .. }
  594. ));
  595. assert!(matches!(
  596. runtime.session().messages[2].blocks[0],
  597. ContentBlock::ToolResult {
  598. is_error: false,
  599. ..
  600. }
  601. ));
  602. }
  603. #[test]
  604. fn records_denied_tool_results_when_prompt_rejects() {
  605. struct RejectPrompter;
  606. impl PermissionPrompter for RejectPrompter {
  607. fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
  608. PermissionPromptDecision::Deny {
  609. reason: "not now".to_string(),
  610. }
  611. }
  612. }
  613. struct SingleCallApiClient;
  614. impl ApiClient for SingleCallApiClient {
  615. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  616. if request
  617. .messages
  618. .iter()
  619. .any(|message| message.role == MessageRole::Tool)
  620. {
  621. return Ok(vec![
  622. AssistantEvent::TextDelta("I could not use the tool.".to_string()),
  623. AssistantEvent::MessageStop,
  624. ]);
  625. }
  626. Ok(vec![
  627. AssistantEvent::ToolUse {
  628. id: "tool-1".to_string(),
  629. name: "blocked".to_string(),
  630. input: "secret".to_string(),
  631. },
  632. AssistantEvent::MessageStop,
  633. ])
  634. }
  635. }
  636. let mut runtime = ConversationRuntime::new(
  637. Session::new(),
  638. SingleCallApiClient,
  639. StaticToolExecutor::new(),
  640. PermissionPolicy::new(PermissionMode::WorkspaceWrite),
  641. vec!["system".to_string()],
  642. );
  643. let summary = runtime
  644. .run_turn("use the tool", Some(&mut RejectPrompter))
  645. .expect("conversation should continue after denied tool");
  646. assert_eq!(summary.tool_results.len(), 1);
  647. assert!(matches!(
  648. &summary.tool_results[0].blocks[0],
  649. ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
  650. ));
  651. }
  652. #[test]
  653. fn denies_tool_use_when_pre_tool_hook_blocks() {
  654. struct SingleCallApiClient;
  655. impl ApiClient for SingleCallApiClient {
  656. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  657. if request
  658. .messages
  659. .iter()
  660. .any(|message| message.role == MessageRole::Tool)
  661. {
  662. return Ok(vec![
  663. AssistantEvent::TextDelta("blocked".to_string()),
  664. AssistantEvent::MessageStop,
  665. ]);
  666. }
  667. Ok(vec![
  668. AssistantEvent::ToolUse {
  669. id: "tool-1".to_string(),
  670. name: "blocked".to_string(),
  671. input: r#"{"path":"secret.txt"}"#.to_string(),
  672. },
  673. AssistantEvent::MessageStop,
  674. ])
  675. }
  676. }
  677. let mut runtime = ConversationRuntime::new_with_features(
  678. Session::new(),
  679. SingleCallApiClient,
  680. StaticToolExecutor::new().register("blocked", |_input| {
  681. panic!("tool should not execute when hook denies")
  682. }),
  683. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  684. vec!["system".to_string()],
  685. RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
  686. vec![shell_snippet("printf 'blocked by hook'; exit 2")],
  687. Vec::new(),
  688. )),
  689. );
  690. let summary = runtime
  691. .run_turn("use the tool", None)
  692. .expect("conversation should continue after hook denial");
  693. assert_eq!(summary.tool_results.len(), 1);
  694. let ContentBlock::ToolResult {
  695. is_error, output, ..
  696. } = &summary.tool_results[0].blocks[0]
  697. else {
  698. panic!("expected tool result block");
  699. };
  700. assert!(
  701. *is_error,
  702. "hook denial should produce an error result: {output}"
  703. );
  704. assert!(
  705. output.contains("denied tool") || output.contains("blocked by hook"),
  706. "unexpected hook denial output: {output:?}"
  707. );
  708. }
  709. #[test]
  710. fn appends_post_tool_hook_feedback_to_tool_result() {
  711. struct TwoCallApiClient {
  712. calls: usize,
  713. }
  714. impl ApiClient for TwoCallApiClient {
  715. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  716. self.calls += 1;
  717. match self.calls {
  718. 1 => Ok(vec![
  719. AssistantEvent::ToolUse {
  720. id: "tool-1".to_string(),
  721. name: "add".to_string(),
  722. input: r#"{"lhs":2,"rhs":2}"#.to_string(),
  723. },
  724. AssistantEvent::MessageStop,
  725. ]),
  726. 2 => {
  727. assert!(request
  728. .messages
  729. .iter()
  730. .any(|message| message.role == MessageRole::Tool));
  731. Ok(vec![
  732. AssistantEvent::TextDelta("done".to_string()),
  733. AssistantEvent::MessageStop,
  734. ])
  735. }
  736. _ => Err(RuntimeError::new("unexpected extra API call")),
  737. }
  738. }
  739. }
  740. let mut runtime = ConversationRuntime::new_with_features(
  741. Session::new(),
  742. TwoCallApiClient { calls: 0 },
  743. StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
  744. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  745. vec!["system".to_string()],
  746. RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
  747. vec![shell_snippet("printf 'pre hook ran'")],
  748. vec![shell_snippet("printf 'post hook ran'")],
  749. )),
  750. );
  751. let summary = runtime
  752. .run_turn("use add", None)
  753. .expect("tool loop succeeds");
  754. assert_eq!(summary.tool_results.len(), 1);
  755. let ContentBlock::ToolResult {
  756. is_error, output, ..
  757. } = &summary.tool_results[0].blocks[0]
  758. else {
  759. panic!("expected tool result block");
  760. };
  761. assert!(
  762. !*is_error,
  763. "post hook should preserve non-error result: {output:?}"
  764. );
  765. assert!(
  766. output.contains('4'),
  767. "tool output missing value: {output:?}"
  768. );
  769. assert!(
  770. output.contains("pre hook ran"),
  771. "tool output missing pre hook feedback: {output:?}"
  772. );
  773. assert!(
  774. output.contains("post hook ran"),
  775. "tool output missing post hook feedback: {output:?}"
  776. );
  777. }
  778. #[test]
  779. fn initializes_and_shuts_down_plugins_with_runtime_lifecycle() {
  780. let config_home = temp_dir("config");
  781. let source_root = temp_dir("source");
  782. let log_path = write_lifecycle_plugin(&source_root, "runtime-lifecycle");
  783. let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home));
  784. manager
  785. .install(source_root.to_str().expect("utf8 path"))
  786. .expect("install should succeed");
  787. let registry = manager.plugin_registry().expect("registry should load");
  788. {
  789. let runtime = ConversationRuntime::new_with_plugins(
  790. Session::new(),
  791. ScriptedApiClient { call_count: 0 },
  792. StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
  793. PermissionPolicy::new(PermissionMode::WorkspaceWrite),
  794. vec!["system".to_string()],
  795. RuntimeFeatureConfig::default(),
  796. registry,
  797. )
  798. .expect("runtime should initialize plugins");
  799. let log = fs::read_to_string(&log_path).expect("init log should exist");
  800. assert_eq!(log, "init\n");
  801. drop(runtime);
  802. }
  803. let log = fs::read_to_string(&log_path).expect("shutdown log should exist");
  804. assert_eq!(log, "init\nshutdown\n");
  805. let _ = fs::remove_dir_all(config_home);
  806. let _ = fs::remove_dir_all(source_root);
  807. }
  808. #[test]
  809. fn reconstructs_usage_tracker_from_restored_session() {
  810. struct SimpleApi;
  811. impl ApiClient for SimpleApi {
  812. fn stream(
  813. &mut self,
  814. _request: ApiRequest,
  815. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  816. Ok(vec![
  817. AssistantEvent::TextDelta("done".to_string()),
  818. AssistantEvent::MessageStop,
  819. ])
  820. }
  821. }
  822. let mut session = Session::new();
  823. session
  824. .messages
  825. .push(crate::session::ConversationMessage::assistant_with_usage(
  826. vec![ContentBlock::Text {
  827. text: "earlier".to_string(),
  828. }],
  829. Some(TokenUsage {
  830. input_tokens: 11,
  831. output_tokens: 7,
  832. cache_creation_input_tokens: 2,
  833. cache_read_input_tokens: 1,
  834. }),
  835. ));
  836. let runtime = ConversationRuntime::new(
  837. session,
  838. SimpleApi,
  839. StaticToolExecutor::new(),
  840. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  841. vec!["system".to_string()],
  842. );
  843. assert_eq!(runtime.usage().turns(), 1);
  844. assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
  845. }
  846. #[test]
  847. fn compacts_session_after_turns() {
  848. struct SimpleApi;
  849. impl ApiClient for SimpleApi {
  850. fn stream(
  851. &mut self,
  852. _request: ApiRequest,
  853. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  854. Ok(vec![
  855. AssistantEvent::TextDelta("done".to_string()),
  856. AssistantEvent::MessageStop,
  857. ])
  858. }
  859. }
  860. let mut runtime = ConversationRuntime::new(
  861. Session::new(),
  862. SimpleApi,
  863. StaticToolExecutor::new(),
  864. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  865. vec!["system".to_string()],
  866. );
  867. runtime.run_turn("a", None).expect("turn a");
  868. runtime.run_turn("b", None).expect("turn b");
  869. runtime.run_turn("c", None).expect("turn c");
  870. let result = runtime.compact(CompactionConfig {
  871. preserve_recent_messages: 2,
  872. max_estimated_tokens: 1,
  873. });
  874. assert!(result.summary.contains("Conversation summary"));
  875. assert_eq!(
  876. result.compacted_session.messages[0].role,
  877. MessageRole::System
  878. );
  879. }
  880. #[cfg(windows)]
  881. fn shell_snippet(script: &str) -> String {
  882. script.replace('\'', "\"")
  883. }
  884. #[cfg(not(windows))]
  885. fn shell_snippet(script: &str) -> String {
  886. script.to_string()
  887. }
  888. #[test]
  889. fn auto_compacts_when_cumulative_input_threshold_is_crossed() {
  890. struct SimpleApi;
  891. impl ApiClient for SimpleApi {
  892. fn stream(
  893. &mut self,
  894. _request: ApiRequest,
  895. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  896. Ok(vec![
  897. AssistantEvent::TextDelta("done".to_string()),
  898. AssistantEvent::Usage(TokenUsage {
  899. input_tokens: 120_000,
  900. output_tokens: 4,
  901. cache_creation_input_tokens: 0,
  902. cache_read_input_tokens: 0,
  903. }),
  904. AssistantEvent::MessageStop,
  905. ])
  906. }
  907. }
  908. let session = Session {
  909. version: 1,
  910. messages: vec![
  911. crate::session::ConversationMessage::user_text("one"),
  912. crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
  913. text: "two".to_string(),
  914. }]),
  915. crate::session::ConversationMessage::user_text("three"),
  916. crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
  917. text: "four".to_string(),
  918. }]),
  919. ],
  920. };
  921. let mut runtime = ConversationRuntime::new(
  922. session,
  923. SimpleApi,
  924. StaticToolExecutor::new(),
  925. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  926. vec!["system".to_string()],
  927. )
  928. .with_auto_compaction_input_tokens_threshold(100_000);
  929. let summary = runtime
  930. .run_turn("trigger", None)
  931. .expect("turn should succeed");
  932. assert_eq!(
  933. summary.auto_compaction,
  934. Some(AutoCompactionEvent {
  935. removed_message_count: 2,
  936. })
  937. );
  938. assert_eq!(runtime.session().messages[0].role, MessageRole::System);
  939. }
  940. #[test]
  941. fn skips_auto_compaction_below_threshold() {
  942. struct SimpleApi;
  943. impl ApiClient for SimpleApi {
  944. fn stream(
  945. &mut self,
  946. _request: ApiRequest,
  947. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  948. Ok(vec![
  949. AssistantEvent::TextDelta("done".to_string()),
  950. AssistantEvent::Usage(TokenUsage {
  951. input_tokens: 99_999,
  952. output_tokens: 4,
  953. cache_creation_input_tokens: 0,
  954. cache_read_input_tokens: 0,
  955. }),
  956. AssistantEvent::MessageStop,
  957. ])
  958. }
  959. }
  960. let mut runtime = ConversationRuntime::new(
  961. Session::new(),
  962. SimpleApi,
  963. StaticToolExecutor::new(),
  964. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  965. vec!["system".to_string()],
  966. )
  967. .with_auto_compaction_input_tokens_threshold(100_000);
  968. let summary = runtime
  969. .run_turn("trigger", None)
  970. .expect("turn should succeed");
  971. assert_eq!(summary.auto_compaction, None);
  972. assert_eq!(runtime.session().messages.len(), 2);
  973. }
  974. #[test]
  975. fn auto_compaction_threshold_defaults_and_parses_values() {
  976. assert_eq!(
  977. parse_auto_compaction_threshold(None),
  978. DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
  979. );
  980. assert_eq!(parse_auto_compaction_threshold(Some("4321")), 4321);
  981. assert_eq!(
  982. parse_auto_compaction_threshold(Some("not-a-number")),
  983. DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
  984. );
  985. }
  986. }