conversation.rs 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679
  1. use std::collections::BTreeMap;
  2. use std::fmt::{Display, Formatter};
  3. use serde_json::{Map, Value};
  4. use telemetry::SessionTracer;
  5. use crate::compact::{
  6. compact_session, estimate_session_tokens, CompactionConfig, CompactionResult,
  7. };
  8. use crate::config::RuntimeFeatureConfig;
  9. use crate::hooks::{HookAbortSignal, HookProgressReporter, HookRunResult, HookRunner};
  10. use crate::permissions::{
  11. PermissionContext, PermissionOutcome, PermissionPolicy, PermissionPrompter,
  12. };
  13. use crate::session::{ContentBlock, ConversationMessage, Session};
  14. use crate::usage::{TokenUsage, UsageTracker};
  15. const DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD: u32 = 100_000;
  16. const AUTO_COMPACTION_THRESHOLD_ENV_VAR: &str = "CLAUDE_CODE_AUTO_COMPACT_INPUT_TOKENS";
  17. #[derive(Debug, Clone, PartialEq, Eq)]
  18. pub struct ApiRequest {
  19. pub system_prompt: Vec<String>,
  20. pub messages: Vec<ConversationMessage>,
  21. }
  22. #[derive(Debug, Clone, PartialEq, Eq)]
  23. pub enum AssistantEvent {
  24. TextDelta(String),
  25. ToolUse {
  26. id: String,
  27. name: String,
  28. input: String,
  29. },
  30. Usage(TokenUsage),
  31. PromptCache(PromptCacheEvent),
  32. MessageStop,
  33. }
  34. #[derive(Debug, Clone, PartialEq, Eq)]
  35. pub struct PromptCacheEvent {
  36. pub unexpected: bool,
  37. pub reason: String,
  38. pub previous_cache_read_input_tokens: u32,
  39. pub current_cache_read_input_tokens: u32,
  40. pub token_drop: u32,
  41. }
  42. pub trait ApiClient {
  43. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError>;
  44. }
  45. pub trait ToolExecutor {
  46. fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError>;
  47. }
  48. #[derive(Debug, Clone, PartialEq, Eq)]
  49. pub struct ToolError {
  50. message: String,
  51. }
  52. impl ToolError {
  53. #[must_use]
  54. pub fn new(message: impl Into<String>) -> Self {
  55. Self {
  56. message: message.into(),
  57. }
  58. }
  59. }
  60. impl Display for ToolError {
  61. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  62. write!(f, "{}", self.message)
  63. }
  64. }
  65. impl std::error::Error for ToolError {}
  66. #[derive(Debug, Clone, PartialEq, Eq)]
  67. pub struct RuntimeError {
  68. message: String,
  69. }
  70. impl RuntimeError {
  71. #[must_use]
  72. pub fn new(message: impl Into<String>) -> Self {
  73. Self {
  74. message: message.into(),
  75. }
  76. }
  77. }
  78. impl Display for RuntimeError {
  79. fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
  80. write!(f, "{}", self.message)
  81. }
  82. }
  83. impl std::error::Error for RuntimeError {}
  84. #[derive(Debug, Clone, PartialEq, Eq)]
  85. pub struct TurnSummary {
  86. pub assistant_messages: Vec<ConversationMessage>,
  87. pub tool_results: Vec<ConversationMessage>,
  88. pub prompt_cache_events: Vec<PromptCacheEvent>,
  89. pub iterations: usize,
  90. pub usage: TokenUsage,
  91. pub auto_compaction: Option<AutoCompactionEvent>,
  92. }
  93. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  94. pub struct AutoCompactionEvent {
  95. pub removed_message_count: usize,
  96. }
  97. pub struct ConversationRuntime<C, T> {
  98. session: Session,
  99. api_client: C,
  100. tool_executor: T,
  101. permission_policy: PermissionPolicy,
  102. system_prompt: Vec<String>,
  103. max_iterations: usize,
  104. usage_tracker: UsageTracker,
  105. hook_runner: HookRunner,
  106. auto_compaction_input_tokens_threshold: u32,
  107. hook_abort_signal: HookAbortSignal,
  108. hook_progress_reporter: Option<Box<dyn HookProgressReporter>>,
  109. session_tracer: Option<SessionTracer>,
  110. }
  111. impl<C, T> ConversationRuntime<C, T>
  112. where
  113. C: ApiClient,
  114. T: ToolExecutor,
  115. {
  116. #[must_use]
  117. pub fn new(
  118. session: Session,
  119. api_client: C,
  120. tool_executor: T,
  121. permission_policy: PermissionPolicy,
  122. system_prompt: Vec<String>,
  123. ) -> Self {
  124. Self::new_with_features(
  125. session,
  126. api_client,
  127. tool_executor,
  128. permission_policy,
  129. system_prompt,
  130. &RuntimeFeatureConfig::default(),
  131. )
  132. }
  133. #[must_use]
  134. #[allow(clippy::needless_pass_by_value)]
  135. pub fn new_with_features(
  136. session: Session,
  137. api_client: C,
  138. tool_executor: T,
  139. permission_policy: PermissionPolicy,
  140. system_prompt: Vec<String>,
  141. feature_config: &RuntimeFeatureConfig,
  142. ) -> Self {
  143. let usage_tracker = UsageTracker::from_session(&session);
  144. Self {
  145. session,
  146. api_client,
  147. tool_executor,
  148. permission_policy,
  149. system_prompt,
  150. max_iterations: usize::MAX,
  151. usage_tracker,
  152. hook_runner: HookRunner::from_feature_config(feature_config),
  153. auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(),
  154. hook_abort_signal: HookAbortSignal::default(),
  155. hook_progress_reporter: None,
  156. session_tracer: None,
  157. }
  158. }
  159. #[must_use]
  160. pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
  161. self.max_iterations = max_iterations;
  162. self
  163. }
  164. #[must_use]
  165. pub fn with_auto_compaction_input_tokens_threshold(mut self, threshold: u32) -> Self {
  166. self.auto_compaction_input_tokens_threshold = threshold;
  167. self
  168. }
  169. #[must_use]
  170. pub fn with_hook_abort_signal(mut self, hook_abort_signal: HookAbortSignal) -> Self {
  171. self.hook_abort_signal = hook_abort_signal;
  172. self
  173. }
  174. #[must_use]
  175. pub fn with_hook_progress_reporter(
  176. mut self,
  177. hook_progress_reporter: Box<dyn HookProgressReporter>,
  178. ) -> Self {
  179. self.hook_progress_reporter = Some(hook_progress_reporter);
  180. self
  181. }
  182. #[must_use]
  183. pub fn with_session_tracer(mut self, session_tracer: SessionTracer) -> Self {
  184. self.session_tracer = Some(session_tracer);
  185. self
  186. }
  187. fn run_pre_tool_use_hook(&mut self, tool_name: &str, input: &str) -> HookRunResult {
  188. if let Some(reporter) = self.hook_progress_reporter.as_mut() {
  189. self.hook_runner.run_pre_tool_use_with_context(
  190. tool_name,
  191. input,
  192. Some(&self.hook_abort_signal),
  193. Some(reporter.as_mut()),
  194. )
  195. } else {
  196. self.hook_runner.run_pre_tool_use_with_context(
  197. tool_name,
  198. input,
  199. Some(&self.hook_abort_signal),
  200. None,
  201. )
  202. }
  203. }
  204. fn run_post_tool_use_hook(
  205. &mut self,
  206. tool_name: &str,
  207. input: &str,
  208. output: &str,
  209. is_error: bool,
  210. ) -> HookRunResult {
  211. if let Some(reporter) = self.hook_progress_reporter.as_mut() {
  212. self.hook_runner.run_post_tool_use_with_context(
  213. tool_name,
  214. input,
  215. output,
  216. is_error,
  217. Some(&self.hook_abort_signal),
  218. Some(reporter.as_mut()),
  219. )
  220. } else {
  221. self.hook_runner.run_post_tool_use_with_context(
  222. tool_name,
  223. input,
  224. output,
  225. is_error,
  226. Some(&self.hook_abort_signal),
  227. None,
  228. )
  229. }
  230. }
  231. fn run_post_tool_use_failure_hook(
  232. &mut self,
  233. tool_name: &str,
  234. input: &str,
  235. output: &str,
  236. ) -> HookRunResult {
  237. if let Some(reporter) = self.hook_progress_reporter.as_mut() {
  238. self.hook_runner.run_post_tool_use_failure_with_context(
  239. tool_name,
  240. input,
  241. output,
  242. Some(&self.hook_abort_signal),
  243. Some(reporter.as_mut()),
  244. )
  245. } else {
  246. self.hook_runner.run_post_tool_use_failure_with_context(
  247. tool_name,
  248. input,
  249. output,
  250. Some(&self.hook_abort_signal),
  251. None,
  252. )
  253. }
  254. }
  255. #[allow(clippy::too_many_lines)]
  256. pub fn run_turn(
  257. &mut self,
  258. user_input: impl Into<String>,
  259. mut prompter: Option<&mut dyn PermissionPrompter>,
  260. ) -> Result<TurnSummary, RuntimeError> {
  261. let user_input = user_input.into();
  262. self.record_turn_started(&user_input);
  263. self.session
  264. .push_user_text(user_input)
  265. .map_err(|error| RuntimeError::new(error.to_string()))?;
  266. let mut assistant_messages = Vec::new();
  267. let mut tool_results = Vec::new();
  268. let mut prompt_cache_events = Vec::new();
  269. let mut iterations = 0;
  270. loop {
  271. iterations += 1;
  272. if iterations > self.max_iterations {
  273. let error = RuntimeError::new(
  274. "conversation loop exceeded the maximum number of iterations",
  275. );
  276. self.record_turn_failed(iterations, &error);
  277. return Err(error);
  278. }
  279. let request = ApiRequest {
  280. system_prompt: self.system_prompt.clone(),
  281. messages: self.session.messages.clone(),
  282. };
  283. let events = match self.api_client.stream(request) {
  284. Ok(events) => events,
  285. Err(error) => {
  286. self.record_turn_failed(iterations, &error);
  287. return Err(error);
  288. }
  289. };
  290. let (assistant_message, usage, turn_prompt_cache_events) =
  291. match build_assistant_message(events) {
  292. Ok(result) => result,
  293. Err(error) => {
  294. self.record_turn_failed(iterations, &error);
  295. return Err(error);
  296. }
  297. };
  298. if let Some(usage) = usage {
  299. self.usage_tracker.record(usage);
  300. }
  301. prompt_cache_events.extend(turn_prompt_cache_events);
  302. let pending_tool_uses = assistant_message
  303. .blocks
  304. .iter()
  305. .filter_map(|block| match block {
  306. ContentBlock::ToolUse { id, name, input } => {
  307. Some((id.clone(), name.clone(), input.clone()))
  308. }
  309. _ => None,
  310. })
  311. .collect::<Vec<_>>();
  312. self.record_assistant_iteration(
  313. iterations,
  314. &assistant_message,
  315. pending_tool_uses.len(),
  316. );
  317. self.session
  318. .push_message(assistant_message.clone())
  319. .map_err(|error| RuntimeError::new(error.to_string()))?;
  320. assistant_messages.push(assistant_message);
  321. if pending_tool_uses.is_empty() {
  322. break;
  323. }
  324. for (tool_use_id, tool_name, input) in pending_tool_uses {
  325. let pre_hook_result = self.run_pre_tool_use_hook(&tool_name, &input);
  326. let effective_input = pre_hook_result
  327. .updated_input()
  328. .map_or_else(|| input.clone(), ToOwned::to_owned);
  329. let permission_context = PermissionContext::new(
  330. pre_hook_result.permission_override(),
  331. pre_hook_result.permission_reason().map(ToOwned::to_owned),
  332. );
  333. let permission_outcome = if pre_hook_result.is_cancelled() {
  334. PermissionOutcome::Deny {
  335. reason: format_hook_message(
  336. &pre_hook_result,
  337. &format!("PreToolUse hook cancelled tool `{tool_name}`"),
  338. ),
  339. }
  340. } else if pre_hook_result.is_failed() {
  341. PermissionOutcome::Deny {
  342. reason: format_hook_message(
  343. &pre_hook_result,
  344. &format!("PreToolUse hook failed for tool `{tool_name}`"),
  345. ),
  346. }
  347. } else if pre_hook_result.is_denied() {
  348. PermissionOutcome::Deny {
  349. reason: format_hook_message(
  350. &pre_hook_result,
  351. &format!("PreToolUse hook denied tool `{tool_name}`"),
  352. ),
  353. }
  354. } else if let Some(prompt) = prompter.as_mut() {
  355. self.permission_policy.authorize_with_context(
  356. &tool_name,
  357. &effective_input,
  358. &permission_context,
  359. Some(*prompt),
  360. )
  361. } else {
  362. self.permission_policy.authorize_with_context(
  363. &tool_name,
  364. &effective_input,
  365. &permission_context,
  366. None,
  367. )
  368. };
  369. let result_message = match permission_outcome {
  370. PermissionOutcome::Allow => {
  371. self.record_tool_started(iterations, &tool_name);
  372. let (mut output, mut is_error) =
  373. match self.tool_executor.execute(&tool_name, &effective_input) {
  374. Ok(output) => (output, false),
  375. Err(error) => (error.to_string(), true),
  376. };
  377. output = merge_hook_feedback(pre_hook_result.messages(), output, false);
  378. let post_hook_result = if is_error {
  379. self.run_post_tool_use_failure_hook(
  380. &tool_name,
  381. &effective_input,
  382. &output,
  383. )
  384. } else {
  385. self.run_post_tool_use_hook(
  386. &tool_name,
  387. &effective_input,
  388. &output,
  389. false,
  390. )
  391. };
  392. if post_hook_result.is_denied()
  393. || post_hook_result.is_failed()
  394. || post_hook_result.is_cancelled()
  395. {
  396. is_error = true;
  397. }
  398. output = merge_hook_feedback(
  399. post_hook_result.messages(),
  400. output,
  401. post_hook_result.is_denied()
  402. || post_hook_result.is_failed()
  403. || post_hook_result.is_cancelled(),
  404. );
  405. ConversationMessage::tool_result(tool_use_id, tool_name, output, is_error)
  406. }
  407. PermissionOutcome::Deny { reason } => ConversationMessage::tool_result(
  408. tool_use_id,
  409. tool_name,
  410. merge_hook_feedback(pre_hook_result.messages(), reason, true),
  411. true,
  412. ),
  413. };
  414. self.session
  415. .push_message(result_message.clone())
  416. .map_err(|error| RuntimeError::new(error.to_string()))?;
  417. self.record_tool_finished(iterations, &result_message);
  418. tool_results.push(result_message);
  419. }
  420. }
  421. let auto_compaction = self.maybe_auto_compact();
  422. let summary = TurnSummary {
  423. assistant_messages,
  424. tool_results,
  425. prompt_cache_events,
  426. iterations,
  427. usage: self.usage_tracker.cumulative_usage(),
  428. auto_compaction,
  429. };
  430. self.record_turn_completed(&summary);
  431. Ok(summary)
  432. }
  433. #[must_use]
  434. pub fn compact(&self, config: CompactionConfig) -> CompactionResult {
  435. compact_session(&self.session, config)
  436. }
  437. #[must_use]
  438. pub fn estimated_tokens(&self) -> usize {
  439. estimate_session_tokens(&self.session)
  440. }
  441. #[must_use]
  442. pub fn usage(&self) -> &UsageTracker {
  443. &self.usage_tracker
  444. }
  445. #[must_use]
  446. pub fn session(&self) -> &Session {
  447. &self.session
  448. }
  449. #[must_use]
  450. pub fn fork_session(&self, branch_name: Option<String>) -> Session {
  451. self.session.fork(branch_name)
  452. }
  453. #[must_use]
  454. pub fn into_session(self) -> Session {
  455. self.session
  456. }
  457. fn maybe_auto_compact(&mut self) -> Option<AutoCompactionEvent> {
  458. if self.usage_tracker.cumulative_usage().input_tokens
  459. < self.auto_compaction_input_tokens_threshold
  460. {
  461. return None;
  462. }
  463. let result = compact_session(
  464. &self.session,
  465. CompactionConfig {
  466. max_estimated_tokens: 0,
  467. ..CompactionConfig::default()
  468. },
  469. );
  470. if result.removed_message_count == 0 {
  471. return None;
  472. }
  473. self.session = result.compacted_session;
  474. Some(AutoCompactionEvent {
  475. removed_message_count: result.removed_message_count,
  476. })
  477. }
  478. fn record_turn_started(&self, user_input: &str) {
  479. let Some(session_tracer) = &self.session_tracer else {
  480. return;
  481. };
  482. let mut attributes = Map::new();
  483. attributes.insert(
  484. "user_input".to_string(),
  485. Value::String(user_input.to_string()),
  486. );
  487. session_tracer.record("turn_started", attributes);
  488. }
  489. fn record_assistant_iteration(
  490. &self,
  491. iteration: usize,
  492. assistant_message: &ConversationMessage,
  493. pending_tool_use_count: usize,
  494. ) {
  495. let Some(session_tracer) = &self.session_tracer else {
  496. return;
  497. };
  498. let mut attributes = Map::new();
  499. attributes.insert("iteration".to_string(), Value::from(iteration as u64));
  500. attributes.insert(
  501. "assistant_blocks".to_string(),
  502. Value::from(assistant_message.blocks.len() as u64),
  503. );
  504. attributes.insert(
  505. "pending_tool_use_count".to_string(),
  506. Value::from(pending_tool_use_count as u64),
  507. );
  508. session_tracer.record("assistant_iteration_completed", attributes);
  509. }
  510. fn record_tool_started(&self, iteration: usize, tool_name: &str) {
  511. let Some(session_tracer) = &self.session_tracer else {
  512. return;
  513. };
  514. let mut attributes = Map::new();
  515. attributes.insert("iteration".to_string(), Value::from(iteration as u64));
  516. attributes.insert(
  517. "tool_name".to_string(),
  518. Value::String(tool_name.to_string()),
  519. );
  520. session_tracer.record("tool_execution_started", attributes);
  521. }
  522. fn record_tool_finished(&self, iteration: usize, result_message: &ConversationMessage) {
  523. let Some(session_tracer) = &self.session_tracer else {
  524. return;
  525. };
  526. let Some(ContentBlock::ToolResult {
  527. tool_name,
  528. is_error,
  529. ..
  530. }) = result_message.blocks.first()
  531. else {
  532. return;
  533. };
  534. let mut attributes = Map::new();
  535. attributes.insert("iteration".to_string(), Value::from(iteration as u64));
  536. attributes.insert("tool_name".to_string(), Value::String(tool_name.clone()));
  537. attributes.insert("is_error".to_string(), Value::Bool(*is_error));
  538. session_tracer.record("tool_execution_finished", attributes);
  539. }
  540. fn record_turn_completed(&self, summary: &TurnSummary) {
  541. let Some(session_tracer) = &self.session_tracer else {
  542. return;
  543. };
  544. let mut attributes = Map::new();
  545. attributes.insert(
  546. "iterations".to_string(),
  547. Value::from(summary.iterations as u64),
  548. );
  549. attributes.insert(
  550. "assistant_messages".to_string(),
  551. Value::from(summary.assistant_messages.len() as u64),
  552. );
  553. attributes.insert(
  554. "tool_results".to_string(),
  555. Value::from(summary.tool_results.len() as u64),
  556. );
  557. attributes.insert(
  558. "prompt_cache_events".to_string(),
  559. Value::from(summary.prompt_cache_events.len() as u64),
  560. );
  561. session_tracer.record("turn_completed", attributes);
  562. }
  563. fn record_turn_failed(&self, iteration: usize, error: &RuntimeError) {
  564. let Some(session_tracer) = &self.session_tracer else {
  565. return;
  566. };
  567. let mut attributes = Map::new();
  568. attributes.insert("iteration".to_string(), Value::from(iteration as u64));
  569. attributes.insert("error".to_string(), Value::String(error.to_string()));
  570. session_tracer.record("turn_failed", attributes);
  571. }
  572. }
  573. #[must_use]
  574. pub fn auto_compaction_threshold_from_env() -> u32 {
  575. parse_auto_compaction_threshold(
  576. std::env::var(AUTO_COMPACTION_THRESHOLD_ENV_VAR)
  577. .ok()
  578. .as_deref(),
  579. )
  580. }
  581. #[must_use]
  582. fn parse_auto_compaction_threshold(value: Option<&str>) -> u32 {
  583. value
  584. .and_then(|raw| raw.trim().parse::<u32>().ok())
  585. .filter(|threshold| *threshold > 0)
  586. .unwrap_or(DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD)
  587. }
  588. fn build_assistant_message(
  589. events: Vec<AssistantEvent>,
  590. ) -> Result<
  591. (
  592. ConversationMessage,
  593. Option<TokenUsage>,
  594. Vec<PromptCacheEvent>,
  595. ),
  596. RuntimeError,
  597. > {
  598. let mut text = String::new();
  599. let mut blocks = Vec::new();
  600. let mut prompt_cache_events = Vec::new();
  601. let mut finished = false;
  602. let mut usage = None;
  603. for event in events {
  604. match event {
  605. AssistantEvent::TextDelta(delta) => text.push_str(&delta),
  606. AssistantEvent::ToolUse { id, name, input } => {
  607. flush_text_block(&mut text, &mut blocks);
  608. blocks.push(ContentBlock::ToolUse { id, name, input });
  609. }
  610. AssistantEvent::Usage(value) => usage = Some(value),
  611. AssistantEvent::PromptCache(event) => prompt_cache_events.push(event),
  612. AssistantEvent::MessageStop => {
  613. finished = true;
  614. }
  615. }
  616. }
  617. flush_text_block(&mut text, &mut blocks);
  618. if !finished {
  619. return Err(RuntimeError::new(
  620. "assistant stream ended without a message stop event",
  621. ));
  622. }
  623. if blocks.is_empty() {
  624. return Err(RuntimeError::new("assistant stream produced no content"));
  625. }
  626. Ok((
  627. ConversationMessage::assistant_with_usage(blocks, usage),
  628. usage,
  629. prompt_cache_events,
  630. ))
  631. }
  632. fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) {
  633. if !text.is_empty() {
  634. blocks.push(ContentBlock::Text {
  635. text: std::mem::take(text),
  636. });
  637. }
  638. }
  639. fn format_hook_message(result: &HookRunResult, fallback: &str) -> String {
  640. if result.messages().is_empty() {
  641. fallback.to_string()
  642. } else {
  643. result.messages().join("\n")
  644. }
  645. }
  646. fn merge_hook_feedback(messages: &[String], output: String, is_error: bool) -> String {
  647. if messages.is_empty() {
  648. return output;
  649. }
  650. let mut sections = Vec::new();
  651. if !output.trim().is_empty() {
  652. sections.push(output);
  653. }
  654. let label = if is_error {
  655. "Hook feedback (error)"
  656. } else {
  657. "Hook feedback"
  658. };
  659. sections.push(format!("{label}:\n{}", messages.join("\n")));
  660. sections.join("\n\n")
  661. }
  662. type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>;
  663. #[derive(Default)]
  664. pub struct StaticToolExecutor {
  665. handlers: BTreeMap<String, ToolHandler>,
  666. }
  667. impl StaticToolExecutor {
  668. #[must_use]
  669. pub fn new() -> Self {
  670. Self::default()
  671. }
  672. #[must_use]
  673. pub fn register(
  674. mut self,
  675. tool_name: impl Into<String>,
  676. handler: impl FnMut(&str) -> Result<String, ToolError> + 'static,
  677. ) -> Self {
  678. self.handlers.insert(tool_name.into(), Box::new(handler));
  679. self
  680. }
  681. }
  682. impl ToolExecutor for StaticToolExecutor {
  683. fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
  684. self.handlers
  685. .get_mut(tool_name)
  686. .ok_or_else(|| ToolError::new(format!("unknown tool: {tool_name}")))?(input)
  687. }
  688. }
  689. #[cfg(test)]
  690. mod tests {
  691. use super::{
  692. build_assistant_message, parse_auto_compaction_threshold, ApiClient, ApiRequest,
  693. AssistantEvent, AutoCompactionEvent, ConversationRuntime, PromptCacheEvent, RuntimeError,
  694. StaticToolExecutor, ToolExecutor, DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD,
  695. };
  696. use crate::compact::CompactionConfig;
  697. use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig};
  698. use crate::permissions::{
  699. PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter,
  700. PermissionRequest,
  701. };
  702. use crate::prompt::{ProjectContext, SystemPromptBuilder};
  703. use crate::session::{ContentBlock, MessageRole, Session};
  704. use crate::usage::TokenUsage;
  705. use crate::ToolError;
  706. use std::fs;
  707. use std::path::PathBuf;
  708. use std::sync::Arc;
  709. use std::time::{SystemTime, UNIX_EPOCH};
  710. use telemetry::{MemoryTelemetrySink, SessionTracer, TelemetryEvent};
  711. struct ScriptedApiClient {
  712. call_count: usize,
  713. }
  714. impl ApiClient for ScriptedApiClient {
  715. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  716. self.call_count += 1;
  717. match self.call_count {
  718. 1 => {
  719. assert!(request
  720. .messages
  721. .iter()
  722. .any(|message| message.role == MessageRole::User));
  723. Ok(vec![
  724. AssistantEvent::TextDelta("Let me calculate that.".to_string()),
  725. AssistantEvent::ToolUse {
  726. id: "tool-1".to_string(),
  727. name: "add".to_string(),
  728. input: "2,2".to_string(),
  729. },
  730. AssistantEvent::Usage(TokenUsage {
  731. input_tokens: 20,
  732. output_tokens: 6,
  733. cache_creation_input_tokens: 1,
  734. cache_read_input_tokens: 2,
  735. }),
  736. AssistantEvent::MessageStop,
  737. ])
  738. }
  739. 2 => {
  740. let last_message = request
  741. .messages
  742. .last()
  743. .expect("tool result should be present");
  744. assert_eq!(last_message.role, MessageRole::Tool);
  745. Ok(vec![
  746. AssistantEvent::TextDelta("The answer is 4.".to_string()),
  747. AssistantEvent::Usage(TokenUsage {
  748. input_tokens: 24,
  749. output_tokens: 4,
  750. cache_creation_input_tokens: 1,
  751. cache_read_input_tokens: 3,
  752. }),
  753. AssistantEvent::PromptCache(PromptCacheEvent {
  754. unexpected: true,
  755. reason:
  756. "cache read tokens dropped while prompt fingerprint remained stable"
  757. .to_string(),
  758. previous_cache_read_input_tokens: 6_000,
  759. current_cache_read_input_tokens: 1_000,
  760. token_drop: 5_000,
  761. }),
  762. AssistantEvent::MessageStop,
  763. ])
  764. }
  765. _ => Err(RuntimeError::new("unexpected extra API call")),
  766. }
  767. }
  768. }
  769. struct PromptAllowOnce;
  770. impl PermissionPrompter for PromptAllowOnce {
  771. fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision {
  772. assert_eq!(request.tool_name, "add");
  773. PermissionPromptDecision::Allow
  774. }
  775. }
  776. #[test]
  777. fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() {
  778. let api_client = ScriptedApiClient { call_count: 0 };
  779. let tool_executor = StaticToolExecutor::new().register("add", |input| {
  780. let total = input
  781. .split(',')
  782. .map(|part| part.parse::<i32>().expect("input must be valid integer"))
  783. .sum::<i32>();
  784. Ok(total.to_string())
  785. });
  786. let permission_policy = PermissionPolicy::new(PermissionMode::WorkspaceWrite);
  787. let system_prompt = SystemPromptBuilder::new()
  788. .with_project_context(ProjectContext {
  789. cwd: PathBuf::from("/tmp/project"),
  790. current_date: "2026-03-31".to_string(),
  791. git_status: None,
  792. git_diff: None,
  793. instruction_files: Vec::new(),
  794. })
  795. .with_os("linux", "6.8")
  796. .build();
  797. let mut runtime = ConversationRuntime::new(
  798. Session::new(),
  799. api_client,
  800. tool_executor,
  801. permission_policy,
  802. system_prompt,
  803. );
  804. let summary = runtime
  805. .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
  806. .expect("conversation loop should succeed");
  807. assert_eq!(summary.iterations, 2);
  808. assert_eq!(summary.assistant_messages.len(), 2);
  809. assert_eq!(summary.tool_results.len(), 1);
  810. assert_eq!(summary.prompt_cache_events.len(), 1);
  811. assert_eq!(runtime.session().messages.len(), 4);
  812. assert_eq!(summary.usage.output_tokens, 10);
  813. assert_eq!(summary.auto_compaction, None);
  814. assert!(matches!(
  815. runtime.session().messages[1].blocks[1],
  816. ContentBlock::ToolUse { .. }
  817. ));
  818. assert!(matches!(
  819. runtime.session().messages[2].blocks[0],
  820. ContentBlock::ToolResult {
  821. is_error: false,
  822. ..
  823. }
  824. ));
  825. }
  826. #[test]
  827. fn records_runtime_session_trace_events() {
  828. let sink = Arc::new(MemoryTelemetrySink::default());
  829. let tracer = SessionTracer::new("session-runtime", sink.clone());
  830. let mut runtime = ConversationRuntime::new(
  831. Session::new(),
  832. ScriptedApiClient { call_count: 0 },
  833. StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
  834. PermissionPolicy::new(PermissionMode::WorkspaceWrite),
  835. vec!["system".to_string()],
  836. )
  837. .with_session_tracer(tracer);
  838. runtime
  839. .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce))
  840. .expect("conversation loop should succeed");
  841. let events = sink.events();
  842. let trace_names = events
  843. .iter()
  844. .filter_map(|event| match event {
  845. TelemetryEvent::SessionTrace(trace) => Some(trace.name.as_str()),
  846. _ => None,
  847. })
  848. .collect::<Vec<_>>();
  849. assert!(trace_names.contains(&"turn_started"));
  850. assert!(trace_names.contains(&"assistant_iteration_completed"));
  851. assert!(trace_names.contains(&"tool_execution_started"));
  852. assert!(trace_names.contains(&"tool_execution_finished"));
  853. assert!(trace_names.contains(&"turn_completed"));
  854. }
  855. #[test]
  856. fn records_denied_tool_results_when_prompt_rejects() {
  857. struct RejectPrompter;
  858. impl PermissionPrompter for RejectPrompter {
  859. fn decide(&mut self, _request: &PermissionRequest) -> PermissionPromptDecision {
  860. PermissionPromptDecision::Deny {
  861. reason: "not now".to_string(),
  862. }
  863. }
  864. }
  865. struct SingleCallApiClient;
  866. impl ApiClient for SingleCallApiClient {
  867. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  868. if request
  869. .messages
  870. .iter()
  871. .any(|message| message.role == MessageRole::Tool)
  872. {
  873. return Ok(vec![
  874. AssistantEvent::TextDelta("I could not use the tool.".to_string()),
  875. AssistantEvent::MessageStop,
  876. ]);
  877. }
  878. Ok(vec![
  879. AssistantEvent::ToolUse {
  880. id: "tool-1".to_string(),
  881. name: "blocked".to_string(),
  882. input: "secret".to_string(),
  883. },
  884. AssistantEvent::MessageStop,
  885. ])
  886. }
  887. }
  888. let mut runtime = ConversationRuntime::new(
  889. Session::new(),
  890. SingleCallApiClient,
  891. StaticToolExecutor::new(),
  892. PermissionPolicy::new(PermissionMode::WorkspaceWrite),
  893. vec!["system".to_string()],
  894. );
  895. let summary = runtime
  896. .run_turn("use the tool", Some(&mut RejectPrompter))
  897. .expect("conversation should continue after denied tool");
  898. assert_eq!(summary.tool_results.len(), 1);
  899. assert!(matches!(
  900. &summary.tool_results[0].blocks[0],
  901. ContentBlock::ToolResult { is_error: true, output, .. } if output == "not now"
  902. ));
  903. }
  904. #[test]
  905. fn denies_tool_use_when_pre_tool_hook_blocks() {
  906. struct SingleCallApiClient;
  907. impl ApiClient for SingleCallApiClient {
  908. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  909. if request
  910. .messages
  911. .iter()
  912. .any(|message| message.role == MessageRole::Tool)
  913. {
  914. return Ok(vec![
  915. AssistantEvent::TextDelta("blocked".to_string()),
  916. AssistantEvent::MessageStop,
  917. ]);
  918. }
  919. Ok(vec![
  920. AssistantEvent::ToolUse {
  921. id: "tool-1".to_string(),
  922. name: "blocked".to_string(),
  923. input: r#"{"path":"secret.txt"}"#.to_string(),
  924. },
  925. AssistantEvent::MessageStop,
  926. ])
  927. }
  928. }
  929. let mut runtime = ConversationRuntime::new_with_features(
  930. Session::new(),
  931. SingleCallApiClient,
  932. StaticToolExecutor::new().register("blocked", |_input| {
  933. panic!("tool should not execute when hook denies")
  934. }),
  935. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  936. vec!["system".to_string()],
  937. &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
  938. vec![shell_snippet("printf 'blocked by hook'; exit 2")],
  939. Vec::new(),
  940. Vec::new(),
  941. )),
  942. );
  943. let summary = runtime
  944. .run_turn("use the tool", None)
  945. .expect("conversation should continue after hook denial");
  946. assert_eq!(summary.tool_results.len(), 1);
  947. let ContentBlock::ToolResult {
  948. is_error, output, ..
  949. } = &summary.tool_results[0].blocks[0]
  950. else {
  951. panic!("expected tool result block");
  952. };
  953. assert!(
  954. *is_error,
  955. "hook denial should produce an error result: {output}"
  956. );
  957. assert!(
  958. output.contains("denied tool") || output.contains("blocked by hook"),
  959. "unexpected hook denial output: {output:?}"
  960. );
  961. }
  962. #[test]
  963. fn denies_tool_use_when_pre_tool_hook_fails() {
  964. struct SingleCallApiClient;
  965. impl ApiClient for SingleCallApiClient {
  966. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  967. if request
  968. .messages
  969. .iter()
  970. .any(|message| message.role == MessageRole::Tool)
  971. {
  972. return Ok(vec![
  973. AssistantEvent::TextDelta("failed".to_string()),
  974. AssistantEvent::MessageStop,
  975. ]);
  976. }
  977. Ok(vec![
  978. AssistantEvent::ToolUse {
  979. id: "tool-1".to_string(),
  980. name: "blocked".to_string(),
  981. input: r#"{"path":"secret.txt"}"#.to_string(),
  982. },
  983. AssistantEvent::MessageStop,
  984. ])
  985. }
  986. }
  987. // given
  988. let mut runtime = ConversationRuntime::new_with_features(
  989. Session::new(),
  990. SingleCallApiClient,
  991. StaticToolExecutor::new().register("blocked", |_input| {
  992. panic!("tool should not execute when hook fails")
  993. }),
  994. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  995. vec!["system".to_string()],
  996. &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
  997. vec![shell_snippet("printf 'broken hook'; exit 1")],
  998. Vec::new(),
  999. Vec::new(),
  1000. )),
  1001. );
  1002. // when
  1003. let summary = runtime
  1004. .run_turn("use the tool", None)
  1005. .expect("conversation should continue after hook failure");
  1006. // then
  1007. assert_eq!(summary.tool_results.len(), 1);
  1008. let ContentBlock::ToolResult {
  1009. is_error, output, ..
  1010. } = &summary.tool_results[0].blocks[0]
  1011. else {
  1012. panic!("expected tool result block");
  1013. };
  1014. assert!(
  1015. *is_error,
  1016. "hook failure should produce an error result: {output}"
  1017. );
  1018. assert!(
  1019. output.contains("exited with status 1") || output.contains("broken hook"),
  1020. "unexpected hook failure output: {output:?}"
  1021. );
  1022. }
  1023. #[test]
  1024. fn appends_post_tool_hook_feedback_to_tool_result() {
  1025. struct TwoCallApiClient {
  1026. calls: usize,
  1027. }
  1028. impl ApiClient for TwoCallApiClient {
  1029. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  1030. self.calls += 1;
  1031. match self.calls {
  1032. 1 => Ok(vec![
  1033. AssistantEvent::ToolUse {
  1034. id: "tool-1".to_string(),
  1035. name: "add".to_string(),
  1036. input: r#"{"lhs":2,"rhs":2}"#.to_string(),
  1037. },
  1038. AssistantEvent::MessageStop,
  1039. ]),
  1040. 2 => {
  1041. assert!(request
  1042. .messages
  1043. .iter()
  1044. .any(|message| message.role == MessageRole::Tool));
  1045. Ok(vec![
  1046. AssistantEvent::TextDelta("done".to_string()),
  1047. AssistantEvent::MessageStop,
  1048. ])
  1049. }
  1050. _ => Err(RuntimeError::new("unexpected extra API call")),
  1051. }
  1052. }
  1053. }
  1054. let mut runtime = ConversationRuntime::new_with_features(
  1055. Session::new(),
  1056. TwoCallApiClient { calls: 0 },
  1057. StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
  1058. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  1059. vec!["system".to_string()],
  1060. &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
  1061. vec![shell_snippet("printf 'pre hook ran'")],
  1062. vec![shell_snippet("printf 'post hook ran'")],
  1063. Vec::new(),
  1064. )),
  1065. );
  1066. let summary = runtime
  1067. .run_turn("use add", None)
  1068. .expect("tool loop succeeds");
  1069. assert_eq!(summary.tool_results.len(), 1);
  1070. let ContentBlock::ToolResult {
  1071. is_error, output, ..
  1072. } = &summary.tool_results[0].blocks[0]
  1073. else {
  1074. panic!("expected tool result block");
  1075. };
  1076. assert!(
  1077. !*is_error,
  1078. "post hook should preserve non-error result: {output:?}"
  1079. );
  1080. assert!(
  1081. output.contains('4'),
  1082. "tool output missing value: {output:?}"
  1083. );
  1084. assert!(
  1085. output.contains("pre hook ran"),
  1086. "tool output missing pre hook feedback: {output:?}"
  1087. );
  1088. assert!(
  1089. output.contains("post hook ran"),
  1090. "tool output missing post hook feedback: {output:?}"
  1091. );
  1092. }
  1093. #[test]
  1094. fn appends_post_tool_use_failure_hook_feedback_to_tool_result() {
  1095. struct TwoCallApiClient {
  1096. calls: usize,
  1097. }
  1098. impl ApiClient for TwoCallApiClient {
  1099. fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
  1100. self.calls += 1;
  1101. match self.calls {
  1102. 1 => Ok(vec![
  1103. AssistantEvent::ToolUse {
  1104. id: "tool-1".to_string(),
  1105. name: "fail".to_string(),
  1106. input: r#"{"path":"README.md"}"#.to_string(),
  1107. },
  1108. AssistantEvent::MessageStop,
  1109. ]),
  1110. 2 => {
  1111. assert!(request
  1112. .messages
  1113. .iter()
  1114. .any(|message| message.role == MessageRole::Tool));
  1115. Ok(vec![
  1116. AssistantEvent::TextDelta("done".to_string()),
  1117. AssistantEvent::MessageStop,
  1118. ])
  1119. }
  1120. _ => Err(RuntimeError::new("unexpected extra API call")),
  1121. }
  1122. }
  1123. }
  1124. // given
  1125. let mut runtime = ConversationRuntime::new_with_features(
  1126. Session::new(),
  1127. TwoCallApiClient { calls: 0 },
  1128. StaticToolExecutor::new()
  1129. .register("fail", |_input| Err(ToolError::new("tool exploded"))),
  1130. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  1131. vec!["system".to_string()],
  1132. &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
  1133. Vec::new(),
  1134. vec![shell_snippet("printf 'post hook should not run'")],
  1135. vec![shell_snippet("printf 'failure hook ran'")],
  1136. )),
  1137. );
  1138. // when
  1139. let summary = runtime
  1140. .run_turn("use fail", None)
  1141. .expect("tool loop succeeds");
  1142. // then
  1143. assert_eq!(summary.tool_results.len(), 1);
  1144. let ContentBlock::ToolResult {
  1145. is_error, output, ..
  1146. } = &summary.tool_results[0].blocks[0]
  1147. else {
  1148. panic!("expected tool result block");
  1149. };
  1150. assert!(
  1151. *is_error,
  1152. "failure hook path should preserve error result: {output:?}"
  1153. );
  1154. assert!(
  1155. output.contains("tool exploded"),
  1156. "tool output missing failure reason: {output:?}"
  1157. );
  1158. assert!(
  1159. output.contains("failure hook ran"),
  1160. "tool output missing failure hook feedback: {output:?}"
  1161. );
  1162. assert!(
  1163. !output.contains("post hook should not run"),
  1164. "normal post hook should not run on tool failure: {output:?}"
  1165. );
  1166. }
  1167. #[test]
  1168. fn reconstructs_usage_tracker_from_restored_session() {
  1169. struct SimpleApi;
  1170. impl ApiClient for SimpleApi {
  1171. fn stream(
  1172. &mut self,
  1173. _request: ApiRequest,
  1174. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  1175. Ok(vec![
  1176. AssistantEvent::TextDelta("done".to_string()),
  1177. AssistantEvent::MessageStop,
  1178. ])
  1179. }
  1180. }
  1181. let mut session = Session::new();
  1182. session
  1183. .messages
  1184. .push(crate::session::ConversationMessage::assistant_with_usage(
  1185. vec![ContentBlock::Text {
  1186. text: "earlier".to_string(),
  1187. }],
  1188. Some(TokenUsage {
  1189. input_tokens: 11,
  1190. output_tokens: 7,
  1191. cache_creation_input_tokens: 2,
  1192. cache_read_input_tokens: 1,
  1193. }),
  1194. ));
  1195. let runtime = ConversationRuntime::new(
  1196. session,
  1197. SimpleApi,
  1198. StaticToolExecutor::new(),
  1199. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  1200. vec!["system".to_string()],
  1201. );
  1202. assert_eq!(runtime.usage().turns(), 1);
  1203. assert_eq!(runtime.usage().cumulative_usage().total_tokens(), 21);
  1204. }
  1205. #[test]
  1206. fn compacts_session_after_turns() {
  1207. struct SimpleApi;
  1208. impl ApiClient for SimpleApi {
  1209. fn stream(
  1210. &mut self,
  1211. _request: ApiRequest,
  1212. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  1213. Ok(vec![
  1214. AssistantEvent::TextDelta("done".to_string()),
  1215. AssistantEvent::MessageStop,
  1216. ])
  1217. }
  1218. }
  1219. let mut runtime = ConversationRuntime::new(
  1220. Session::new(),
  1221. SimpleApi,
  1222. StaticToolExecutor::new(),
  1223. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  1224. vec!["system".to_string()],
  1225. );
  1226. runtime.run_turn("a", None).expect("turn a");
  1227. runtime.run_turn("b", None).expect("turn b");
  1228. runtime.run_turn("c", None).expect("turn c");
  1229. let result = runtime.compact(CompactionConfig {
  1230. preserve_recent_messages: 2,
  1231. max_estimated_tokens: 1,
  1232. });
  1233. assert!(result.summary.contains("Conversation summary"));
  1234. assert_eq!(
  1235. result.compacted_session.messages[0].role,
  1236. MessageRole::System
  1237. );
  1238. assert_eq!(
  1239. result.compacted_session.session_id,
  1240. runtime.session().session_id
  1241. );
  1242. assert!(result.compacted_session.compaction.is_some());
  1243. }
  1244. #[test]
  1245. fn persists_conversation_turn_messages_to_jsonl_session() {
  1246. struct SimpleApi;
  1247. impl ApiClient for SimpleApi {
  1248. fn stream(
  1249. &mut self,
  1250. _request: ApiRequest,
  1251. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  1252. Ok(vec![
  1253. AssistantEvent::TextDelta("done".to_string()),
  1254. AssistantEvent::MessageStop,
  1255. ])
  1256. }
  1257. }
  1258. let path = temp_session_path("persisted-turn");
  1259. let session = Session::new().with_persistence_path(path.clone());
  1260. let mut runtime = ConversationRuntime::new(
  1261. session,
  1262. SimpleApi,
  1263. StaticToolExecutor::new(),
  1264. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  1265. vec!["system".to_string()],
  1266. );
  1267. runtime
  1268. .run_turn("persist this turn", None)
  1269. .expect("turn should succeed");
  1270. let restored = Session::load_from_path(&path).expect("persisted session should reload");
  1271. fs::remove_file(&path).expect("temp session file should be removable");
  1272. assert_eq!(restored.messages.len(), 2);
  1273. assert_eq!(restored.messages[0].role, MessageRole::User);
  1274. assert_eq!(restored.messages[1].role, MessageRole::Assistant);
  1275. assert_eq!(restored.session_id, runtime.session().session_id);
  1276. }
  1277. #[test]
  1278. fn forks_runtime_session_without_mutating_original() {
  1279. let mut session = Session::new();
  1280. session
  1281. .push_user_text("branch me")
  1282. .expect("message should append");
  1283. let runtime = ConversationRuntime::new(
  1284. session.clone(),
  1285. ScriptedApiClient { call_count: 0 },
  1286. StaticToolExecutor::new(),
  1287. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  1288. vec!["system".to_string()],
  1289. );
  1290. let forked = runtime.fork_session(Some("alt-path".to_string()));
  1291. assert_eq!(forked.messages, session.messages);
  1292. assert_ne!(forked.session_id, session.session_id);
  1293. assert_eq!(
  1294. forked
  1295. .fork
  1296. .as_ref()
  1297. .map(|fork| (fork.parent_session_id.as_str(), fork.branch_name.as_deref())),
  1298. Some((session.session_id.as_str(), Some("alt-path")))
  1299. );
  1300. assert!(runtime.session().fork.is_none());
  1301. }
  1302. fn temp_session_path(label: &str) -> PathBuf {
  1303. let nanos = SystemTime::now()
  1304. .duration_since(UNIX_EPOCH)
  1305. .expect("system time should be after epoch")
  1306. .as_nanos();
  1307. std::env::temp_dir().join(format!("runtime-conversation-{label}-{nanos}.json"))
  1308. }
  1309. #[cfg(windows)]
  1310. fn shell_snippet(script: &str) -> String {
  1311. script.replace('\'', "\"")
  1312. }
  1313. #[cfg(not(windows))]
  1314. fn shell_snippet(script: &str) -> String {
  1315. script.to_string()
  1316. }
  1317. #[test]
  1318. fn auto_compacts_when_cumulative_input_threshold_is_crossed() {
  1319. struct SimpleApi;
  1320. impl ApiClient for SimpleApi {
  1321. fn stream(
  1322. &mut self,
  1323. _request: ApiRequest,
  1324. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  1325. Ok(vec![
  1326. AssistantEvent::TextDelta("done".to_string()),
  1327. AssistantEvent::Usage(TokenUsage {
  1328. input_tokens: 120_000,
  1329. output_tokens: 4,
  1330. cache_creation_input_tokens: 0,
  1331. cache_read_input_tokens: 0,
  1332. }),
  1333. AssistantEvent::MessageStop,
  1334. ])
  1335. }
  1336. }
  1337. let mut session = Session::new();
  1338. session.messages = vec![
  1339. crate::session::ConversationMessage::user_text("one"),
  1340. crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
  1341. text: "two".to_string(),
  1342. }]),
  1343. crate::session::ConversationMessage::user_text("three"),
  1344. crate::session::ConversationMessage::assistant(vec![ContentBlock::Text {
  1345. text: "four".to_string(),
  1346. }]),
  1347. ];
  1348. let mut runtime = ConversationRuntime::new(
  1349. session,
  1350. SimpleApi,
  1351. StaticToolExecutor::new(),
  1352. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  1353. vec!["system".to_string()],
  1354. )
  1355. .with_auto_compaction_input_tokens_threshold(100_000);
  1356. let summary = runtime
  1357. .run_turn("trigger", None)
  1358. .expect("turn should succeed");
  1359. assert_eq!(
  1360. summary.auto_compaction,
  1361. Some(AutoCompactionEvent {
  1362. removed_message_count: 2,
  1363. })
  1364. );
  1365. assert_eq!(runtime.session().messages[0].role, MessageRole::System);
  1366. }
  1367. #[test]
  1368. fn skips_auto_compaction_below_threshold() {
  1369. struct SimpleApi;
  1370. impl ApiClient for SimpleApi {
  1371. fn stream(
  1372. &mut self,
  1373. _request: ApiRequest,
  1374. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  1375. Ok(vec![
  1376. AssistantEvent::TextDelta("done".to_string()),
  1377. AssistantEvent::Usage(TokenUsage {
  1378. input_tokens: 99_999,
  1379. output_tokens: 4,
  1380. cache_creation_input_tokens: 0,
  1381. cache_read_input_tokens: 0,
  1382. }),
  1383. AssistantEvent::MessageStop,
  1384. ])
  1385. }
  1386. }
  1387. let mut runtime = ConversationRuntime::new(
  1388. Session::new(),
  1389. SimpleApi,
  1390. StaticToolExecutor::new(),
  1391. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  1392. vec!["system".to_string()],
  1393. )
  1394. .with_auto_compaction_input_tokens_threshold(100_000);
  1395. let summary = runtime
  1396. .run_turn("trigger", None)
  1397. .expect("turn should succeed");
  1398. assert_eq!(summary.auto_compaction, None);
  1399. assert_eq!(runtime.session().messages.len(), 2);
  1400. }
  1401. #[test]
  1402. fn auto_compaction_threshold_defaults_and_parses_values() {
  1403. // given / when / then
  1404. assert_eq!(
  1405. parse_auto_compaction_threshold(None),
  1406. DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
  1407. );
  1408. assert_eq!(parse_auto_compaction_threshold(Some("4321")), 4321);
  1409. assert_eq!(
  1410. parse_auto_compaction_threshold(Some("0")),
  1411. DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
  1412. );
  1413. assert_eq!(
  1414. parse_auto_compaction_threshold(Some("not-a-number")),
  1415. DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD
  1416. );
  1417. }
  1418. #[test]
  1419. fn build_assistant_message_requires_message_stop_event() {
  1420. // given
  1421. let events = vec![AssistantEvent::TextDelta("hello".to_string())];
  1422. // when
  1423. let error = build_assistant_message(events)
  1424. .expect_err("assistant messages should require a stop event");
  1425. // then
  1426. assert!(error
  1427. .to_string()
  1428. .contains("assistant stream ended without a message stop event"));
  1429. }
  1430. #[test]
  1431. fn build_assistant_message_requires_content() {
  1432. // given
  1433. let events = vec![AssistantEvent::MessageStop];
  1434. // when
  1435. let error =
  1436. build_assistant_message(events).expect_err("assistant messages should require content");
  1437. // then
  1438. assert!(error
  1439. .to_string()
  1440. .contains("assistant stream produced no content"));
  1441. }
  1442. #[test]
  1443. fn static_tool_executor_rejects_unknown_tools() {
  1444. // given
  1445. let mut executor = StaticToolExecutor::new();
  1446. // when
  1447. let error = executor
  1448. .execute("missing", "{}")
  1449. .expect_err("unregistered tools should fail");
  1450. // then
  1451. assert_eq!(error.to_string(), "unknown tool: missing");
  1452. }
  1453. #[test]
  1454. fn run_turn_errors_when_max_iterations_is_exceeded() {
  1455. struct LoopingApi;
  1456. impl ApiClient for LoopingApi {
  1457. fn stream(
  1458. &mut self,
  1459. _request: ApiRequest,
  1460. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  1461. Ok(vec![
  1462. AssistantEvent::ToolUse {
  1463. id: "tool-1".to_string(),
  1464. name: "echo".to_string(),
  1465. input: "payload".to_string(),
  1466. },
  1467. AssistantEvent::MessageStop,
  1468. ])
  1469. }
  1470. }
  1471. // given
  1472. let mut runtime = ConversationRuntime::new(
  1473. Session::new(),
  1474. LoopingApi,
  1475. StaticToolExecutor::new().register("echo", |input| Ok(input.to_string())),
  1476. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  1477. vec!["system".to_string()],
  1478. )
  1479. .with_max_iterations(1);
  1480. // when
  1481. let error = runtime
  1482. .run_turn("loop", None)
  1483. .expect_err("conversation loop should stop after the configured limit");
  1484. // then
  1485. assert!(error
  1486. .to_string()
  1487. .contains("conversation loop exceeded the maximum number of iterations"));
  1488. }
  1489. #[test]
  1490. fn run_turn_propagates_api_errors() {
  1491. struct FailingApi;
  1492. impl ApiClient for FailingApi {
  1493. fn stream(
  1494. &mut self,
  1495. _request: ApiRequest,
  1496. ) -> Result<Vec<AssistantEvent>, RuntimeError> {
  1497. Err(RuntimeError::new("upstream failed"))
  1498. }
  1499. }
  1500. // given
  1501. let mut runtime = ConversationRuntime::new(
  1502. Session::new(),
  1503. FailingApi,
  1504. StaticToolExecutor::new(),
  1505. PermissionPolicy::new(PermissionMode::DangerFullAccess),
  1506. vec!["system".to_string()],
  1507. );
  1508. // when
  1509. let error = runtime
  1510. .run_turn("hello", None)
  1511. .expect_err("API failures should propagate");
  1512. // then
  1513. assert_eq!(error.to_string(), "upstream failed");
  1514. }
  1515. }