|
|
@@ -1,18 +1,22 @@
|
|
|
use std::collections::VecDeque;
|
|
|
+use std::sync::{Arc, Mutex};
|
|
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
|
|
|
|
|
+use runtime::format_usd;
|
|
|
use runtime::{
|
|
|
load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
|
|
|
OAuthTokenExchangeRequest,
|
|
|
};
|
|
|
use serde::Deserialize;
|
|
|
-use telemetry::SessionTracer;
|
|
|
+use serde_json::{Map, Value};
|
|
|
+use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, SessionTracer};
|
|
|
|
|
|
use crate::error::ApiError;
|
|
|
+use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
|
|
|
|
|
|
use super::{Provider, ProviderFuture};
|
|
|
use crate::sse::SseParser;
|
|
|
-use crate::types::{MessageRequest, MessageResponse, StreamEvent};
|
|
|
+use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage};
|
|
|
|
|
|
pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
|
|
|
const ANTHROPIC_VERSION: &str = "2023-06-01";
|
|
|
@@ -114,6 +118,10 @@ pub struct AnthropicClient {
|
|
|
max_retries: u32,
|
|
|
initial_backoff: Duration,
|
|
|
max_backoff: Duration,
|
|
|
+ request_profile: AnthropicRequestProfile,
|
|
|
+ session_tracer: Option<SessionTracer>,
|
|
|
+ prompt_cache: Option<PromptCache>,
|
|
|
+ last_prompt_cache_record: Arc<Mutex<Option<PromptCacheRecord>>>,
|
|
|
}
|
|
|
|
|
|
impl AnthropicClient {
|
|
|
@@ -126,6 +134,10 @@ impl AnthropicClient {
|
|
|
max_retries: DEFAULT_MAX_RETRIES,
|
|
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
|
|
max_backoff: DEFAULT_MAX_BACKOFF,
|
|
|
+ request_profile: AnthropicRequestProfile::default(),
|
|
|
+ session_tracer: None,
|
|
|
+ prompt_cache: None,
|
|
|
+ last_prompt_cache_record: Arc::new(Mutex::new(None)),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -138,6 +150,10 @@ impl AnthropicClient {
|
|
|
max_retries: DEFAULT_MAX_RETRIES,
|
|
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
|
|
max_backoff: DEFAULT_MAX_BACKOFF,
|
|
|
+ request_profile: AnthropicRequestProfile::default(),
|
|
|
+ session_tracer: None,
|
|
|
+ prompt_cache: None,
|
|
|
+ last_prompt_cache_record: Arc::new(Mutex::new(None)),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -196,7 +212,66 @@ impl AnthropicClient {
|
|
|
}
|
|
|
|
|
|
#[must_use]
|
|
|
- pub fn with_session_tracer(self, _session_tracer: SessionTracer) -> Self {
|
|
|
+ pub fn with_session_tracer(mut self, session_tracer: SessionTracer) -> Self {
|
|
|
+ self.session_tracer = Some(session_tracer);
|
|
|
+ self
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn with_client_identity(mut self, client_identity: ClientIdentity) -> Self {
|
|
|
+ self.request_profile.client_identity = client_identity;
|
|
|
+ self
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn with_beta(mut self, beta: impl Into<String>) -> Self {
|
|
|
+ self.request_profile = self.request_profile.with_beta(beta);
|
|
|
+ self
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn with_extra_body_param(mut self, key: impl Into<String>, value: Value) -> Self {
|
|
|
+ self.request_profile = self.request_profile.with_extra_body(key, value);
|
|
|
+ self
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn with_prompt_cache(mut self, prompt_cache: PromptCache) -> Self {
|
|
|
+ self.prompt_cache = Some(prompt_cache);
|
|
|
+ self
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn prompt_cache_stats(&self) -> Option<PromptCacheStats> {
|
|
|
+ self.prompt_cache.as_ref().map(PromptCache::stats)
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn request_profile(&self) -> &AnthropicRequestProfile {
|
|
|
+ &self.request_profile
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn session_tracer(&self) -> Option<&SessionTracer> {
|
|
|
+ self.session_tracer.as_ref()
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn prompt_cache(&self) -> Option<&PromptCache> {
|
|
|
+ self.prompt_cache.as_ref()
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn take_last_prompt_cache_record(&self) -> Option<PromptCacheRecord> {
|
|
|
+ self.last_prompt_cache_record
|
|
|
+ .lock()
|
|
|
+ .unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
|
+ .take()
|
|
|
+ }
|
|
|
+
|
|
|
+ #[must_use]
|
|
|
+ pub fn with_request_profile(mut self, request_profile: AnthropicRequestProfile) -> Self {
|
|
|
+ self.request_profile = request_profile;
|
|
|
self
|
|
|
}
|
|
|
|
|
|
@@ -213,6 +288,13 @@ impl AnthropicClient {
|
|
|
stream: false,
|
|
|
..request.clone()
|
|
|
};
|
|
|
+
|
|
|
+ if let Some(prompt_cache) = &self.prompt_cache {
|
|
|
+ if let Some(response) = prompt_cache.lookup_completion(&request) {
|
|
|
+ return Ok(response);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
let response = self.send_with_retry(&request).await?;
|
|
|
let request_id = request_id_from_headers(response.headers());
|
|
|
let mut response = response
|
|
|
@@ -222,6 +304,30 @@ impl AnthropicClient {
|
|
|
if response.request_id.is_none() {
|
|
|
response.request_id = request_id;
|
|
|
}
|
|
|
+
|
|
|
+ if let Some(prompt_cache) = &self.prompt_cache {
|
|
|
+ let record = prompt_cache.record_response(&request, &response);
|
|
|
+ self.store_last_prompt_cache_record(record);
|
|
|
+ }
|
|
|
+ if let Some(session_tracer) = &self.session_tracer {
|
|
|
+ session_tracer.record_analytics(
|
|
|
+ AnalyticsEvent::new("api", "message_usage")
|
|
|
+ .with_property(
|
|
|
+ "request_id",
|
|
|
+ response
|
|
|
+ .request_id
|
|
|
+ .clone()
|
|
|
+ .map_or(Value::Null, Value::String),
|
|
|
+ )
|
|
|
+ .with_property("total_tokens", Value::from(response.total_tokens()))
|
|
|
+ .with_property(
|
|
|
+ "estimated_cost_usd",
|
|
|
+ Value::String(format_usd(
|
|
|
+ response.usage.estimated_cost_usd(&response.model).total_cost_usd(),
|
|
|
+ )),
|
|
|
+ ),
|
|
|
+ );
|
|
|
+ }
|
|
|
Ok(response)
|
|
|
}
|
|
|
|
|
|
@@ -238,6 +344,11 @@ impl AnthropicClient {
|
|
|
parser: SseParser::new(),
|
|
|
pending: VecDeque::new(),
|
|
|
done: false,
|
|
|
+ request: request.clone(),
|
|
|
+ prompt_cache: self.prompt_cache.clone(),
|
|
|
+ latest_usage: None,
|
|
|
+ usage_recorded: false,
|
|
|
+ last_prompt_cache_record: Arc::clone(&self.last_prompt_cache_record),
|
|
|
})
|
|
|
}
|
|
|
|
|
|
@@ -290,18 +401,46 @@ impl AnthropicClient {
|
|
|
|
|
|
loop {
|
|
|
attempts += 1;
|
|
|
+ if let Some(session_tracer) = &self.session_tracer {
|
|
|
+ session_tracer.record_http_request_started(
|
|
|
+ attempts,
|
|
|
+ "POST",
|
|
|
+ "/v1/messages",
|
|
|
+ Map::new(),
|
|
|
+ );
|
|
|
+ }
|
|
|
match self.send_raw_request(request).await {
|
|
|
Ok(response) => match expect_success(response).await {
|
|
|
- Ok(response) => return Ok(response),
|
|
|
+ Ok(response) => {
|
|
|
+ if let Some(session_tracer) = &self.session_tracer {
|
|
|
+ session_tracer.record_http_request_succeeded(
|
|
|
+ attempts,
|
|
|
+ "POST",
|
|
|
+ "/v1/messages",
|
|
|
+ response.status().as_u16(),
|
|
|
+ request_id_from_headers(response.headers()),
|
|
|
+ Map::new(),
|
|
|
+ );
|
|
|
+ }
|
|
|
+ return Ok(response);
|
|
|
+ }
|
|
|
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
|
|
|
+ self.record_request_failure(attempts, &error);
|
|
|
last_error = Some(error);
|
|
|
}
|
|
|
- Err(error) => return Err(error),
|
|
|
+ Err(error) => {
|
|
|
+ self.record_request_failure(attempts, &error);
|
|
|
+ return Err(error);
|
|
|
+ }
|
|
|
},
|
|
|
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
|
|
|
+ self.record_request_failure(attempts, &error);
|
|
|
last_error = Some(error);
|
|
|
}
|
|
|
- Err(error) => return Err(error),
|
|
|
+ Err(error) => {
|
|
|
+ self.record_request_failure(attempts, &error);
|
|
|
+ return Err(error);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
if attempts > self.max_retries {
|
|
|
@@ -325,14 +464,37 @@ impl AnthropicClient {
|
|
|
let request_builder = self
|
|
|
.http
|
|
|
.post(&request_url)
|
|
|
- .header("anthropic-version", ANTHROPIC_VERSION)
|
|
|
.header("content-type", "application/json");
|
|
|
let mut request_builder = self.auth.apply(request_builder);
|
|
|
+ for (header_name, header_value) in self.request_profile.header_pairs() {
|
|
|
+ request_builder = request_builder.header(header_name, header_value);
|
|
|
+ }
|
|
|
|
|
|
- request_builder = request_builder.json(request);
|
|
|
+ let request_body = self.request_profile.render_json_body(request)?;
|
|
|
+ request_builder = request_builder.json(&request_body);
|
|
|
request_builder.send().await.map_err(ApiError::from)
|
|
|
}
|
|
|
|
|
|
+ fn record_request_failure(&self, attempt: u32, error: &ApiError) {
|
|
|
+ if let Some(session_tracer) = &self.session_tracer {
|
|
|
+ session_tracer.record_http_request_failed(
|
|
|
+ attempt,
|
|
|
+ "POST",
|
|
|
+ "/v1/messages",
|
|
|
+ error.to_string(),
|
|
|
+ error.is_retryable(),
|
|
|
+ Map::new(),
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ fn store_last_prompt_cache_record(&self, record: PromptCacheRecord) {
|
|
|
+ *self
|
|
|
+ .last_prompt_cache_record
|
|
|
+ .lock()
|
|
|
+ .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record);
|
|
|
+ }
|
|
|
+
|
|
|
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
|
|
let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
|
|
|
return Err(ApiError::BackoffOverflow {
|
|
|
@@ -571,6 +733,11 @@ pub struct MessageStream {
|
|
|
parser: SseParser,
|
|
|
pending: VecDeque<StreamEvent>,
|
|
|
done: bool,
|
|
|
+ request: MessageRequest,
|
|
|
+ prompt_cache: Option<PromptCache>,
|
|
|
+ latest_usage: Option<Usage>,
|
|
|
+ usage_recorded: bool,
|
|
|
+ last_prompt_cache_record: Arc<Mutex<Option<PromptCacheRecord>>>,
|
|
|
}
|
|
|
|
|
|
impl MessageStream {
|
|
|
@@ -582,6 +749,7 @@ impl MessageStream {
|
|
|
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
|
|
|
loop {
|
|
|
if let Some(event) = self.pending.pop_front() {
|
|
|
+ self.observe_event(&event);
|
|
|
return Ok(Some(event));
|
|
|
}
|
|
|
|
|
|
@@ -604,6 +772,29 @@ impl MessageStream {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ fn observe_event(&mut self, event: &StreamEvent) {
|
|
|
+ match event {
|
|
|
+ StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => {
|
|
|
+ self.latest_usage = Some(usage.clone());
|
|
|
+ }
|
|
|
+ StreamEvent::MessageStop(_) => {
|
|
|
+ if !self.usage_recorded {
|
|
|
+ if let (Some(prompt_cache), Some(usage)) =
|
|
|
+ (&self.prompt_cache, self.latest_usage.as_ref())
|
|
|
+ {
|
|
|
+ let record = prompt_cache.record_usage(&self.request, usage);
|
|
|
+ *self
|
|
|
+ .last_prompt_cache_record
|
|
|
+ .lock()
|
|
|
+ .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record);
|
|
|
+ }
|
|
|
+ self.usage_recorded = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ _ => {}
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
|