|
@@ -141,6 +141,7 @@ impl PromptCache {
|
|
|
self.lock().stats.clone()
|
|
self.lock().stats.clone()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ #[must_use]
|
|
|
pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> {
|
|
pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> {
|
|
|
let request_hash = request_hash_hex(request);
|
|
let request_hash = request_hash_hex(request);
|
|
|
let (paths, ttl) = {
|
|
let (paths, ttl) = {
|
|
@@ -191,6 +192,7 @@ impl PromptCache {
|
|
|
Some(entry.response)
|
|
Some(entry.response)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ #[must_use]
|
|
|
pub fn record_response(
|
|
pub fn record_response(
|
|
|
&self,
|
|
&self,
|
|
|
request: &MessageRequest,
|
|
request: &MessageRequest,
|
|
@@ -199,6 +201,7 @@ impl PromptCache {
|
|
|
self.record_usage_internal(request, &response.usage, Some(response))
|
|
self.record_usage_internal(request, &response.usage, Some(response))
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ #[must_use]
|
|
|
pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord {
|
|
pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord {
|
|
|
self.record_usage_internal(request, usage, None)
|
|
self.record_usage_internal(request, usage, None)
|
|
|
}
|
|
}
|
|
@@ -267,7 +270,6 @@ struct TrackedPromptState {
|
|
|
observed_at_unix_secs: u64,
|
|
observed_at_unix_secs: u64,
|
|
|
#[serde(default = "current_fingerprint_version")]
|
|
#[serde(default = "current_fingerprint_version")]
|
|
|
fingerprint_version: u32,
|
|
fingerprint_version: u32,
|
|
|
- request_hash: u64,
|
|
|
|
|
model_hash: u64,
|
|
model_hash: u64,
|
|
|
system_hash: u64,
|
|
system_hash: u64,
|
|
|
tools_hash: u64,
|
|
tools_hash: u64,
|
|
@@ -277,37 +279,34 @@ struct TrackedPromptState {
|
|
|
|
|
|
|
|
impl TrackedPromptState {
|
|
impl TrackedPromptState {
|
|
|
fn from_usage(request: &MessageRequest, usage: &Usage) -> Self {
|
|
fn from_usage(request: &MessageRequest, usage: &Usage) -> Self {
|
|
|
- let hashes = RequestHashes::from_request(request);
|
|
|
|
|
|
|
+ let hashes = RequestFingerprints::from_request(request);
|
|
|
Self {
|
|
Self {
|
|
|
observed_at_unix_secs: now_unix_secs(),
|
|
observed_at_unix_secs: now_unix_secs(),
|
|
|
fingerprint_version: current_fingerprint_version(),
|
|
fingerprint_version: current_fingerprint_version(),
|
|
|
- request_hash: hashes.request_hash,
|
|
|
|
|
- model_hash: hashes.model_hash,
|
|
|
|
|
- system_hash: hashes.system_hash,
|
|
|
|
|
- tools_hash: hashes.tools_hash,
|
|
|
|
|
- messages_hash: hashes.messages_hash,
|
|
|
|
|
|
|
+ model_hash: hashes.model,
|
|
|
|
|
+ system_hash: hashes.system,
|
|
|
|
|
+ tools_hash: hashes.tools,
|
|
|
|
|
+ messages_hash: hashes.messages,
|
|
|
cache_read_input_tokens: usage.cache_read_input_tokens,
|
|
cache_read_input_tokens: usage.cache_read_input_tokens,
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug, Clone, Copy)]
|
|
#[derive(Debug, Clone, Copy)]
|
|
|
-struct RequestHashes {
|
|
|
|
|
- request_hash: u64,
|
|
|
|
|
- model_hash: u64,
|
|
|
|
|
- system_hash: u64,
|
|
|
|
|
- tools_hash: u64,
|
|
|
|
|
- messages_hash: u64,
|
|
|
|
|
|
|
+struct RequestFingerprints {
|
|
|
|
|
+ model: u64,
|
|
|
|
|
+ system: u64,
|
|
|
|
|
+ tools: u64,
|
|
|
|
|
+ messages: u64,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-impl RequestHashes {
|
|
|
|
|
|
|
+impl RequestFingerprints {
|
|
|
fn from_request(request: &MessageRequest) -> Self {
|
|
fn from_request(request: &MessageRequest) -> Self {
|
|
|
Self {
|
|
Self {
|
|
|
- request_hash: hash_serializable(request),
|
|
|
|
|
- model_hash: hash_serializable(&request.model),
|
|
|
|
|
- system_hash: hash_serializable(&request.system),
|
|
|
|
|
- tools_hash: hash_serializable(&request.tools),
|
|
|
|
|
- messages_hash: hash_serializable(&request.messages),
|
|
|
|
|
|
|
+ model: hash_serializable(&request.model),
|
|
|
|
|
+ system: hash_serializable(&request.system),
|
|
|
|
|
+ tools: hash_serializable(&request.tools),
|
|
|
|
|
+ messages: hash_serializable(&request.messages),
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -501,22 +500,15 @@ fn stable_hash_bytes(bytes: &[u8]) -> u64 {
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
#[cfg(test)]
|
|
|
mod tests {
|
|
mod tests {
|
|
|
- use std::sync::{Mutex, OnceLock};
|
|
|
|
|
- use std::time::{SystemTime, UNIX_EPOCH};
|
|
|
|
|
|
|
+ use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
|
|
|
|
|
|
|
use super::{
|
|
use super::{
|
|
|
detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache,
|
|
detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache,
|
|
|
PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX,
|
|
PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX,
|
|
|
};
|
|
};
|
|
|
|
|
+ use crate::test_env_lock;
|
|
|
use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage};
|
|
use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage};
|
|
|
|
|
|
|
|
- fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
|
|
|
|
- static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
|
|
|
|
- LOCK.get_or_init(|| Mutex::new(()))
|
|
|
|
|
- .lock()
|
|
|
|
|
- .unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
#[test]
|
|
#[test]
|
|
|
fn path_builder_sanitizes_session_identifier() {
|
|
fn path_builder_sanitizes_session_identifier() {
|
|
|
let paths = PromptCachePaths::for_session("session:/with spaces");
|
|
let paths = PromptCachePaths::for_session("session:/with spaces");
|
|
@@ -588,7 +580,7 @@ mod tests {
|
|
|
|
|
|
|
|
#[test]
|
|
#[test]
|
|
|
fn completion_cache_round_trip_persists_recent_response() {
|
|
fn completion_cache_round_trip_persists_recent_response() {
|
|
|
- let _guard = env_lock();
|
|
|
|
|
|
|
+ let _guard = test_env_lock();
|
|
|
let temp_root = std::env::temp_dir().join(format!(
|
|
let temp_root = std::env::temp_dir().join(format!(
|
|
|
"prompt-cache-test-{}-{}",
|
|
"prompt-cache-test-{}-{}",
|
|
|
std::process::id(),
|
|
std::process::id(),
|
|
@@ -624,6 +616,62 @@ mod tests {
|
|
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn distinct_requests_do_not_collide_in_completion_cache() {
|
|
|
|
|
+ let _guard = test_env_lock();
|
|
|
|
|
+ let temp_root = std::env::temp_dir().join(format!(
|
|
|
|
|
+ "prompt-cache-distinct-{}-{}",
|
|
|
|
|
+ std::process::id(),
|
|
|
|
|
+ SystemTime::now()
|
|
|
|
|
+ .duration_since(UNIX_EPOCH)
|
|
|
|
|
+ .expect("time")
|
|
|
|
|
+ .as_nanos()
|
|
|
|
|
+ ));
|
|
|
|
|
+ std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
|
|
|
|
|
+ let cache = PromptCache::new("distinct-request-session");
|
|
|
|
|
+ let first_request = sample_request("first");
|
|
|
|
|
+ let second_request = sample_request("second");
|
|
|
|
|
+
|
|
|
|
|
+ let response = sample_response(42, 12, "cached");
|
|
|
|
|
+ let _ = cache.record_response(&first_request, &response);
|
|
|
|
|
+
|
|
|
|
|
+ assert!(cache.lookup_completion(&second_request).is_none());
|
|
|
|
|
+
|
|
|
|
|
+ std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
|
|
|
|
|
+ std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn expired_completion_entries_are_not_reused() {
|
|
|
|
|
+ let _guard = test_env_lock();
|
|
|
|
|
+ let temp_root = std::env::temp_dir().join(format!(
|
|
|
|
|
+ "prompt-cache-expired-{}-{}",
|
|
|
|
|
+ std::process::id(),
|
|
|
|
|
+ SystemTime::now()
|
|
|
|
|
+ .duration_since(UNIX_EPOCH)
|
|
|
|
|
+ .expect("time")
|
|
|
|
|
+ .as_nanos()
|
|
|
|
|
+ ));
|
|
|
|
|
+ std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
|
|
|
|
|
+ let cache = PromptCache::with_config(PromptCacheConfig {
|
|
|
|
|
+ session_id: "expired-session".to_string(),
|
|
|
|
|
+ completion_ttl: Duration::ZERO,
|
|
|
|
|
+ ..PromptCacheConfig::default()
|
|
|
|
|
+ });
|
|
|
|
|
+ let request = sample_request("expire me");
|
|
|
|
|
+ let response = sample_response(7, 3, "stale");
|
|
|
|
|
+
|
|
|
|
|
+ let _ = cache.record_response(&request, &response);
|
|
|
|
|
+
|
|
|
|
|
+ assert!(cache.lookup_completion(&request).is_none());
|
|
|
|
|
+ let stats = cache.stats();
|
|
|
|
|
+ assert_eq!(stats.completion_cache_hits, 0);
|
|
|
|
|
+ assert_eq!(stats.completion_cache_misses, 1);
|
|
|
|
|
+
|
|
|
|
|
+ std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
|
|
|
|
|
+ std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
#[test]
|
|
#[test]
|
|
|
fn sanitize_path_caps_long_values() {
|
|
fn sanitize_path_caps_long_values() {
|
|
|
let long_value = "x".repeat(200);
|
|
let long_value = "x".repeat(200);
|