Selaa lähdekoodia

fix: restore anthropic request profile integration

YeonGyu-Kim 2 kuukautta sitten
vanhempi
commit
de589d47a5
2 muutettua tiedostoa jossa 204 lisäystä ja 8 poistoa
  1. 5 0
      rust/crates/api/src/lib.rs
  2. 199 8
      rust/crates/api/src/providers/anthropic.rs

+ 5 - 0
rust/crates/api/src/lib.rs

@@ -1,5 +1,6 @@
 mod client;
 mod error;
+mod prompt_cache;
 mod providers;
 mod sse;
 mod types;
@@ -9,6 +10,10 @@ pub use client::{
     resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient,
 };
 pub use error::ApiError;
+pub use prompt_cache::{
+    CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord,
+    PromptCacheStats,
+};
 pub use providers::anthropic::{AnthropicClient, AnthropicClient as ApiClient, AuthSource};
 pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig};
 pub use providers::{

+ 199 - 8
rust/crates/api/src/providers/anthropic.rs

@@ -1,18 +1,22 @@
 use std::collections::VecDeque;
+use std::sync::{Arc, Mutex};
 use std::time::{Duration, SystemTime, UNIX_EPOCH};
 
+use runtime::format_usd;
 use runtime::{
     load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
     OAuthTokenExchangeRequest,
 };
 use serde::Deserialize;
-use telemetry::SessionTracer;
+use serde_json::{Map, Value};
+use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, SessionTracer};
 
 use crate::error::ApiError;
+use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
 
 use super::{Provider, ProviderFuture};
 use crate::sse::SseParser;
-use crate::types::{MessageRequest, MessageResponse, StreamEvent};
+use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage};
 
 pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
 const ANTHROPIC_VERSION: &str = "2023-06-01";
@@ -114,6 +118,10 @@ pub struct AnthropicClient {
     max_retries: u32,
     initial_backoff: Duration,
     max_backoff: Duration,
+    request_profile: AnthropicRequestProfile,
+    session_tracer: Option<SessionTracer>,
+    prompt_cache: Option<PromptCache>,
+    last_prompt_cache_record: Arc<Mutex<Option<PromptCacheRecord>>>,
 }
 
 impl AnthropicClient {
@@ -126,6 +134,10 @@ impl AnthropicClient {
             max_retries: DEFAULT_MAX_RETRIES,
             initial_backoff: DEFAULT_INITIAL_BACKOFF,
             max_backoff: DEFAULT_MAX_BACKOFF,
+            request_profile: AnthropicRequestProfile::default(),
+            session_tracer: None,
+            prompt_cache: None,
+            last_prompt_cache_record: Arc::new(Mutex::new(None)),
         }
     }
 
@@ -138,6 +150,10 @@ impl AnthropicClient {
             max_retries: DEFAULT_MAX_RETRIES,
             initial_backoff: DEFAULT_INITIAL_BACKOFF,
             max_backoff: DEFAULT_MAX_BACKOFF,
+            request_profile: AnthropicRequestProfile::default(),
+            session_tracer: None,
+            prompt_cache: None,
+            last_prompt_cache_record: Arc::new(Mutex::new(None)),
         }
     }
 
@@ -196,7 +212,66 @@ impl AnthropicClient {
     }
 
     #[must_use]
-    pub fn with_session_tracer(self, _session_tracer: SessionTracer) -> Self {
+    pub fn with_session_tracer(mut self, session_tracer: SessionTracer) -> Self {
+        self.session_tracer = Some(session_tracer);
+        self
+    }
+
+    #[must_use]
+    pub fn with_client_identity(mut self, client_identity: ClientIdentity) -> Self {
+        self.request_profile.client_identity = client_identity;
+        self
+    }
+
+    #[must_use]
+    pub fn with_beta(mut self, beta: impl Into<String>) -> Self {
+        self.request_profile = self.request_profile.with_beta(beta);
+        self
+    }
+
+    #[must_use]
+    pub fn with_extra_body_param(mut self, key: impl Into<String>, value: Value) -> Self {
+        self.request_profile = self.request_profile.with_extra_body(key, value);
+        self
+    }
+
+    #[must_use]
+    pub fn with_prompt_cache(mut self, prompt_cache: PromptCache) -> Self {
+        self.prompt_cache = Some(prompt_cache);
+        self
+    }
+
+    #[must_use]
+    pub fn prompt_cache_stats(&self) -> Option<PromptCacheStats> {
+        self.prompt_cache.as_ref().map(PromptCache::stats)
+    }
+
+    #[must_use]
+    pub fn request_profile(&self) -> &AnthropicRequestProfile {
+        &self.request_profile
+    }
+
+    #[must_use]
+    pub fn session_tracer(&self) -> Option<&SessionTracer> {
+        self.session_tracer.as_ref()
+    }
+
+    #[must_use]
+    pub fn prompt_cache(&self) -> Option<&PromptCache> {
+        self.prompt_cache.as_ref()
+    }
+
+    #[must_use]
+    pub fn take_last_prompt_cache_record(&self) -> Option<PromptCacheRecord> {
+        self.last_prompt_cache_record
+            .lock()
+            .unwrap_or_else(std::sync::PoisonError::into_inner)
+            .take()
+    }
+
+    #[must_use]
+    pub fn with_request_profile(mut self, request_profile: AnthropicRequestProfile) -> Self {
+        self.request_profile = request_profile;
         self
     }
 
@@ -213,6 +288,13 @@ impl AnthropicClient {
             stream: false,
             ..request.clone()
         };
+
+        if let Some(prompt_cache) = &self.prompt_cache {
+            if let Some(response) = prompt_cache.lookup_completion(&request) {
+                return Ok(response);
+            }
+        }
+
         let response = self.send_with_retry(&request).await?;
         let request_id = request_id_from_headers(response.headers());
         let mut response = response
@@ -222,6 +304,30 @@ impl AnthropicClient {
         if response.request_id.is_none() {
             response.request_id = request_id;
         }
+
+        if let Some(prompt_cache) = &self.prompt_cache {
+            let record = prompt_cache.record_response(&request, &response);
+            self.store_last_prompt_cache_record(record);
+        }
+        if let Some(session_tracer) = &self.session_tracer {
+            session_tracer.record_analytics(
+                AnalyticsEvent::new("api", "message_usage")
+                    .with_property(
+                        "request_id",
+                        response
+                            .request_id
+                            .clone()
+                            .map_or(Value::Null, Value::String),
+                    )
+                    .with_property("total_tokens", Value::from(response.total_tokens()))
+                    .with_property(
+                        "estimated_cost_usd",
+                        Value::String(format_usd(
+                            response.usage.estimated_cost_usd(&response.model).total_cost_usd(),
+                        )),
+                    ),
+            );
+        }
         Ok(response)
     }
 
@@ -238,6 +344,11 @@ impl AnthropicClient {
             parser: SseParser::new(),
             pending: VecDeque::new(),
             done: false,
+            request: request.clone(),
+            prompt_cache: self.prompt_cache.clone(),
+            latest_usage: None,
+            usage_recorded: false,
+            last_prompt_cache_record: Arc::clone(&self.last_prompt_cache_record),
         })
     }
 
