Explorar o código

wip: cache-tracking progress

Yeachan-Heo hai 2 meses
pai
achega
26344c578b

+ 11 - 18
rust/crates/api/src/client.rs

@@ -689,7 +689,6 @@ mod tests {
     use std::io::{Read, Write};
     use std::net::TcpListener;
     use std::sync::atomic::{AtomicU64, Ordering};
-    use std::sync::{Mutex, OnceLock};
     use std::thread;
     use std::time::{Duration, SystemTime, UNIX_EPOCH};
 
@@ -699,15 +698,9 @@ mod tests {
         now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
         resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
     };
+    use crate::test_env_lock;
     use crate::types::{ContentBlockDelta, MessageRequest};
 
-    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
-        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
-        LOCK.get_or_init(|| Mutex::new(()))
-            .lock()
-            .unwrap_or_else(std::sync::PoisonError::into_inner)
-    }
-
     fn temp_config_home() -> std::path::PathBuf {
         static NEXT_ID: AtomicU64 = AtomicU64::new(0);
         std::env::temp_dir().join(format!(
@@ -753,7 +746,7 @@ mod tests {
 
     #[test]
     fn read_api_key_requires_presence() {
-        let _guard = env_lock();
+        let _guard = test_env_lock();
         std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
         std::env::remove_var("ANTHROPIC_API_KEY");
         std::env::remove_var("CLAUDE_CONFIG_HOME");
@@ -763,7 +756,7 @@ mod tests {
 
     #[test]
     fn read_api_key_requires_non_empty_value() {
-        let _guard = env_lock();
+        let _guard = test_env_lock();
         std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
         std::env::remove_var("ANTHROPIC_API_KEY");
         let error = super::read_api_key().expect_err("empty key should error");
@@ -773,7 +766,7 @@ mod tests {
 
     #[test]
     fn read_api_key_prefers_api_key_env() {
-        let _guard = env_lock();
+        let _guard = test_env_lock();
         std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
         std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
         assert_eq!(
@@ -786,7 +779,7 @@ mod tests {
 
     #[test]
     fn read_auth_token_reads_auth_token_env() {
-        let _guard = env_lock();
+        let _guard = test_env_lock();
         std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
         assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
         std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -806,7 +799,7 @@ mod tests {
 
     #[test]
     fn auth_source_from_env_combines_api_key_and_bearer_token() {
-        let _guard = env_lock();
+        let _guard = test_env_lock();
         std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
         std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
         let auth = AuthSource::from_env().expect("env auth");
@@ -818,7 +811,7 @@ mod tests {
 
     #[test]
     fn auth_source_from_saved_oauth_when_env_absent() {
-        let _guard = env_lock();
+        let _guard = test_env_lock();
         let config_home = temp_config_home();
         std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
         std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -857,7 +850,7 @@ mod tests {
 
     #[test]
     fn resolve_saved_oauth_token_refreshes_expired_credentials() {
-        let _guard = env_lock();
+        let _guard = test_env_lock();
         let config_home = temp_config_home();
         std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
         std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -889,7 +882,7 @@ mod tests {
 
     #[test]
     fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
-        let _guard = env_lock();
+        let _guard = test_env_lock();
         let config_home = temp_config_home();
         std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
         std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -913,7 +906,7 @@ mod tests {
 
     #[test]
     fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
-        let _guard = env_lock();
+        let _guard = test_env_lock();
         let config_home = temp_config_home();
         std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
         std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
@@ -945,7 +938,7 @@ mod tests {
 
     #[test]
     fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
-        let _guard = env_lock();
+        let _guard = test_env_lock();
         let config_home = temp_config_home();
         std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
         std::env::remove_var("ANTHROPIC_AUTH_TOKEN");

+ 8 - 0
rust/crates/api/src/lib.rs

@@ -20,3 +20,11 @@ pub use types::{
     MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
     ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
 };
+
+#[cfg(test)]
+pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
+    static LOCK: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
+    LOCK.get_or_init(|| std::sync::Mutex::new(()))
+        .lock()
+        .unwrap_or_else(std::sync::PoisonError::into_inner)
+}

+ 77 - 29
rust/crates/api/src/prompt_cache.rs

@@ -141,6 +141,7 @@ impl PromptCache {
         self.lock().stats.clone()
     }
 
+    #[must_use]
     pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> {
         let request_hash = request_hash_hex(request);
         let (paths, ttl) = {
@@ -191,6 +192,7 @@ impl PromptCache {
         Some(entry.response)
     }
 
+    #[must_use]
     pub fn record_response(
         &self,
         request: &MessageRequest,
@@ -199,6 +201,7 @@ impl PromptCache {
         self.record_usage_internal(request, &response.usage, Some(response))
     }
 
+    #[must_use]
     pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord {
         self.record_usage_internal(request, usage, None)
     }
@@ -267,7 +270,6 @@ struct TrackedPromptState {
     observed_at_unix_secs: u64,
     #[serde(default = "current_fingerprint_version")]
     fingerprint_version: u32,
-    request_hash: u64,
     model_hash: u64,
     system_hash: u64,
     tools_hash: u64,
@@ -277,37 +279,34 @@ struct TrackedPromptState {
 
 impl TrackedPromptState {
     fn from_usage(request: &MessageRequest, usage: &Usage) -> Self {
-        let hashes = RequestHashes::from_request(request);
+        let hashes = RequestFingerprints::from_request(request);
         Self {
             observed_at_unix_secs: now_unix_secs(),
             fingerprint_version: current_fingerprint_version(),
-            request_hash: hashes.request_hash,
-            model_hash: hashes.model_hash,
-            system_hash: hashes.system_hash,
-            tools_hash: hashes.tools_hash,
-            messages_hash: hashes.messages_hash,
+            model_hash: hashes.model,
+            system_hash: hashes.system,
+            tools_hash: hashes.tools,
+            messages_hash: hashes.messages,
             cache_read_input_tokens: usage.cache_read_input_tokens,
         }
     }
 }
 
 #[derive(Debug, Clone, Copy)]
-struct RequestHashes {
-    request_hash: u64,
-    model_hash: u64,
-    system_hash: u64,
-    tools_hash: u64,
-    messages_hash: u64,
+struct RequestFingerprints {
+    model: u64,
+    system: u64,
+    tools: u64,
+    messages: u64,
 }
 
-impl RequestHashes {
+impl RequestFingerprints {
     fn from_request(request: &MessageRequest) -> Self {
         Self {
-            request_hash: hash_serializable(request),
-            model_hash: hash_serializable(&request.model),
-            system_hash: hash_serializable(&request.system),
-            tools_hash: hash_serializable(&request.tools),
-            messages_hash: hash_serializable(&request.messages),
+            model: hash_serializable(&request.model),
+            system: hash_serializable(&request.system),
+            tools: hash_serializable(&request.tools),
+            messages: hash_serializable(&request.messages),
         }
     }
 }
@@ -501,22 +500,15 @@ fn stable_hash_bytes(bytes: &[u8]) -> u64 {
 
 #[cfg(test)]
 mod tests {
-    use std::sync::{Mutex, OnceLock};
-    use std::time::{SystemTime, UNIX_EPOCH};
+    use std::time::{Duration, SystemTime, UNIX_EPOCH};
 
     use super::{
         detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache,
         PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX,
     };
+    use crate::test_env_lock;
     use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage};
 
-    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
-        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
-        LOCK.get_or_init(|| Mutex::new(()))
-            .lock()
-            .unwrap_or_else(std::sync::PoisonError::into_inner)
-    }
-
     #[test]
     fn path_builder_sanitizes_session_identifier() {
         let paths = PromptCachePaths::for_session("session:/with spaces");
@@ -588,7 +580,7 @@ mod tests {
 
     #[test]
     fn completion_cache_round_trip_persists_recent_response() {
-        let _guard = env_lock();
+        let _guard = test_env_lock();
         let temp_root = std::env::temp_dir().join(format!(
             "prompt-cache-test-{}-{}",
             std::process::id(),
@@ -624,6 +616,62 @@ mod tests {
         std::env::remove_var("CLAUDE_CONFIG_HOME");
     }
 
+    #[test]
+    fn distinct_requests_do_not_collide_in_completion_cache() {
+        let _guard = test_env_lock();
+        let temp_root = std::env::temp_dir().join(format!(
+            "prompt-cache-distinct-{}-{}",
+            std::process::id(),
+            SystemTime::now()
+                .duration_since(UNIX_EPOCH)
+                .expect("time")
+                .as_nanos()
+        ));
+        std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
+        let cache = PromptCache::new("distinct-request-session");
+        let first_request = sample_request("first");
+        let second_request = sample_request("second");
+
+        let response = sample_response(42, 12, "cached");
+        let _ = cache.record_response(&first_request, &response);
+
+        assert!(cache.lookup_completion(&second_request).is_none());
+
+        std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
+        std::env::remove_var("CLAUDE_CONFIG_HOME");
+    }
+
+    #[test]
+    fn expired_completion_entries_are_not_reused() {
+        let _guard = test_env_lock();
+        let temp_root = std::env::temp_dir().join(format!(
+            "prompt-cache-expired-{}-{}",
+            std::process::id(),
+            SystemTime::now()
+                .duration_since(UNIX_EPOCH)
+                .expect("time")
+                .as_nanos()
+        ));
+        std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
+        let cache = PromptCache::with_config(PromptCacheConfig {
+            session_id: "expired-session".to_string(),
+            completion_ttl: Duration::ZERO,
+            ..PromptCacheConfig::default()
+        });
+        let request = sample_request("expire me");
+        let response = sample_response(7, 3, "stale");
+
+        let _ = cache.record_response(&request, &response);
+
+        assert!(cache.lookup_completion(&request).is_none());
+        let stats = cache.stats();
+        assert_eq!(stats.completion_cache_hits, 0);
+        assert_eq!(stats.completion_cache_misses, 1);
+
+        std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
+        std::env::remove_var("CLAUDE_CONFIG_HOME");
+    }
+
     #[test]
     fn sanitize_path_caps_long_values() {
         let long_value = "x".repeat(200);

+ 17 - 11
rust/crates/api/tests/client_integration.rs

@@ -84,6 +84,7 @@ async fn send_message_posts_json_and_parses_response() {
 }
 
 #[tokio::test]
+#[allow(clippy::await_holding_lock)]
 async fn stream_message_parses_sse_events_with_tool_use() {
     let _guard = env_lock();
     let temp_root = std::env::temp_dir().join(format!(
@@ -180,12 +181,15 @@ async fn stream_message_parses_sse_events_with_tool_use() {
     let request = captured.first().expect("server should capture request");
     assert!(request.body.contains("\"stream\":true"));
 
-    let stats = client
+    let cache_stats = client
         .prompt_cache_stats()
         .expect("prompt cache stats should exist");
-    assert_eq!(stats.tracked_requests, 1);
-    assert_eq!(stats.last_cache_read_input_tokens, Some(0));
-    assert_eq!(stats.last_cache_source.as_deref(), Some("api-response"));
+    assert_eq!(cache_stats.tracked_requests, 1);
+    assert_eq!(cache_stats.last_cache_read_input_tokens, Some(0));
+    assert_eq!(
+        cache_stats.last_cache_source.as_deref(),
+        Some("api-response")
+    );
 
     std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
     std::env::remove_var("CLAUDE_CONFIG_HOME");
@@ -273,6 +277,7 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
 }
 
 #[tokio::test]
+#[allow(clippy::await_holding_lock)]
 async fn send_message_reuses_recent_completion_cache_entries() {
     let _guard = env_lock();
     let temp_root = std::env::temp_dir().join(format!(
@@ -312,18 +317,19 @@ async fn send_message_reuses_recent_completion_cache_entries() {
     assert_eq!(first.content, second.content);
     assert_eq!(state.lock().await.len(), 1);
 
-    let stats = client
+    let cache_stats = client
         .prompt_cache_stats()
         .expect("prompt cache stats should exist");
-    assert_eq!(stats.completion_cache_hits, 1);
-    assert_eq!(stats.completion_cache_misses, 1);
-    assert_eq!(stats.completion_cache_writes, 1);
+    assert_eq!(cache_stats.completion_cache_hits, 1);
+    assert_eq!(cache_stats.completion_cache_misses, 1);
+    assert_eq!(cache_stats.completion_cache_writes, 1);
 
     std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
     std::env::remove_var("CLAUDE_CONFIG_HOME");
 }
 
 #[tokio::test]
+#[allow(clippy::await_holding_lock)]
 async fn send_message_tracks_unexpected_prompt_cache_breaks() {
     let _guard = env_lock();
     let temp_root = std::env::temp_dir().join(format!(
@@ -372,12 +378,12 @@ async fn send_message_tracks_unexpected_prompt_cache_breaks() {
         .await
         .expect("second response should succeed");
 
-    let stats = client
+    let cache_stats = client
         .prompt_cache_stats()
         .expect("prompt cache stats should exist");
-    assert_eq!(stats.unexpected_cache_breaks, 1);
+    assert_eq!(cache_stats.unexpected_cache_breaks, 1);
     assert_eq!(
-        stats.last_break_reason.as_deref(),
+        cache_stats.last_break_reason.as_deref(),
         Some("cache read tokens dropped while prompt fingerprint remained stable")
     );