|
|
@@ -0,0 +1,679 @@
|
|
|
+use std::fs;
|
|
|
+use std::path::{Path, PathBuf};
|
|
|
+use std::sync::{Arc, Mutex};
|
|
|
+use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
|
|
+
|
|
|
+use serde::{Deserialize, Serialize};
|
|
|
+
|
|
|
+use crate::types::{MessageRequest, MessageResponse, Usage};
|
|
|
+
|
|
|
+const DEFAULT_COMPLETION_TTL_SECS: u64 = 30;
|
|
|
+const DEFAULT_PROMPT_TTL_SECS: u64 = 5 * 60;
|
|
|
+const DEFAULT_BREAK_MIN_DROP: u32 = 2_000;
|
|
|
+const MAX_SANITIZED_LENGTH: usize = 80;
|
|
|
+const REQUEST_FINGERPRINT_VERSION: u32 = 1;
|
|
|
+const REQUEST_FINGERPRINT_PREFIX: &str = "v1";
|
|
|
+const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
|
|
|
+const FNV_PRIME: u64 = 0x0000_0100_0000_01b3;
|
|
|
+
|
|
|
+#[derive(Debug, Clone)]
|
|
|
+pub struct PromptCacheConfig {
|
|
|
+ pub session_id: String,
|
|
|
+ pub completion_ttl: Duration,
|
|
|
+ pub prompt_ttl: Duration,
|
|
|
+ pub cache_break_min_drop: u32,
|
|
|
+}
|
|
|
+
|
|
|
+impl PromptCacheConfig {
|
|
|
+ #[must_use]
|
|
|
+ pub fn new(session_id: impl Into<String>) -> Self {
|
|
|
+ Self {
|
|
|
+ session_id: session_id.into(),
|
|
|
+ completion_ttl: Duration::from_secs(DEFAULT_COMPLETION_TTL_SECS),
|
|
|
+ prompt_ttl: Duration::from_secs(DEFAULT_PROMPT_TTL_SECS),
|
|
|
+ cache_break_min_drop: DEFAULT_BREAK_MIN_DROP,
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl Default for PromptCacheConfig {
|
|
|
+ fn default() -> Self {
|
|
|
+ Self::new("default")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
|
+pub struct PromptCachePaths {
|
|
|
+ pub root: PathBuf,
|
|
|
+ pub session_dir: PathBuf,
|
|
|
+ pub completion_dir: PathBuf,
|
|
|
+ pub session_state_path: PathBuf,
|
|
|
+ pub stats_path: PathBuf,
|
|
|
+}
|
|
|
+
|
|
|
+impl PromptCachePaths {
|
|
|
+ #[must_use]
|
|
|
+ pub fn for_session(session_id: &str) -> Self {
|
|
|
+ let root = base_cache_root();
|
|
|
+ let session_dir = root.join(sanitize_path_segment(session_id));
|
|
|
+ let completion_dir = session_dir.join("completions");
|
|
|
+ Self {
|
|
|
+ root,
|
|
|
+ session_state_path: session_dir.join("session-state.json"),
|
|
|
+ stats_path: session_dir.join("stats.json"),
|
|
|
+ session_dir,
|
|
|
+ completion_dir,
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn completion_entry_path(&self, request_hash: &str) -> PathBuf {
|
|
|
+ self.completion_dir.join(format!("{request_hash}.json"))
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
|
|
|
+pub struct PromptCacheStats {
|
|
|
+ pub tracked_requests: u64,
|
|
|
+ pub completion_cache_hits: u64,
|
|
|
+ pub completion_cache_misses: u64,
|
|
|
+ pub completion_cache_writes: u64,
|
|
|
+ pub expected_invalidations: u64,
|
|
|
+ pub unexpected_cache_breaks: u64,
|
|
|
+ pub total_cache_creation_input_tokens: u64,
|
|
|
+ pub total_cache_read_input_tokens: u64,
|
|
|
+ pub last_cache_creation_input_tokens: Option<u32>,
|
|
|
+ pub last_cache_read_input_tokens: Option<u32>,
|
|
|
+ pub last_request_hash: Option<String>,
|
|
|
+ pub last_completion_cache_key: Option<String>,
|
|
|
+ pub last_break_reason: Option<String>,
|
|
|
+ pub last_cache_source: Option<String>,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
|
+pub struct CacheBreakEvent {
|
|
|
+ pub unexpected: bool,
|
|
|
+ pub reason: String,
|
|
|
+ pub previous_cache_read_input_tokens: u32,
|
|
|
+ pub current_cache_read_input_tokens: u32,
|
|
|
+ pub token_drop: u32,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, PartialEq, Eq)]
|
|
|
+pub struct PromptCacheRecord {
|
|
|
+ pub cache_break: Option<CacheBreakEvent>,
|
|
|
+ pub stats: PromptCacheStats,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone)]
|
|
|
+pub struct PromptCache {
|
|
|
+ inner: Arc<Mutex<PromptCacheInner>>,
|
|
|
+}
|
|
|
+
|
|
|
+impl PromptCache {
|
|
|
+ #[must_use]
|
|
|
+ pub fn new(session_id: impl Into<String>) -> Self {
|
|
|
+ Self::with_config(PromptCacheConfig::new(session_id))
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn with_config(config: PromptCacheConfig) -> Self {
|
|
|
+ let paths = PromptCachePaths::for_session(&config.session_id);
|
|
|
+ let stats = read_json::<PromptCacheStats>(&paths.stats_path).unwrap_or_default();
|
|
|
+ let previous = read_json::<TrackedPromptState>(&paths.session_state_path);
|
|
|
+ Self {
|
|
|
+ inner: Arc::new(Mutex::new(PromptCacheInner {
|
|
|
+ config,
|
|
|
+ paths,
|
|
|
+ stats,
|
|
|
+ previous,
|
|
|
+ })),
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn paths(&self) -> PromptCachePaths {
|
|
|
+ self.lock().paths.clone()
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn stats(&self) -> PromptCacheStats {
|
|
|
+ self.lock().stats.clone()
|
|
|
+ }
|
|
|
+
|
|
|
+ pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> {
|
|
|
+ let request_hash = request_hash_hex(request);
|
|
|
+ let (paths, ttl) = {
|
|
|
+ let inner = self.lock();
|
|
|
+ (inner.paths.clone(), inner.config.completion_ttl)
|
|
|
+ };
|
|
|
+ let entry_path = paths.completion_entry_path(&request_hash);
|
|
|
+ let entry = read_json::<CompletionCacheEntry>(&entry_path);
|
|
|
+ let Some(entry) = entry else {
|
|
|
+ let mut inner = self.lock();
|
|
|
+ inner.stats.completion_cache_misses += 1;
|
|
|
+ inner.stats.last_completion_cache_key = Some(request_hash);
|
|
|
+ persist_state(&inner);
|
|
|
+ return None;
|
|
|
+ };
|
|
|
+
|
|
|
+ if entry.fingerprint_version != current_fingerprint_version() {
|
|
|
+ let mut inner = self.lock();
|
|
|
+ inner.stats.completion_cache_misses += 1;
|
|
|
+ inner.stats.last_completion_cache_key = Some(request_hash.clone());
|
|
|
+ let _ = fs::remove_file(entry_path);
|
|
|
+ persist_state(&inner);
|
|
|
+ return None;
|
|
|
+ }
|
|
|
+
|
|
|
+ let expired = now_unix_secs().saturating_sub(entry.cached_at_unix_secs) >= ttl.as_secs();
|
|
|
+ let mut inner = self.lock();
|
|
|
+ inner.stats.last_completion_cache_key = Some(request_hash.clone());
|
|
|
+ if expired {
|
|
|
+ inner.stats.completion_cache_misses += 1;
|
|
|
+ let _ = fs::remove_file(entry_path);
|
|
|
+ persist_state(&inner);
|
|
|
+ return None;
|
|
|
+ }
|
|
|
+
|
|
|
+ inner.stats.completion_cache_hits += 1;
|
|
|
+ apply_usage_to_stats(
|
|
|
+ &mut inner.stats,
|
|
|
+ &entry.response.usage,
|
|
|
+ &request_hash,
|
|
|
+ "completion-cache",
|
|
|
+ );
|
|
|
+ inner.previous = Some(TrackedPromptState::from_usage(
|
|
|
+ request,
|
|
|
+ &entry.response.usage,
|
|
|
+ ));
|
|
|
+ persist_state(&inner);
|
|
|
+ Some(entry.response)
|
|
|
+ }
|
|
|
+
|
|
|
+ pub fn record_response(
|
|
|
+ &self,
|
|
|
+ request: &MessageRequest,
|
|
|
+ response: &MessageResponse,
|
|
|
+ ) -> PromptCacheRecord {
|
|
|
+ self.record_usage_internal(request, &response.usage, Some(response))
|
|
|
+ }
|
|
|
+
|
|
|
+ pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord {
|
|
|
+ self.record_usage_internal(request, usage, None)
|
|
|
+ }
|
|
|
+
|
|
|
+ fn record_usage_internal(
|
|
|
+ &self,
|
|
|
+ request: &MessageRequest,
|
|
|
+ usage: &Usage,
|
|
|
+ response: Option<&MessageResponse>,
|
|
|
+ ) -> PromptCacheRecord {
|
|
|
+ let request_hash = request_hash_hex(request);
|
|
|
+ let mut inner = self.lock();
|
|
|
+ let previous = inner.previous.clone();
|
|
|
+ let current = TrackedPromptState::from_usage(request, usage);
|
|
|
+ let cache_break = detect_cache_break(&inner.config, previous.as_ref(), ¤t);
|
|
|
+
|
|
|
+ inner.stats.tracked_requests += 1;
|
|
|
+ apply_usage_to_stats(&mut inner.stats, usage, &request_hash, "api-response");
|
|
|
+ if let Some(event) = &cache_break {
|
|
|
+ if event.unexpected {
|
|
|
+ inner.stats.unexpected_cache_breaks += 1;
|
|
|
+ } else {
|
|
|
+ inner.stats.expected_invalidations += 1;
|
|
|
+ }
|
|
|
+ inner.stats.last_break_reason = Some(event.reason.clone());
|
|
|
+ }
|
|
|
+
|
|
|
+ inner.previous = Some(current);
|
|
|
+ if let Some(response) = response {
|
|
|
+ write_completion_entry(&inner.paths, &request_hash, response);
|
|
|
+ inner.stats.completion_cache_writes += 1;
|
|
|
+ }
|
|
|
+ persist_state(&inner);
|
|
|
+
|
|
|
+ PromptCacheRecord {
|
|
|
+ cache_break,
|
|
|
+ stats: inner.stats.clone(),
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fn lock(&self) -> std::sync::MutexGuard<'_, PromptCacheInner> {
|
|
|
+ self.inner
|
|
|
+ .lock()
|
|
|
+ .unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug)]
|
|
|
+struct PromptCacheInner {
|
|
|
+ config: PromptCacheConfig,
|
|
|
+ paths: PromptCachePaths,
|
|
|
+ stats: PromptCacheStats,
|
|
|
+ previous: Option<TrackedPromptState>,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
|
+struct CompletionCacheEntry {
|
|
|
+ cached_at_unix_secs: u64,
|
|
|
+ #[serde(default = "current_fingerprint_version")]
|
|
|
+ fingerprint_version: u32,
|
|
|
+ response: MessageResponse,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
|
+struct TrackedPromptState {
|
|
|
+ observed_at_unix_secs: u64,
|
|
|
+ #[serde(default = "current_fingerprint_version")]
|
|
|
+ fingerprint_version: u32,
|
|
|
+ request_hash: u64,
|
|
|
+ model_hash: u64,
|
|
|
+ system_hash: u64,
|
|
|
+ tools_hash: u64,
|
|
|
+ messages_hash: u64,
|
|
|
+ cache_read_input_tokens: u32,
|
|
|
+}
|
|
|
+
|
|
|
+impl TrackedPromptState {
|
|
|
+ fn from_usage(request: &MessageRequest, usage: &Usage) -> Self {
|
|
|
+ let hashes = RequestHashes::from_request(request);
|
|
|
+ Self {
|
|
|
+ observed_at_unix_secs: now_unix_secs(),
|
|
|
+ fingerprint_version: current_fingerprint_version(),
|
|
|
+ request_hash: hashes.request_hash,
|
|
|
+ model_hash: hashes.model_hash,
|
|
|
+ system_hash: hashes.system_hash,
|
|
|
+ tools_hash: hashes.tools_hash,
|
|
|
+ messages_hash: hashes.messages_hash,
|
|
|
+ cache_read_input_tokens: usage.cache_read_input_tokens,
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, Copy)]
|
|
|
+struct RequestHashes {
|
|
|
+ request_hash: u64,
|
|
|
+ model_hash: u64,
|
|
|
+ system_hash: u64,
|
|
|
+ tools_hash: u64,
|
|
|
+ messages_hash: u64,
|
|
|
+}
|
|
|
+
|
|
|
+impl RequestHashes {
|
|
|
+ fn from_request(request: &MessageRequest) -> Self {
|
|
|
+ Self {
|
|
|
+ request_hash: hash_serializable(request),
|
|
|
+ model_hash: hash_serializable(&request.model),
|
|
|
+ system_hash: hash_serializable(&request.system),
|
|
|
+ tools_hash: hash_serializable(&request.tools),
|
|
|
+ messages_hash: hash_serializable(&request.messages),
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+fn detect_cache_break(
|
|
|
+ config: &PromptCacheConfig,
|
|
|
+ previous: Option<&TrackedPromptState>,
|
|
|
+ current: &TrackedPromptState,
|
|
|
+) -> Option<CacheBreakEvent> {
|
|
|
+ let previous = previous?;
|
|
|
+ if previous.fingerprint_version != current.fingerprint_version {
|
|
|
+ return Some(CacheBreakEvent {
|
|
|
+ unexpected: false,
|
|
|
+ reason: format!(
|
|
|
+ "fingerprint version changed (v{} -> v{})",
|
|
|
+ previous.fingerprint_version, current.fingerprint_version
|
|
|
+ ),
|
|
|
+ previous_cache_read_input_tokens: previous.cache_read_input_tokens,
|
|
|
+ current_cache_read_input_tokens: current.cache_read_input_tokens,
|
|
|
+ token_drop: previous
|
|
|
+ .cache_read_input_tokens
|
|
|
+ .saturating_sub(current.cache_read_input_tokens),
|
|
|
+ });
|
|
|
+ }
|
|
|
+ let token_drop = previous
|
|
|
+ .cache_read_input_tokens
|
|
|
+ .saturating_sub(current.cache_read_input_tokens);
|
|
|
+ if token_drop < config.cache_break_min_drop {
|
|
|
+ return None;
|
|
|
+ }
|
|
|
+
|
|
|
+ let mut reasons = Vec::new();
|
|
|
+ if previous.model_hash != current.model_hash {
|
|
|
+ reasons.push("model changed");
|
|
|
+ }
|
|
|
+ if previous.system_hash != current.system_hash {
|
|
|
+ reasons.push("system prompt changed");
|
|
|
+ }
|
|
|
+ if previous.tools_hash != current.tools_hash {
|
|
|
+ reasons.push("tool definitions changed");
|
|
|
+ }
|
|
|
+ if previous.messages_hash != current.messages_hash {
|
|
|
+ reasons.push("message payload changed");
|
|
|
+ }
|
|
|
+
|
|
|
+ let elapsed = current
|
|
|
+ .observed_at_unix_secs
|
|
|
+ .saturating_sub(previous.observed_at_unix_secs);
|
|
|
+
|
|
|
+ let (unexpected, reason) = if reasons.is_empty() {
|
|
|
+ if elapsed > config.prompt_ttl.as_secs() {
|
|
|
+ (
|
|
|
+ false,
|
|
|
+ format!("possible prompt cache TTL expiry after {elapsed}s"),
|
|
|
+ )
|
|
|
+ } else {
|
|
|
+ (
|
|
|
+ true,
|
|
|
+ "cache read tokens dropped while prompt fingerprint remained stable".to_string(),
|
|
|
+ )
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ (false, reasons.join(", "))
|
|
|
+ };
|
|
|
+
|
|
|
+ Some(CacheBreakEvent {
|
|
|
+ unexpected,
|
|
|
+ reason,
|
|
|
+ previous_cache_read_input_tokens: previous.cache_read_input_tokens,
|
|
|
+ current_cache_read_input_tokens: current.cache_read_input_tokens,
|
|
|
+ token_drop,
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+fn apply_usage_to_stats(
|
|
|
+ stats: &mut PromptCacheStats,
|
|
|
+ usage: &Usage,
|
|
|
+ request_hash: &str,
|
|
|
+ source: &str,
|
|
|
+) {
|
|
|
+ stats.total_cache_creation_input_tokens += u64::from(usage.cache_creation_input_tokens);
|
|
|
+ stats.total_cache_read_input_tokens += u64::from(usage.cache_read_input_tokens);
|
|
|
+ stats.last_cache_creation_input_tokens = Some(usage.cache_creation_input_tokens);
|
|
|
+ stats.last_cache_read_input_tokens = Some(usage.cache_read_input_tokens);
|
|
|
+ stats.last_request_hash = Some(request_hash.to_string());
|
|
|
+ stats.last_cache_source = Some(source.to_string());
|
|
|
+}
|
|
|
+
|
|
|
+fn persist_state(inner: &PromptCacheInner) {
|
|
|
+ let _ = ensure_cache_dirs(&inner.paths);
|
|
|
+ let _ = write_json(&inner.paths.stats_path, &inner.stats);
|
|
|
+ if let Some(previous) = &inner.previous {
|
|
|
+ let _ = write_json(&inner.paths.session_state_path, previous);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+fn write_completion_entry(
|
|
|
+ paths: &PromptCachePaths,
|
|
|
+ request_hash: &str,
|
|
|
+ response: &MessageResponse,
|
|
|
+) {
|
|
|
+ let _ = ensure_cache_dirs(paths);
|
|
|
+ let entry = CompletionCacheEntry {
|
|
|
+ cached_at_unix_secs: now_unix_secs(),
|
|
|
+ fingerprint_version: current_fingerprint_version(),
|
|
|
+ response: response.clone(),
|
|
|
+ };
|
|
|
+ let _ = write_json(&paths.completion_entry_path(request_hash), &entry);
|
|
|
+}
|
|
|
+
|
|
|
+fn ensure_cache_dirs(paths: &PromptCachePaths) -> std::io::Result<()> {
|
|
|
+ fs::create_dir_all(&paths.completion_dir)
|
|
|
+}
|
|
|
+
|
|
|
+fn write_json<T: Serialize>(path: &Path, value: &T) -> std::io::Result<()> {
|
|
|
+ let json = serde_json::to_vec_pretty(value)
|
|
|
+ .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?;
|
|
|
+ fs::write(path, json)
|
|
|
+}
|
|
|
+
|
|
|
+fn read_json<T: for<'de> Deserialize<'de>>(path: &Path) -> Option<T> {
|
|
|
+ let bytes = fs::read(path).ok()?;
|
|
|
+ serde_json::from_slice(&bytes).ok()
|
|
|
+}
|
|
|
+
|
|
|
+fn request_hash_hex(request: &MessageRequest) -> String {
|
|
|
+ format!(
|
|
|
+ "{REQUEST_FINGERPRINT_PREFIX}-{:016x}",
|
|
|
+ hash_serializable(request)
|
|
|
+ )
|
|
|
+}
|
|
|
+
|
|
|
+fn hash_serializable<T: Serialize>(value: &T) -> u64 {
|
|
|
+ let json = serde_json::to_vec(value).unwrap_or_default();
|
|
|
+ stable_hash_bytes(&json)
|
|
|
+}
|
|
|
+
|
|
|
+fn sanitize_path_segment(value: &str) -> String {
|
|
|
+ let sanitized: String = value
|
|
|
+ .chars()
|
|
|
+ .map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '-' })
|
|
|
+ .collect();
|
|
|
+ if sanitized.len() <= MAX_SANITIZED_LENGTH {
|
|
|
+ return sanitized;
|
|
|
+ }
|
|
|
+ let suffix = format!("-{:x}", hash_string(value));
|
|
|
+ format!(
|
|
|
+ "{}{}",
|
|
|
+ &sanitized[..MAX_SANITIZED_LENGTH.saturating_sub(suffix.len())],
|
|
|
+ suffix
|
|
|
+ )
|
|
|
+}
|
|
|
+
|
|
|
+fn hash_string(value: &str) -> u64 {
|
|
|
+ stable_hash_bytes(value.as_bytes())
|
|
|
+}
|
|
|
+
|
|
|
+fn base_cache_root() -> PathBuf {
|
|
|
+ if let Some(config_home) = std::env::var_os("CLAUDE_CONFIG_HOME") {
|
|
|
+ return PathBuf::from(config_home)
|
|
|
+ .join("cache")
|
|
|
+ .join("prompt-cache");
|
|
|
+ }
|
|
|
+ if let Some(home) = std::env::var_os("HOME") {
|
|
|
+ return PathBuf::from(home)
|
|
|
+ .join(".claude")
|
|
|
+ .join("cache")
|
|
|
+ .join("prompt-cache");
|
|
|
+ }
|
|
|
+ std::env::temp_dir().join("claude-prompt-cache")
|
|
|
+}
|
|
|
+
|
|
|
+fn now_unix_secs() -> u64 {
|
|
|
+ SystemTime::now()
|
|
|
+ .duration_since(UNIX_EPOCH)
|
|
|
+ .map_or(0, |duration| duration.as_secs())
|
|
|
+}
|
|
|
+
|
|
|
+const fn current_fingerprint_version() -> u32 {
|
|
|
+ REQUEST_FINGERPRINT_VERSION
|
|
|
+}
|
|
|
+
|
|
|
+fn stable_hash_bytes(bytes: &[u8]) -> u64 {
|
|
|
+ let mut hash = FNV_OFFSET_BASIS;
|
|
|
+ for byte in bytes {
|
|
|
+ hash ^= u64::from(*byte);
|
|
|
+ hash = hash.wrapping_mul(FNV_PRIME);
|
|
|
+ }
|
|
|
+ hash
|
|
|
+}
|
|
|
+
|
|
|
+#[cfg(test)]
|
|
|
+mod tests {
|
|
|
+ use std::sync::{Mutex, OnceLock};
|
|
|
+ use std::time::{SystemTime, UNIX_EPOCH};
|
|
|
+
|
|
|
+ use super::{
|
|
|
+ detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache,
|
|
|
+ PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX,
|
|
|
+ };
|
|
|
+ use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage};
|
|
|
+
|
|
|
+ fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
|
|
+ static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
|
|
+ LOCK.get_or_init(|| Mutex::new(()))
|
|
|
+ .lock()
|
|
|
+ .unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn path_builder_sanitizes_session_identifier() {
|
|
|
+ let paths = PromptCachePaths::for_session("session:/with spaces");
|
|
|
+ let session_dir = paths
|
|
|
+ .session_dir
|
|
|
+ .file_name()
|
|
|
+ .and_then(|value| value.to_str())
|
|
|
+ .expect("session dir name");
|
|
|
+ assert_eq!(session_dir, "session--with-spaces");
|
|
|
+ assert!(paths.completion_dir.ends_with("completions"));
|
|
|
+ assert!(paths.stats_path.ends_with("stats.json"));
|
|
|
+ assert!(paths.session_state_path.ends_with("session-state.json"));
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn request_fingerprint_drives_unexpected_break_detection() {
|
|
|
+ let request = sample_request("same");
|
|
|
+ let previous = TrackedPromptState::from_usage(
|
|
|
+ &request,
|
|
|
+ &Usage {
|
|
|
+ input_tokens: 0,
|
|
|
+ cache_creation_input_tokens: 0,
|
|
|
+ cache_read_input_tokens: 6_000,
|
|
|
+ output_tokens: 0,
|
|
|
+ },
|
|
|
+ );
|
|
|
+ let current = TrackedPromptState::from_usage(
|
|
|
+ &request,
|
|
|
+ &Usage {
|
|
|
+ input_tokens: 0,
|
|
|
+ cache_creation_input_tokens: 0,
|
|
|
+ cache_read_input_tokens: 1_000,
|
|
|
+ output_tokens: 0,
|
|
|
+ },
|
|
|
+ );
|
|
|
+ let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t)
|
|
|
+ .expect("break should be detected");
|
|
|
+ assert!(event.unexpected);
|
|
|
+ assert!(event.reason.contains("stable"));
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn changed_prompt_marks_break_as_expected() {
|
|
|
+ let previous_request = sample_request("first");
|
|
|
+ let current_request = sample_request("second");
|
|
|
+ let previous = TrackedPromptState::from_usage(
|
|
|
+ &previous_request,
|
|
|
+ &Usage {
|
|
|
+ input_tokens: 0,
|
|
|
+ cache_creation_input_tokens: 0,
|
|
|
+ cache_read_input_tokens: 6_000,
|
|
|
+ output_tokens: 0,
|
|
|
+ },
|
|
|
+ );
|
|
|
+ let current = TrackedPromptState::from_usage(
|
|
|
+ ¤t_request,
|
|
|
+ &Usage {
|
|
|
+ input_tokens: 0,
|
|
|
+ cache_creation_input_tokens: 0,
|
|
|
+ cache_read_input_tokens: 1_000,
|
|
|
+ output_tokens: 0,
|
|
|
+ },
|
|
|
+ );
|
|
|
+ let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t)
|
|
|
+ .expect("break should be detected");
|
|
|
+ assert!(!event.unexpected);
|
|
|
+ assert!(event.reason.contains("message payload changed"));
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn completion_cache_round_trip_persists_recent_response() {
|
|
|
+ let _guard = env_lock();
|
|
|
+ let temp_root = std::env::temp_dir().join(format!(
|
|
|
+ "prompt-cache-test-{}-{}",
|
|
|
+ std::process::id(),
|
|
|
+ SystemTime::now()
|
|
|
+ .duration_since(UNIX_EPOCH)
|
|
|
+ .expect("time")
|
|
|
+ .as_nanos()
|
|
|
+ ));
|
|
|
+ std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
|
|
|
+ let cache = PromptCache::new("unit-test-session");
|
|
|
+ let request = sample_request("cache me");
|
|
|
+ let response = sample_response(42, 12, "cached");
|
|
|
+
|
|
|
+ assert!(cache.lookup_completion(&request).is_none());
|
|
|
+ let record = cache.record_response(&request, &response);
|
|
|
+ assert!(record.cache_break.is_none());
|
|
|
+
|
|
|
+ let cached = cache
|
|
|
+ .lookup_completion(&request)
|
|
|
+ .expect("cached response should load");
|
|
|
+ assert_eq!(cached.content, response.content);
|
|
|
+
|
|
|
+ let stats = cache.stats();
|
|
|
+ assert_eq!(stats.completion_cache_hits, 1);
|
|
|
+ assert_eq!(stats.completion_cache_misses, 1);
|
|
|
+ assert_eq!(stats.completion_cache_writes, 1);
|
|
|
+
|
|
|
+ let persisted = read_json::<super::PromptCacheStats>(&cache.paths().stats_path)
|
|
|
+ .expect("stats should persist");
|
|
|
+ assert_eq!(persisted.completion_cache_hits, 1);
|
|
|
+
|
|
|
+ std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
|
|
|
+ std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn sanitize_path_caps_long_values() {
|
|
|
+ let long_value = "x".repeat(200);
|
|
|
+ let sanitized = sanitize_path_segment(&long_value);
|
|
|
+ assert!(sanitized.len() <= 80);
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn request_hashes_are_versioned_and_stable() {
|
|
|
+ let request = sample_request("stable");
|
|
|
+ let first = request_hash_hex(&request);
|
|
|
+ let second = request_hash_hex(&request);
|
|
|
+ assert_eq!(first, second);
|
|
|
+ assert!(first.starts_with(REQUEST_FINGERPRINT_PREFIX));
|
|
|
+ }
|
|
|
+
|
|
|
+ fn sample_request(text: &str) -> MessageRequest {
|
|
|
+ MessageRequest {
|
|
|
+ model: "claude-3-7-sonnet-latest".to_string(),
|
|
|
+ max_tokens: 64,
|
|
|
+ messages: vec![InputMessage::user_text(text)],
|
|
|
+ system: Some("system".to_string()),
|
|
|
+ tools: None,
|
|
|
+ tool_choice: None,
|
|
|
+ stream: false,
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fn sample_response(
|
|
|
+ cache_read_input_tokens: u32,
|
|
|
+ output_tokens: u32,
|
|
|
+ text: &str,
|
|
|
+ ) -> MessageResponse {
|
|
|
+ MessageResponse {
|
|
|
+ id: "msg_test".to_string(),
|
|
|
+ kind: "message".to_string(),
|
|
|
+ role: "assistant".to_string(),
|
|
|
+ content: vec![OutputContentBlock::Text {
|
|
|
+ text: text.to_string(),
|
|
|
+ }],
|
|
|
+ model: "claude-3-7-sonnet-latest".to_string(),
|
|
|
+ stop_reason: Some("end_turn".to_string()),
|
|
|
+ stop_sequence: None,
|
|
|
+ usage: Usage {
|
|
|
+ input_tokens: 10,
|
|
|
+ cache_creation_input_tokens: 5,
|
|
|
+ cache_read_input_tokens,
|
|
|
+ output_tokens,
|
|
|
+ },
|
|
|
+ request_id: Some("req_test".to_string()),
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|