@@ -290,18 +401,46 @@ impl AnthropicClient {
 
         loop {
             attempts += 1;
+            if let Some(session_tracer) = &self.session_tracer {
+                session_tracer.record_http_request_started(
+                    attempts,
+                    "POST",
+                    "/v1/messages",
+                    Map::new(),
+                );
+            }
             match self.send_raw_request(request).await {
                 Ok(response) => match expect_success(response).await {
-                    Ok(response) => return Ok(response),
+                    Ok(response) => {
+                        if let Some(session_tracer) = &self.session_tracer {
+                            session_tracer.record_http_request_succeeded(
+                                attempts,
+                                "POST",
+                                "/v1/messages",
+                                response.status().as_u16(),
+                                request_id_from_headers(response.headers()),
+                                Map::new(),
+                            );
+                        }
+                        return Ok(response);
+                    }
                     Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
+                        self.record_request_failure(attempts, &error);
                         last_error = Some(error);
                     }
-                    Err(error) => return Err(error),
+                    Err(error) => {
+                        self.record_request_failure(attempts, &error);
+                        return Err(error);
+                    }
                 },
                 Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
+                    self.record_request_failure(attempts, &error);
                     last_error = Some(error);
                 }
-                Err(error) => return Err(error),
+                Err(error) => {
+                    self.record_request_failure(attempts, &error);
+                    return Err(error);
+                }
             }
 
             if attempts > self.max_retries {
@@ -325,14 +464,37 @@ impl AnthropicClient {
         let request_builder = self
             .http
             .post(&request_url)
-            .header("anthropic-version", ANTHROPIC_VERSION)
             .header("content-type", "application/json");
         let mut request_builder = self.auth.apply(request_builder);
+        for (header_name, header_value) in self.request_profile.header_pairs() {
+            request_builder = request_builder.header(header_name, header_value);
+        }
 
-        request_builder = request_builder.json(request);
+        let request_body = self.request_profile.render_json_body(request)?;
+        request_builder = request_builder.json(&request_body);
         request_builder.send().await.map_err(ApiError::from)
     }
 
+    fn record_request_failure(&self, attempt: u32, error: &ApiError) {
+        if let Some(session_tracer) = &self.session_tracer {
+            session_tracer.record_http_request_failed(
+                attempt,
+                "POST",
+                "/v1/messages",
+                error.to_string(),
+                error.is_retryable(),
+                Map::new(),
+            );
+        }
+    }
+
+    fn store_last_prompt_cache_record(&self, record: PromptCacheRecord) {
+        *self
+            .last_prompt_cache_record
+            .lock()
+            .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record);
+    }
+
     fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
         let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
             return Err(ApiError::BackoffOverflow {
@@ -571,6 +733,11 @@ pub struct MessageStream {
     parser: SseParser,
     pending: VecDeque<StreamEvent>,
     done: bool,
+    request: MessageRequest,
+    prompt_cache: Option<PromptCache>,
+    latest_usage: Option<Usage>,
+    usage_recorded: bool,
+    last_prompt_cache_record: Arc<Mutex<Option<PromptCacheRecord>>>,
 }
 
 impl MessageStream {
@@ -582,6 +749,7 @@ impl MessageStream {
     pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
         loop {
             if let Some(event) = self.pending.pop_front() {
+                self.observe_event(&event);
                 return Ok(Some(event));
             }
 
@@ -604,6 +772,29 @@ impl MessageStream {
             }
         }
     }
+
+    fn observe_event(&mut self, event: &StreamEvent) {
+        match event {
+            StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => {
+                self.latest_usage = Some(usage.clone());
+            }
+            StreamEvent::MessageStop(_) => {
+                if !self.usage_recorded {
+                    if let (Some(prompt_cache), Some(usage)) =
+                        (&self.prompt_cache, self.latest_usage.as_ref())
+                    {
+                        let record = prompt_cache.record_usage(&self.request, usage);
+                        *self
+                            .last_prompt_cache_record
+                            .lock()
+                            .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record);
+                    }
+                    self.usage_recorded = true;
+                }
+            }
+            _ => {}
+        }
+    }
 }
 
 async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {