prompt_cache.rs 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734
  1. use std::fs;
  2. use std::path::{Path, PathBuf};
  3. use std::sync::{Arc, Mutex};
  4. use std::time::{Duration, SystemTime, UNIX_EPOCH};
  5. use serde::{Deserialize, Serialize};
  6. use crate::types::{MessageRequest, MessageResponse, Usage};
  7. const DEFAULT_COMPLETION_TTL_SECS: u64 = 30;
  8. const DEFAULT_PROMPT_TTL_SECS: u64 = 5 * 60;
  9. const DEFAULT_BREAK_MIN_DROP: u32 = 2_000;
  10. const MAX_SANITIZED_LENGTH: usize = 80;
  11. const REQUEST_FINGERPRINT_VERSION: u32 = 1;
  12. const REQUEST_FINGERPRINT_PREFIX: &str = "v1";
  13. const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
  14. const FNV_PRIME: u64 = 0x0000_0100_0000_01b3;
  15. #[derive(Debug, Clone)]
  16. pub struct PromptCacheConfig {
  17. pub session_id: String,
  18. pub completion_ttl: Duration,
  19. pub prompt_ttl: Duration,
  20. pub cache_break_min_drop: u32,
  21. }
  22. impl PromptCacheConfig {
  23. #[must_use]
  24. pub fn new(session_id: impl Into<String>) -> Self {
  25. Self {
  26. session_id: session_id.into(),
  27. completion_ttl: Duration::from_secs(DEFAULT_COMPLETION_TTL_SECS),
  28. prompt_ttl: Duration::from_secs(DEFAULT_PROMPT_TTL_SECS),
  29. cache_break_min_drop: DEFAULT_BREAK_MIN_DROP,
  30. }
  31. }
  32. }
  33. impl Default for PromptCacheConfig {
  34. fn default() -> Self {
  35. Self::new("default")
  36. }
  37. }
  38. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  39. pub struct PromptCachePaths {
  40. pub root: PathBuf,
  41. pub session_dir: PathBuf,
  42. pub completion_dir: PathBuf,
  43. pub session_state_path: PathBuf,
  44. pub stats_path: PathBuf,
  45. }
  46. impl PromptCachePaths {
  47. #[must_use]
  48. pub fn for_session(session_id: &str) -> Self {
  49. let root = base_cache_root();
  50. let session_dir = root.join(sanitize_path_segment(session_id));
  51. let completion_dir = session_dir.join("completions");
  52. Self {
  53. root,
  54. session_state_path: session_dir.join("session-state.json"),
  55. stats_path: session_dir.join("stats.json"),
  56. session_dir,
  57. completion_dir,
  58. }
  59. }
  60. #[must_use]
  61. pub fn completion_entry_path(&self, request_hash: &str) -> PathBuf {
  62. self.completion_dir.join(format!("{request_hash}.json"))
  63. }
  64. }
  65. #[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
  66. pub struct PromptCacheStats {
  67. pub tracked_requests: u64,
  68. pub completion_cache_hits: u64,
  69. pub completion_cache_misses: u64,
  70. pub completion_cache_writes: u64,
  71. pub expected_invalidations: u64,
  72. pub unexpected_cache_breaks: u64,
  73. pub total_cache_creation_input_tokens: u64,
  74. pub total_cache_read_input_tokens: u64,
  75. pub last_cache_creation_input_tokens: Option<u32>,
  76. pub last_cache_read_input_tokens: Option<u32>,
  77. pub last_request_hash: Option<String>,
  78. pub last_completion_cache_key: Option<String>,
  79. pub last_break_reason: Option<String>,
  80. pub last_cache_source: Option<String>,
  81. }
  82. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  83. pub struct CacheBreakEvent {
  84. pub unexpected: bool,
  85. pub reason: String,
  86. pub previous_cache_read_input_tokens: u32,
  87. pub current_cache_read_input_tokens: u32,
  88. pub token_drop: u32,
  89. }
  90. #[derive(Debug, Clone, PartialEq, Eq)]
  91. pub struct PromptCacheRecord {
  92. pub cache_break: Option<CacheBreakEvent>,
  93. pub stats: PromptCacheStats,
  94. }
  95. #[derive(Debug, Clone)]
  96. pub struct PromptCache {
  97. inner: Arc<Mutex<PromptCacheInner>>,
  98. }
  99. impl PromptCache {
  100. #[must_use]
  101. pub fn new(session_id: impl Into<String>) -> Self {
  102. Self::with_config(PromptCacheConfig::new(session_id))
  103. }
  104. #[must_use]
  105. pub fn with_config(config: PromptCacheConfig) -> Self {
  106. let paths = PromptCachePaths::for_session(&config.session_id);
  107. let stats = read_json::<PromptCacheStats>(&paths.stats_path).unwrap_or_default();
  108. let previous = read_json::<TrackedPromptState>(&paths.session_state_path);
  109. Self {
  110. inner: Arc::new(Mutex::new(PromptCacheInner {
  111. config,
  112. paths,
  113. stats,
  114. previous,
  115. })),
  116. }
  117. }
  118. #[must_use]
  119. pub fn paths(&self) -> PromptCachePaths {
  120. self.lock().paths.clone()
  121. }
  122. #[must_use]
  123. pub fn stats(&self) -> PromptCacheStats {
  124. self.lock().stats.clone()
  125. }
  126. #[must_use]
  127. pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> {
  128. let request_hash = request_hash_hex(request);
  129. let (paths, ttl) = {
  130. let inner = self.lock();
  131. (inner.paths.clone(), inner.config.completion_ttl)
  132. };
  133. let entry_path = paths.completion_entry_path(&request_hash);
  134. let entry = read_json::<CompletionCacheEntry>(&entry_path);
  135. let Some(entry) = entry else {
  136. let mut inner = self.lock();
  137. inner.stats.completion_cache_misses += 1;
  138. inner.stats.last_completion_cache_key = Some(request_hash);
  139. persist_state(&inner);
  140. return None;
  141. };
  142. if entry.fingerprint_version != current_fingerprint_version() {
  143. let mut inner = self.lock();
  144. inner.stats.completion_cache_misses += 1;
  145. inner.stats.last_completion_cache_key = Some(request_hash.clone());
  146. let _ = fs::remove_file(entry_path);
  147. persist_state(&inner);
  148. return None;
  149. }
  150. let expired = now_unix_secs().saturating_sub(entry.cached_at_unix_secs) >= ttl.as_secs();
  151. let mut inner = self.lock();
  152. inner.stats.last_completion_cache_key = Some(request_hash.clone());
  153. if expired {
  154. inner.stats.completion_cache_misses += 1;
  155. let _ = fs::remove_file(entry_path);
  156. persist_state(&inner);
  157. return None;
  158. }
  159. inner.stats.completion_cache_hits += 1;
  160. apply_usage_to_stats(
  161. &mut inner.stats,
  162. &entry.response.usage,
  163. &request_hash,
  164. "completion-cache",
  165. );
  166. inner.previous = Some(TrackedPromptState::from_usage(
  167. request,
  168. &entry.response.usage,
  169. ));
  170. persist_state(&inner);
  171. Some(entry.response)
  172. }
  173. #[must_use]
  174. pub fn record_response(
  175. &self,
  176. request: &MessageRequest,
  177. response: &MessageResponse,
  178. ) -> PromptCacheRecord {
  179. self.record_usage_internal(request, &response.usage, Some(response))
  180. }
  181. #[must_use]
  182. pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord {
  183. self.record_usage_internal(request, usage, None)
  184. }
  185. fn record_usage_internal(
  186. &self,
  187. request: &MessageRequest,
  188. usage: &Usage,
  189. response: Option<&MessageResponse>,
  190. ) -> PromptCacheRecord {
  191. let request_hash = request_hash_hex(request);
  192. let mut inner = self.lock();
  193. let previous = inner.previous.clone();
  194. let current = TrackedPromptState::from_usage(request, usage);
  195. let cache_break = detect_cache_break(&inner.config, previous.as_ref(), &current);
  196. inner.stats.tracked_requests += 1;
  197. apply_usage_to_stats(&mut inner.stats, usage, &request_hash, "api-response");
  198. if let Some(event) = &cache_break {
  199. if event.unexpected {
  200. inner.stats.unexpected_cache_breaks += 1;
  201. } else {
  202. inner.stats.expected_invalidations += 1;
  203. }
  204. inner.stats.last_break_reason = Some(event.reason.clone());
  205. }
  206. inner.previous = Some(current);
  207. if let Some(response) = response {
  208. write_completion_entry(&inner.paths, &request_hash, response);
  209. inner.stats.completion_cache_writes += 1;
  210. }
  211. persist_state(&inner);
  212. PromptCacheRecord {
  213. cache_break,
  214. stats: inner.stats.clone(),
  215. }
  216. }
  217. fn lock(&self) -> std::sync::MutexGuard<'_, PromptCacheInner> {
  218. self.inner
  219. .lock()
  220. .unwrap_or_else(std::sync::PoisonError::into_inner)
  221. }
  222. }
  223. #[derive(Debug)]
  224. struct PromptCacheInner {
  225. config: PromptCacheConfig,
  226. paths: PromptCachePaths,
  227. stats: PromptCacheStats,
  228. previous: Option<TrackedPromptState>,
  229. }
  230. #[derive(Debug, Clone, Serialize, Deserialize)]
  231. struct CompletionCacheEntry {
  232. cached_at_unix_secs: u64,
  233. #[serde(default = "current_fingerprint_version")]
  234. fingerprint_version: u32,
  235. response: MessageResponse,
  236. }
  237. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  238. struct TrackedPromptState {
  239. observed_at_unix_secs: u64,
  240. #[serde(default = "current_fingerprint_version")]
  241. fingerprint_version: u32,
  242. model_hash: u64,
  243. system_hash: u64,
  244. tools_hash: u64,
  245. messages_hash: u64,
  246. cache_read_input_tokens: u32,
  247. }
  248. impl TrackedPromptState {
  249. fn from_usage(request: &MessageRequest, usage: &Usage) -> Self {
  250. let hashes = RequestFingerprints::from_request(request);
  251. Self {
  252. observed_at_unix_secs: now_unix_secs(),
  253. fingerprint_version: current_fingerprint_version(),
  254. model_hash: hashes.model,
  255. system_hash: hashes.system,
  256. tools_hash: hashes.tools,
  257. messages_hash: hashes.messages,
  258. cache_read_input_tokens: usage.cache_read_input_tokens,
  259. }
  260. }
  261. }
  262. #[derive(Debug, Clone, Copy)]
  263. struct RequestFingerprints {
  264. model: u64,
  265. system: u64,
  266. tools: u64,
  267. messages: u64,
  268. }
  269. impl RequestFingerprints {
  270. fn from_request(request: &MessageRequest) -> Self {
  271. Self {
  272. model: hash_serializable(&request.model),
  273. system: hash_serializable(&request.system),
  274. tools: hash_serializable(&request.tools),
  275. messages: hash_serializable(&request.messages),
  276. }
  277. }
  278. }
  279. fn detect_cache_break(
  280. config: &PromptCacheConfig,
  281. previous: Option<&TrackedPromptState>,
  282. current: &TrackedPromptState,
  283. ) -> Option<CacheBreakEvent> {
  284. let previous = previous?;
  285. if previous.fingerprint_version != current.fingerprint_version {
  286. return Some(CacheBreakEvent {
  287. unexpected: false,
  288. reason: format!(
  289. "fingerprint version changed (v{} -> v{})",
  290. previous.fingerprint_version, current.fingerprint_version
  291. ),
  292. previous_cache_read_input_tokens: previous.cache_read_input_tokens,
  293. current_cache_read_input_tokens: current.cache_read_input_tokens,
  294. token_drop: previous
  295. .cache_read_input_tokens
  296. .saturating_sub(current.cache_read_input_tokens),
  297. });
  298. }
  299. let token_drop = previous
  300. .cache_read_input_tokens
  301. .saturating_sub(current.cache_read_input_tokens);
  302. if token_drop < config.cache_break_min_drop {
  303. return None;
  304. }
  305. let mut reasons = Vec::new();
  306. if previous.model_hash != current.model_hash {
  307. reasons.push("model changed");
  308. }
  309. if previous.system_hash != current.system_hash {
  310. reasons.push("system prompt changed");
  311. }
  312. if previous.tools_hash != current.tools_hash {
  313. reasons.push("tool definitions changed");
  314. }
  315. if previous.messages_hash != current.messages_hash {
  316. reasons.push("message payload changed");
  317. }
  318. let elapsed = current
  319. .observed_at_unix_secs
  320. .saturating_sub(previous.observed_at_unix_secs);
  321. let (unexpected, reason) = if reasons.is_empty() {
  322. if elapsed > config.prompt_ttl.as_secs() {
  323. (
  324. false,
  325. format!("possible prompt cache TTL expiry after {elapsed}s"),
  326. )
  327. } else {
  328. (
  329. true,
  330. "cache read tokens dropped while prompt fingerprint remained stable".to_string(),
  331. )
  332. }
  333. } else {
  334. (false, reasons.join(", "))
  335. };
  336. Some(CacheBreakEvent {
  337. unexpected,
  338. reason,
  339. previous_cache_read_input_tokens: previous.cache_read_input_tokens,
  340. current_cache_read_input_tokens: current.cache_read_input_tokens,
  341. token_drop,
  342. })
  343. }
  344. fn apply_usage_to_stats(
  345. stats: &mut PromptCacheStats,
  346. usage: &Usage,
  347. request_hash: &str,
  348. source: &str,
  349. ) {
  350. stats.total_cache_creation_input_tokens += u64::from(usage.cache_creation_input_tokens);
  351. stats.total_cache_read_input_tokens += u64::from(usage.cache_read_input_tokens);
  352. stats.last_cache_creation_input_tokens = Some(usage.cache_creation_input_tokens);
  353. stats.last_cache_read_input_tokens = Some(usage.cache_read_input_tokens);
  354. stats.last_request_hash = Some(request_hash.to_string());
  355. stats.last_cache_source = Some(source.to_string());
  356. }
  357. fn persist_state(inner: &PromptCacheInner) {
  358. let _ = ensure_cache_dirs(&inner.paths);
  359. let _ = write_json(&inner.paths.stats_path, &inner.stats);
  360. if let Some(previous) = &inner.previous {
  361. let _ = write_json(&inner.paths.session_state_path, previous);
  362. }
  363. }
  364. fn write_completion_entry(
  365. paths: &PromptCachePaths,
  366. request_hash: &str,
  367. response: &MessageResponse,
  368. ) {
  369. let _ = ensure_cache_dirs(paths);
  370. let entry = CompletionCacheEntry {
  371. cached_at_unix_secs: now_unix_secs(),
  372. fingerprint_version: current_fingerprint_version(),
  373. response: response.clone(),
  374. };
  375. let _ = write_json(&paths.completion_entry_path(request_hash), &entry);
  376. }
  377. fn ensure_cache_dirs(paths: &PromptCachePaths) -> std::io::Result<()> {
  378. fs::create_dir_all(&paths.completion_dir)
  379. }
  380. fn write_json<T: Serialize>(path: &Path, value: &T) -> std::io::Result<()> {
  381. let json = serde_json::to_vec_pretty(value)
  382. .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?;
  383. fs::write(path, json)
  384. }
  385. fn read_json<T: for<'de> Deserialize<'de>>(path: &Path) -> Option<T> {
  386. let bytes = fs::read(path).ok()?;
  387. serde_json::from_slice(&bytes).ok()
  388. }
  389. fn request_hash_hex(request: &MessageRequest) -> String {
  390. format!(
  391. "{REQUEST_FINGERPRINT_PREFIX}-{:016x}",
  392. hash_serializable(request)
  393. )
  394. }
  395. fn hash_serializable<T: Serialize>(value: &T) -> u64 {
  396. let json = serde_json::to_vec(value).unwrap_or_default();
  397. stable_hash_bytes(&json)
  398. }
  399. fn sanitize_path_segment(value: &str) -> String {
  400. let sanitized: String = value
  401. .chars()
  402. .map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '-' })
  403. .collect();
  404. if sanitized.len() <= MAX_SANITIZED_LENGTH {
  405. return sanitized;
  406. }
  407. let suffix = format!("-{:x}", hash_string(value));
  408. format!(
  409. "{}{}",
  410. &sanitized[..MAX_SANITIZED_LENGTH.saturating_sub(suffix.len())],
  411. suffix
  412. )
  413. }
  414. fn hash_string(value: &str) -> u64 {
  415. stable_hash_bytes(value.as_bytes())
  416. }
  417. fn base_cache_root() -> PathBuf {
  418. if let Some(config_home) = std::env::var_os("CLAUDE_CONFIG_HOME") {
  419. return PathBuf::from(config_home)
  420. .join("cache")
  421. .join("prompt-cache");
  422. }
  423. if let Some(home) = std::env::var_os("HOME") {
  424. return PathBuf::from(home)
  425. .join(".claude")
  426. .join("cache")
  427. .join("prompt-cache");
  428. }
  429. std::env::temp_dir().join("claude-prompt-cache")
  430. }
  431. fn now_unix_secs() -> u64 {
  432. SystemTime::now()
  433. .duration_since(UNIX_EPOCH)
  434. .map_or(0, |duration| duration.as_secs())
  435. }
  436. const fn current_fingerprint_version() -> u32 {
  437. REQUEST_FINGERPRINT_VERSION
  438. }
  439. fn stable_hash_bytes(bytes: &[u8]) -> u64 {
  440. let mut hash = FNV_OFFSET_BASIS;
  441. for byte in bytes {
  442. hash ^= u64::from(*byte);
  443. hash = hash.wrapping_mul(FNV_PRIME);
  444. }
  445. hash
  446. }
  447. #[cfg(test)]
  448. mod tests {
  449. use std::sync::{Mutex, OnceLock};
  450. use std::time::{Duration, SystemTime, UNIX_EPOCH};
  451. use super::{
  452. detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache,
  453. PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX,
  454. };
  455. use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage};
  456. fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
  457. static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
  458. LOCK.get_or_init(|| Mutex::new(()))
  459. .lock()
  460. .unwrap_or_else(std::sync::PoisonError::into_inner)
  461. }
  462. #[test]
  463. fn path_builder_sanitizes_session_identifier() {
  464. let paths = PromptCachePaths::for_session("session:/with spaces");
  465. let session_dir = paths
  466. .session_dir
  467. .file_name()
  468. .and_then(|value| value.to_str())
  469. .expect("session dir name");
  470. assert_eq!(session_dir, "session--with-spaces");
  471. assert!(paths.completion_dir.ends_with("completions"));
  472. assert!(paths.stats_path.ends_with("stats.json"));
  473. assert!(paths.session_state_path.ends_with("session-state.json"));
  474. }
  475. #[test]
  476. fn request_fingerprint_drives_unexpected_break_detection() {
  477. let request = sample_request("same");
  478. let previous = TrackedPromptState::from_usage(
  479. &request,
  480. &Usage {
  481. input_tokens: 0,
  482. cache_creation_input_tokens: 0,
  483. cache_read_input_tokens: 6_000,
  484. output_tokens: 0,
  485. },
  486. );
  487. let current = TrackedPromptState::from_usage(
  488. &request,
  489. &Usage {
  490. input_tokens: 0,
  491. cache_creation_input_tokens: 0,
  492. cache_read_input_tokens: 1_000,
  493. output_tokens: 0,
  494. },
  495. );
  496. let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), &current)
  497. .expect("break should be detected");
  498. assert!(event.unexpected);
  499. assert!(event.reason.contains("stable"));
  500. }
  501. #[test]
  502. fn changed_prompt_marks_break_as_expected() {
  503. let previous_request = sample_request("first");
  504. let current_request = sample_request("second");
  505. let previous = TrackedPromptState::from_usage(
  506. &previous_request,
  507. &Usage {
  508. input_tokens: 0,
  509. cache_creation_input_tokens: 0,
  510. cache_read_input_tokens: 6_000,
  511. output_tokens: 0,
  512. },
  513. );
  514. let current = TrackedPromptState::from_usage(
  515. &current_request,
  516. &Usage {
  517. input_tokens: 0,
  518. cache_creation_input_tokens: 0,
  519. cache_read_input_tokens: 1_000,
  520. output_tokens: 0,
  521. },
  522. );
  523. let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), &current)
  524. .expect("break should be detected");
  525. assert!(!event.unexpected);
  526. assert!(event.reason.contains("message payload changed"));
  527. }
  528. #[test]
  529. fn completion_cache_round_trip_persists_recent_response() {
  530. let _guard = test_env_lock();
  531. let temp_root = std::env::temp_dir().join(format!(
  532. "prompt-cache-test-{}-{}",
  533. std::process::id(),
  534. SystemTime::now()
  535. .duration_since(UNIX_EPOCH)
  536. .expect("time")
  537. .as_nanos()
  538. ));
  539. std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
  540. let cache = PromptCache::new("unit-test-session");
  541. let request = sample_request("cache me");
  542. let response = sample_response(42, 12, "cached");
  543. assert!(cache.lookup_completion(&request).is_none());
  544. let record = cache.record_response(&request, &response);
  545. assert!(record.cache_break.is_none());
  546. let cached = cache
  547. .lookup_completion(&request)
  548. .expect("cached response should load");
  549. assert_eq!(cached.content, response.content);
  550. let stats = cache.stats();
  551. assert_eq!(stats.completion_cache_hits, 1);
  552. assert_eq!(stats.completion_cache_misses, 1);
  553. assert_eq!(stats.completion_cache_writes, 1);
  554. let persisted = read_json::<super::PromptCacheStats>(&cache.paths().stats_path)
  555. .expect("stats should persist");
  556. assert_eq!(persisted.completion_cache_hits, 1);
  557. std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
  558. std::env::remove_var("CLAUDE_CONFIG_HOME");
  559. }
  560. #[test]
  561. fn distinct_requests_do_not_collide_in_completion_cache() {
  562. let _guard = test_env_lock();
  563. let temp_root = std::env::temp_dir().join(format!(
  564. "prompt-cache-distinct-{}-{}",
  565. std::process::id(),
  566. SystemTime::now()
  567. .duration_since(UNIX_EPOCH)
  568. .expect("time")
  569. .as_nanos()
  570. ));
  571. std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
  572. let cache = PromptCache::new("distinct-request-session");
  573. let first_request = sample_request("first");
  574. let second_request = sample_request("second");
  575. let response = sample_response(42, 12, "cached");
  576. let _ = cache.record_response(&first_request, &response);
  577. assert!(cache.lookup_completion(&second_request).is_none());
  578. std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
  579. std::env::remove_var("CLAUDE_CONFIG_HOME");
  580. }
  581. #[test]
  582. fn expired_completion_entries_are_not_reused() {
  583. let _guard = test_env_lock();
  584. let temp_root = std::env::temp_dir().join(format!(
  585. "prompt-cache-expired-{}-{}",
  586. std::process::id(),
  587. SystemTime::now()
  588. .duration_since(UNIX_EPOCH)
  589. .expect("time")
  590. .as_nanos()
  591. ));
  592. std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
  593. let cache = PromptCache::with_config(PromptCacheConfig {
  594. session_id: "expired-session".to_string(),
  595. completion_ttl: Duration::ZERO,
  596. ..PromptCacheConfig::default()
  597. });
  598. let request = sample_request("expire me");
  599. let response = sample_response(7, 3, "stale");
  600. let _ = cache.record_response(&request, &response);
  601. assert!(cache.lookup_completion(&request).is_none());
  602. let stats = cache.stats();
  603. assert_eq!(stats.completion_cache_hits, 0);
  604. assert_eq!(stats.completion_cache_misses, 1);
  605. std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
  606. std::env::remove_var("CLAUDE_CONFIG_HOME");
  607. }
  608. #[test]
  609. fn sanitize_path_caps_long_values() {
  610. let long_value = "x".repeat(200);
  611. let sanitized = sanitize_path_segment(&long_value);
  612. assert!(sanitized.len() <= 80);
  613. }
  614. #[test]
  615. fn request_hashes_are_versioned_and_stable() {
  616. let request = sample_request("stable");
  617. let first = request_hash_hex(&request);
  618. let second = request_hash_hex(&request);
  619. assert_eq!(first, second);
  620. assert!(first.starts_with(REQUEST_FINGERPRINT_PREFIX));
  621. }
  622. fn sample_request(text: &str) -> MessageRequest {
  623. MessageRequest {
  624. model: "claude-3-7-sonnet-latest".to_string(),
  625. max_tokens: 64,
  626. messages: vec![InputMessage::user_text(text)],
  627. system: Some("system".to_string()),
  628. tools: None,
  629. tool_choice: None,
  630. stream: false,
  631. }
  632. }
  633. fn sample_response(
  634. cache_read_input_tokens: u32,
  635. output_tokens: u32,
  636. text: &str,
  637. ) -> MessageResponse {
  638. MessageResponse {
  639. id: "msg_test".to_string(),
  640. kind: "message".to_string(),
  641. role: "assistant".to_string(),
  642. content: vec![OutputContentBlock::Text {
  643. text: text.to_string(),
  644. }],
  645. model: "claude-3-7-sonnet-latest".to_string(),
  646. stop_reason: Some("end_turn".to_string()),
  647. stop_sequence: None,
  648. usage: Usage {
  649. input_tokens: 10,
  650. cache_creation_input_tokens: 5,
  651. cache_read_input_tokens,
  652. output_tokens,
  653. },
  654. request_id: Some("req_test".to_string()),
  655. }
  656. }
  657. }