|
@@ -1,994 +1,142 @@
|
|
|
-use std::collections::VecDeque;
|
|
|
|
|
-use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
|
|
|
|
-
|
|
|
|
|
-use runtime::{
|
|
|
|
|
- load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
|
|
|
|
|
- OAuthTokenExchangeRequest,
|
|
|
|
|
-};
|
|
|
|
|
-use serde::Deserialize;
|
|
|
|
|
-
|
|
|
|
|
use crate::error::ApiError;
|
|
use crate::error::ApiError;
|
|
|
-use crate::sse::SseParser;
|
|
|
|
|
|
|
+use crate::providers::anthropic::{self, AnthropicClient, AuthSource};
|
|
|
|
|
+use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig};
|
|
|
|
|
+use crate::providers::{self, Provider, ProviderKind};
|
|
|
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
|
|
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
|
|
|
|
|
|
|
|
-const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
|
|
|
|
|
-const ANTHROPIC_VERSION: &str = "2023-06-01";
|
|
|
|
|
-const REQUEST_ID_HEADER: &str = "request-id";
|
|
|
|
|
-const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
|
|
|
|
|
-const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
|
|
|
|
|
-const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
|
|
|
|
|
-const DEFAULT_MAX_RETRIES: u32 = 2;
|
|
|
|
|
-
|
|
|
|
|
-#[derive(Debug, Clone, PartialEq, Eq)]
|
|
|
|
|
-pub enum AuthSource {
|
|
|
|
|
- None,
|
|
|
|
|
- ApiKey(String),
|
|
|
|
|
- BearerToken(String),
|
|
|
|
|
- ApiKeyAndBearer {
|
|
|
|
|
- api_key: String,
|
|
|
|
|
- bearer_token: String,
|
|
|
|
|
- },
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-impl AuthSource {
|
|
|
|
|
- pub fn from_env() -> Result<Self, ApiError> {
|
|
|
|
|
- let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?;
|
|
|
|
|
- let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?;
|
|
|
|
|
- match (api_key, auth_token) {
|
|
|
|
|
- (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer {
|
|
|
|
|
- api_key,
|
|
|
|
|
- bearer_token,
|
|
|
|
|
- }),
|
|
|
|
|
- (Some(api_key), None) => Ok(Self::ApiKey(api_key)),
|
|
|
|
|
- (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
|
|
|
|
|
- (None, None) => Err(ApiError::MissingApiKey),
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[must_use]
|
|
|
|
|
- pub fn api_key(&self) -> Option<&str> {
|
|
|
|
|
- match self {
|
|
|
|
|
- Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
|
|
|
|
|
- Self::None | Self::BearerToken(_) => None,
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[must_use]
|
|
|
|
|
- pub fn bearer_token(&self) -> Option<&str> {
|
|
|
|
|
- match self {
|
|
|
|
|
- Self::BearerToken(token)
|
|
|
|
|
- | Self::ApiKeyAndBearer {
|
|
|
|
|
- bearer_token: token,
|
|
|
|
|
- ..
|
|
|
|
|
- } => Some(token),
|
|
|
|
|
- Self::None | Self::ApiKey(_) => None,
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[must_use]
|
|
|
|
|
- pub fn masked_authorization_header(&self) -> &'static str {
|
|
|
|
|
- if self.bearer_token().is_some() {
|
|
|
|
|
- "Bearer [REDACTED]"
|
|
|
|
|
- } else {
|
|
|
|
|
- "<absent>"
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
|
|
|
|
- if let Some(api_key) = self.api_key() {
|
|
|
|
|
- request_builder = request_builder.header("x-api-key", api_key);
|
|
|
|
|
- }
|
|
|
|
|
- if let Some(token) = self.bearer_token() {
|
|
|
|
|
- request_builder = request_builder.bearer_auth(token);
|
|
|
|
|
- }
|
|
|
|
|
- request_builder
|
|
|
|
|
- }
|
|
|
|
|
|
|
+async fn send_via_provider<P: Provider>(
|
|
|
|
|
+ provider: &P,
|
|
|
|
|
+ request: &MessageRequest,
|
|
|
|
|
+) -> Result<MessageResponse, ApiError> {
|
|
|
|
|
+ provider.send_message(request).await
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
|
|
|
|
|
-pub struct OAuthTokenSet {
|
|
|
|
|
- pub access_token: String,
|
|
|
|
|
- pub refresh_token: Option<String>,
|
|
|
|
|
- pub expires_at: Option<u64>,
|
|
|
|
|
- #[serde(default)]
|
|
|
|
|
- pub scopes: Vec<String>,
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-impl From<OAuthTokenSet> for AuthSource {
|
|
|
|
|
- fn from(value: OAuthTokenSet) -> Self {
|
|
|
|
|
- Self::BearerToken(value.access_token)
|
|
|
|
|
- }
|
|
|
|
|
|
|
+async fn stream_via_provider<P: Provider>(
|
|
|
|
|
+ provider: &P,
|
|
|
|
|
+ request: &MessageRequest,
|
|
|
|
|
+) -> Result<P::Stream, ApiError> {
|
|
|
|
|
+ provider.stream_message(request).await
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)]
|
|
#[derive(Debug, Clone)]
|
|
|
-pub struct AnthropicClient {
|
|
|
|
|
- http: reqwest::Client,
|
|
|
|
|
- auth: AuthSource,
|
|
|
|
|
- base_url: String,
|
|
|
|
|
- max_retries: u32,
|
|
|
|
|
- initial_backoff: Duration,
|
|
|
|
|
- max_backoff: Duration,
|
|
|
|
|
|
|
+pub enum ProviderClient {
|
|
|
|
|
+ Anthropic(AnthropicClient),
|
|
|
|
|
+ Xai(OpenAiCompatClient),
|
|
|
|
|
+ OpenAi(OpenAiCompatClient),
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-impl AnthropicClient {
|
|
|
|
|
- #[must_use]
|
|
|
|
|
- pub fn new(api_key: impl Into<String>) -> Self {
|
|
|
|
|
- Self {
|
|
|
|
|
- http: reqwest::Client::new(),
|
|
|
|
|
- auth: AuthSource::ApiKey(api_key.into()),
|
|
|
|
|
- base_url: DEFAULT_BASE_URL.to_string(),
|
|
|
|
|
- max_retries: DEFAULT_MAX_RETRIES,
|
|
|
|
|
- initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
|
|
|
|
- max_backoff: DEFAULT_MAX_BACKOFF,
|
|
|
|
|
- }
|
|
|
|
|
|
|
+impl ProviderClient {
|
|
|
|
|
+ pub fn from_model(model: &str) -> Result<Self, ApiError> {
|
|
|
|
|
+ Self::from_model_with_anthropic_auth(model, None)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- #[must_use]
|
|
|
|
|
- pub fn from_auth(auth: AuthSource) -> Self {
|
|
|
|
|
- Self {
|
|
|
|
|
- http: reqwest::Client::new(),
|
|
|
|
|
- auth,
|
|
|
|
|
- base_url: DEFAULT_BASE_URL.to_string(),
|
|
|
|
|
- max_retries: DEFAULT_MAX_RETRIES,
|
|
|
|
|
- initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
|
|
|
|
- max_backoff: DEFAULT_MAX_BACKOFF,
|
|
|
|
|
|
|
+ pub fn from_model_with_anthropic_auth(
|
|
|
|
|
+ model: &str,
|
|
|
|
|
+ anthropic_auth: Option<AuthSource>,
|
|
|
|
|
+ ) -> Result<Self, ApiError> {
|
|
|
|
|
+ let resolved_model = providers::resolve_model_alias(model);
|
|
|
|
|
+ match providers::detect_provider_kind(&resolved_model) {
|
|
|
|
|
+ ProviderKind::Anthropic => Ok(Self::Anthropic(
|
|
|
|
|
+ anthropic_auth
|
|
|
|
|
+ .map(AnthropicClient::from_auth)
|
|
|
|
|
+ .unwrap_or(AnthropicClient::from_env()?),
|
|
|
|
|
+ )),
|
|
|
|
|
+ ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env(
|
|
|
|
|
+ OpenAiCompatConfig::xai(),
|
|
|
|
|
+ )?)),
|
|
|
|
|
+ ProviderKind::OpenAi => Ok(Self::OpenAi(OpenAiCompatClient::from_env(
|
|
|
|
|
+ OpenAiCompatConfig::openai(),
|
|
|
|
|
+ )?)),
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- pub fn from_env() -> Result<Self, ApiError> {
|
|
|
|
|
- Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[must_use]
|
|
|
|
|
- pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
|
|
|
|
|
- self.auth = auth;
|
|
|
|
|
- self
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
#[must_use]
|
|
#[must_use]
|
|
|
- pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
|
|
|
|
|
- match (
|
|
|
|
|
- self.auth.api_key().map(ToOwned::to_owned),
|
|
|
|
|
- auth_token.filter(|token| !token.is_empty()),
|
|
|
|
|
- ) {
|
|
|
|
|
- (Some(api_key), Some(bearer_token)) => {
|
|
|
|
|
- self.auth = AuthSource::ApiKeyAndBearer {
|
|
|
|
|
- api_key,
|
|
|
|
|
- bearer_token,
|
|
|
|
|
- };
|
|
|
|
|
- }
|
|
|
|
|
- (Some(api_key), None) => {
|
|
|
|
|
- self.auth = AuthSource::ApiKey(api_key);
|
|
|
|
|
- }
|
|
|
|
|
- (None, Some(bearer_token)) => {
|
|
|
|
|
- self.auth = AuthSource::BearerToken(bearer_token);
|
|
|
|
|
- }
|
|
|
|
|
- (None, None) => {
|
|
|
|
|
- self.auth = AuthSource::None;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ pub const fn provider_kind(&self) -> ProviderKind {
|
|
|
|
|
+ match self {
|
|
|
|
|
+ Self::Anthropic(_) => ProviderKind::Anthropic,
|
|
|
|
|
+ Self::Xai(_) => ProviderKind::Xai,
|
|
|
|
|
+ Self::OpenAi(_) => ProviderKind::OpenAi,
|
|
|
}
|
|
}
|
|
|
- self
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[must_use]
|
|
|
|
|
- pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
|
|
|
|
|
- self.base_url = base_url.into();
|
|
|
|
|
- self
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[must_use]
|
|
|
|
|
- pub fn with_retry_policy(
|
|
|
|
|
- mut self,
|
|
|
|
|
- max_retries: u32,
|
|
|
|
|
- initial_backoff: Duration,
|
|
|
|
|
- max_backoff: Duration,
|
|
|
|
|
- ) -> Self {
|
|
|
|
|
- self.max_retries = max_retries;
|
|
|
|
|
- self.initial_backoff = initial_backoff;
|
|
|
|
|
- self.max_backoff = max_backoff;
|
|
|
|
|
- self
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[must_use]
|
|
|
|
|
- pub fn auth_source(&self) -> &AuthSource {
|
|
|
|
|
- &self.auth
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
pub async fn send_message(
|
|
pub async fn send_message(
|
|
|
&self,
|
|
&self,
|
|
|
request: &MessageRequest,
|
|
request: &MessageRequest,
|
|
|
) -> Result<MessageResponse, ApiError> {
|
|
) -> Result<MessageResponse, ApiError> {
|
|
|
- let request = MessageRequest {
|
|
|
|
|
- stream: false,
|
|
|
|
|
- ..request.clone()
|
|
|
|
|
- };
|
|
|
|
|
- let response = self.send_with_retry(&request).await?;
|
|
|
|
|
- let request_id = request_id_from_headers(response.headers());
|
|
|
|
|
- let mut response = response
|
|
|
|
|
- .json::<MessageResponse>()
|
|
|
|
|
- .await
|
|
|
|
|
- .map_err(ApiError::from)?;
|
|
|
|
|
- if response.request_id.is_none() {
|
|
|
|
|
- response.request_id = request_id;
|
|
|
|
|
|
|
+ match self {
|
|
|
|
|
+ Self::Anthropic(client) => send_via_provider(client, request).await,
|
|
|
|
|
+ Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await,
|
|
|
}
|
|
}
|
|
|
- Ok(response)
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
pub async fn stream_message(
|
|
pub async fn stream_message(
|
|
|
&self,
|
|
&self,
|
|
|
request: &MessageRequest,
|
|
request: &MessageRequest,
|
|
|
) -> Result<MessageStream, ApiError> {
|
|
) -> Result<MessageStream, ApiError> {
|
|
|
- let response = self
|
|
|
|
|
- .send_with_retry(&request.clone().with_streaming())
|
|
|
|
|
- .await?;
|
|
|
|
|
- Ok(MessageStream {
|
|
|
|
|
- request_id: request_id_from_headers(response.headers()),
|
|
|
|
|
- response,
|
|
|
|
|
- parser: SseParser::new(),
|
|
|
|
|
- pending: VecDeque::new(),
|
|
|
|
|
- done: false,
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- pub async fn exchange_oauth_code(
|
|
|
|
|
- &self,
|
|
|
|
|
- config: &OAuthConfig,
|
|
|
|
|
- request: &OAuthTokenExchangeRequest,
|
|
|
|
|
- ) -> Result<OAuthTokenSet, ApiError> {
|
|
|
|
|
- let response = self
|
|
|
|
|
- .http
|
|
|
|
|
- .post(&config.token_url)
|
|
|
|
|
- .header("content-type", "application/x-www-form-urlencoded")
|
|
|
|
|
- .form(&request.form_params())
|
|
|
|
|
- .send()
|
|
|
|
|
- .await
|
|
|
|
|
- .map_err(ApiError::from)?;
|
|
|
|
|
- let response = expect_success(response).await?;
|
|
|
|
|
- response
|
|
|
|
|
- .json::<OAuthTokenSet>()
|
|
|
|
|
- .await
|
|
|
|
|
- .map_err(ApiError::from)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- pub async fn refresh_oauth_token(
|
|
|
|
|
- &self,
|
|
|
|
|
- config: &OAuthConfig,
|
|
|
|
|
- request: &OAuthRefreshRequest,
|
|
|
|
|
- ) -> Result<OAuthTokenSet, ApiError> {
|
|
|
|
|
- let response = self
|
|
|
|
|
- .http
|
|
|
|
|
- .post(&config.token_url)
|
|
|
|
|
- .header("content-type", "application/x-www-form-urlencoded")
|
|
|
|
|
- .form(&request.form_params())
|
|
|
|
|
- .send()
|
|
|
|
|
- .await
|
|
|
|
|
- .map_err(ApiError::from)?;
|
|
|
|
|
- let response = expect_success(response).await?;
|
|
|
|
|
- response
|
|
|
|
|
- .json::<OAuthTokenSet>()
|
|
|
|
|
- .await
|
|
|
|
|
- .map_err(ApiError::from)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- async fn send_with_retry(
|
|
|
|
|
- &self,
|
|
|
|
|
- request: &MessageRequest,
|
|
|
|
|
- ) -> Result<reqwest::Response, ApiError> {
|
|
|
|
|
- let mut attempts = 0;
|
|
|
|
|
- let mut last_error: Option<ApiError>;
|
|
|
|
|
-
|
|
|
|
|
- loop {
|
|
|
|
|
- attempts += 1;
|
|
|
|
|
- match self.send_raw_request(request).await {
|
|
|
|
|
- Ok(response) => match expect_success(response).await {
|
|
|
|
|
- Ok(response) => return Ok(response),
|
|
|
|
|
- Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
|
|
|
|
|
- last_error = Some(error);
|
|
|
|
|
- }
|
|
|
|
|
- Err(error) => return Err(error),
|
|
|
|
|
- },
|
|
|
|
|
- Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
|
|
|
|
|
- last_error = Some(error);
|
|
|
|
|
- }
|
|
|
|
|
- Err(error) => return Err(error),
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if attempts > self.max_retries {
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- Err(ApiError::RetriesExhausted {
|
|
|
|
|
- attempts,
|
|
|
|
|
- last_error: Box::new(last_error.expect("retry loop must capture an error")),
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- async fn send_raw_request(
|
|
|
|
|
- &self,
|
|
|
|
|
- request: &MessageRequest,
|
|
|
|
|
- ) -> Result<reqwest::Response, ApiError> {
|
|
|
|
|
- let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
|
|
|
|
|
- let request_builder = self
|
|
|
|
|
- .http
|
|
|
|
|
- .post(&request_url)
|
|
|
|
|
- .header("anthropic-version", ANTHROPIC_VERSION)
|
|
|
|
|
- .header("content-type", "application/json");
|
|
|
|
|
- let mut request_builder = self.auth.apply(request_builder);
|
|
|
|
|
-
|
|
|
|
|
- request_builder = request_builder.json(request);
|
|
|
|
|
- request_builder.send().await.map_err(ApiError::from)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
|
|
|
|
- let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
|
|
|
|
|
- return Err(ApiError::BackoffOverflow {
|
|
|
|
|
- attempt,
|
|
|
|
|
- base_delay: self.initial_backoff,
|
|
|
|
|
- });
|
|
|
|
|
- };
|
|
|
|
|
- Ok(self
|
|
|
|
|
- .initial_backoff
|
|
|
|
|
- .checked_mul(multiplier)
|
|
|
|
|
- .map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
|
|
|
|
|
- }
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-impl AuthSource {
|
|
|
|
|
- pub fn from_env_or_saved() -> Result<Self, ApiError> {
|
|
|
|
|
- if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
|
|
|
|
|
- return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
|
|
|
|
- Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
|
|
|
|
|
- api_key,
|
|
|
|
|
- bearer_token,
|
|
|
|
|
- }),
|
|
|
|
|
- None => Ok(Self::ApiKey(api_key)),
|
|
|
|
|
- };
|
|
|
|
|
- }
|
|
|
|
|
- if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
|
|
|
|
- return Ok(Self::BearerToken(bearer_token));
|
|
|
|
|
- }
|
|
|
|
|
- match load_saved_oauth_token() {
|
|
|
|
|
- Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
|
|
|
|
|
- if token_set.refresh_token.is_some() {
|
|
|
|
|
- Err(ApiError::Auth(
|
|
|
|
|
- "saved OAuth token is expired; load runtime OAuth config to refresh it"
|
|
|
|
|
- .to_string(),
|
|
|
|
|
- ))
|
|
|
|
|
- } else {
|
|
|
|
|
- Err(ApiError::ExpiredOAuthToken)
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
|
|
|
|
|
- Ok(None) => Err(ApiError::MissingApiKey),
|
|
|
|
|
- Err(error) => Err(error),
|
|
|
|
|
|
|
+ match self {
|
|
|
|
|
+ Self::Anthropic(client) => stream_via_provider(client, request)
|
|
|
|
|
+ .await
|
|
|
|
|
+ .map(MessageStream::Anthropic),
|
|
|
|
|
+ Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request)
|
|
|
|
|
+ .await
|
|
|
|
|
+ .map(MessageStream::OpenAiCompat),
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-#[must_use]
|
|
|
|
|
-pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
|
|
|
|
|
- token_set
|
|
|
|
|
- .expires_at
|
|
|
|
|
- .is_some_and(|expires_at| expires_at <= now_unix_timestamp())
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
|
|
|
|
|
- let Some(token_set) = load_saved_oauth_token()? else {
|
|
|
|
|
- return Ok(None);
|
|
|
|
|
- };
|
|
|
|
|
- resolve_saved_oauth_token_set(config, token_set).map(Some)
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
|
|
|
|
|
-where
|
|
|
|
|
- F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
|
|
|
|
|
-{
|
|
|
|
|
- if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
|
|
|
|
|
- return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
|
|
|
|
- Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
|
|
|
|
|
- api_key,
|
|
|
|
|
- bearer_token,
|
|
|
|
|
- }),
|
|
|
|
|
- None => Ok(AuthSource::ApiKey(api_key)),
|
|
|
|
|
- };
|
|
|
|
|
- }
|
|
|
|
|
- if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
|
|
|
|
|
- return Ok(AuthSource::BearerToken(bearer_token));
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- let Some(token_set) = load_saved_oauth_token()? else {
|
|
|
|
|
- return Err(ApiError::MissingApiKey);
|
|
|
|
|
- };
|
|
|
|
|
- if !oauth_token_is_expired(&token_set) {
|
|
|
|
|
- return Ok(AuthSource::BearerToken(token_set.access_token));
|
|
|
|
|
- }
|
|
|
|
|
- if token_set.refresh_token.is_none() {
|
|
|
|
|
- return Err(ApiError::ExpiredOAuthToken);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- let Some(config) = load_oauth_config()? else {
|
|
|
|
|
- return Err(ApiError::Auth(
|
|
|
|
|
- "saved OAuth token is expired; runtime OAuth config is missing".to_string(),
|
|
|
|
|
- ));
|
|
|
|
|
- };
|
|
|
|
|
- Ok(AuthSource::from(resolve_saved_oauth_token_set(
|
|
|
|
|
- &config, token_set,
|
|
|
|
|
- )?))
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-fn resolve_saved_oauth_token_set(
|
|
|
|
|
- config: &OAuthConfig,
|
|
|
|
|
- token_set: OAuthTokenSet,
|
|
|
|
|
-) -> Result<OAuthTokenSet, ApiError> {
|
|
|
|
|
- if !oauth_token_is_expired(&token_set) {
|
|
|
|
|
- return Ok(token_set);
|
|
|
|
|
- }
|
|
|
|
|
- let Some(refresh_token) = token_set.refresh_token.clone() else {
|
|
|
|
|
- return Err(ApiError::ExpiredOAuthToken);
|
|
|
|
|
- };
|
|
|
|
|
- let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url());
|
|
|
|
|
- let refreshed = client_runtime_block_on(async {
|
|
|
|
|
- client
|
|
|
|
|
- .refresh_oauth_token(
|
|
|
|
|
- config,
|
|
|
|
|
- &OAuthRefreshRequest::from_config(
|
|
|
|
|
- config,
|
|
|
|
|
- refresh_token,
|
|
|
|
|
- Some(token_set.scopes.clone()),
|
|
|
|
|
- ),
|
|
|
|
|
- )
|
|
|
|
|
- .await
|
|
|
|
|
- })?;
|
|
|
|
|
- let resolved = OAuthTokenSet {
|
|
|
|
|
- access_token: refreshed.access_token,
|
|
|
|
|
- refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
|
|
|
|
|
- expires_at: refreshed.expires_at,
|
|
|
|
|
- scopes: refreshed.scopes,
|
|
|
|
|
- };
|
|
|
|
|
- save_oauth_credentials(&runtime::OAuthTokenSet {
|
|
|
|
|
- access_token: resolved.access_token.clone(),
|
|
|
|
|
- refresh_token: resolved.refresh_token.clone(),
|
|
|
|
|
- expires_at: resolved.expires_at,
|
|
|
|
|
- scopes: resolved.scopes.clone(),
|
|
|
|
|
- })
|
|
|
|
|
- .map_err(ApiError::from)?;
|
|
|
|
|
- Ok(resolved)
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
|
|
|
|
|
-where
|
|
|
|
|
- F: std::future::Future<Output = Result<T, ApiError>>,
|
|
|
|
|
-{
|
|
|
|
|
- tokio::runtime::Runtime::new()
|
|
|
|
|
- .map_err(ApiError::from)?
|
|
|
|
|
- .block_on(future)
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
|
|
|
|
|
- let token_set = load_oauth_credentials().map_err(ApiError::from)?;
|
|
|
|
|
- Ok(token_set.map(|token_set| OAuthTokenSet {
|
|
|
|
|
- access_token: token_set.access_token,
|
|
|
|
|
- refresh_token: token_set.refresh_token,
|
|
|
|
|
- expires_at: token_set.expires_at,
|
|
|
|
|
- scopes: token_set.scopes,
|
|
|
|
|
- }))
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-fn now_unix_timestamp() -> u64 {
|
|
|
|
|
- SystemTime::now()
|
|
|
|
|
- .duration_since(UNIX_EPOCH)
|
|
|
|
|
- .map_or(0, |duration| duration.as_secs())
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
|
|
|
|
|
- match std::env::var(key) {
|
|
|
|
|
- Ok(value) if !value.is_empty() => Ok(Some(value)),
|
|
|
|
|
- Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
|
|
|
|
|
- Err(error) => Err(ApiError::from(error)),
|
|
|
|
|
- }
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-#[cfg(test)]
|
|
|
|
|
-fn read_api_key() -> Result<String, ApiError> {
|
|
|
|
|
- let auth = AuthSource::from_env_or_saved()?;
|
|
|
|
|
- auth.api_key()
|
|
|
|
|
- .or_else(|| auth.bearer_token())
|
|
|
|
|
- .map(ToOwned::to_owned)
|
|
|
|
|
- .ok_or(ApiError::MissingApiKey)
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-#[cfg(test)]
|
|
|
|
|
-fn read_auth_token() -> Option<String> {
|
|
|
|
|
- read_env_non_empty("ANTHROPIC_AUTH_TOKEN")
|
|
|
|
|
- .ok()
|
|
|
|
|
- .and_then(std::convert::identity)
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-#[must_use]
|
|
|
|
|
-pub fn read_base_url() -> String {
|
|
|
|
|
- std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string())
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
|
|
|
|
|
- headers
|
|
|
|
|
- .get(REQUEST_ID_HEADER)
|
|
|
|
|
- .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
|
|
|
|
|
- .and_then(|value| value.to_str().ok())
|
|
|
|
|
- .map(ToOwned::to_owned)
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
#[derive(Debug)]
|
|
#[derive(Debug)]
|
|
|
-pub struct MessageStream {
|
|
|
|
|
- request_id: Option<String>,
|
|
|
|
|
- response: reqwest::Response,
|
|
|
|
|
- parser: SseParser,
|
|
|
|
|
- pending: VecDeque<StreamEvent>,
|
|
|
|
|
- done: bool,
|
|
|
|
|
|
|
+pub enum MessageStream {
|
|
|
|
|
+ Anthropic(anthropic::MessageStream),
|
|
|
|
|
+ OpenAiCompat(openai_compat::MessageStream),
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
impl MessageStream {
|
|
impl MessageStream {
|
|
|
#[must_use]
|
|
#[must_use]
|
|
|
pub fn request_id(&self) -> Option<&str> {
|
|
pub fn request_id(&self) -> Option<&str> {
|
|
|
- self.request_id.as_deref()
|
|
|
|
|
|
|
+ match self {
|
|
|
|
|
+ Self::Anthropic(stream) => stream.request_id(),
|
|
|
|
|
+ Self::OpenAiCompat(stream) => stream.request_id(),
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
|
|
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
|
|
|
- loop {
|
|
|
|
|
- if let Some(event) = self.pending.pop_front() {
|
|
|
|
|
- return Ok(Some(event));
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if self.done {
|
|
|
|
|
- let remaining = self.parser.finish()?;
|
|
|
|
|
- self.pending.extend(remaining);
|
|
|
|
|
- if let Some(event) = self.pending.pop_front() {
|
|
|
|
|
- return Ok(Some(event));
|
|
|
|
|
- }
|
|
|
|
|
- return Ok(None);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- match self.response.chunk().await? {
|
|
|
|
|
- Some(chunk) => {
|
|
|
|
|
- self.pending.extend(self.parser.push(&chunk)?);
|
|
|
|
|
- }
|
|
|
|
|
- None => {
|
|
|
|
|
- self.done = true;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ match self {
|
|
|
|
|
+ Self::Anthropic(stream) => stream.next_event().await,
|
|
|
|
|
+ Self::OpenAiCompat(stream) => stream.next_event().await,
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
|
|
|
|
|
- let status = response.status();
|
|
|
|
|
- if status.is_success() {
|
|
|
|
|
- return Ok(response);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- let body = response.text().await.unwrap_or_else(|_| String::new());
|
|
|
|
|
- let parsed_error = serde_json::from_str::<AnthropicErrorEnvelope>(&body).ok();
|
|
|
|
|
- let retryable = is_retryable_status(status);
|
|
|
|
|
-
|
|
|
|
|
- Err(ApiError::Api {
|
|
|
|
|
- status,
|
|
|
|
|
- error_type: parsed_error
|
|
|
|
|
- .as_ref()
|
|
|
|
|
- .map(|error| error.error.error_type.clone()),
|
|
|
|
|
- message: parsed_error
|
|
|
|
|
- .as_ref()
|
|
|
|
|
- .map(|error| error.error.message.clone()),
|
|
|
|
|
- body,
|
|
|
|
|
- retryable,
|
|
|
|
|
- })
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
|
|
|
|
|
- matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-#[derive(Debug, Deserialize)]
|
|
|
|
|
-struct AnthropicErrorEnvelope {
|
|
|
|
|
- error: AnthropicErrorBody,
|
|
|
|
|
|
|
+pub use anthropic::{
|
|
|
|
|
+ oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, OAuthTokenSet,
|
|
|
|
|
+};
|
|
|
|
|
+#[must_use]
|
|
|
|
|
+pub fn read_base_url() -> String {
|
|
|
|
|
+ anthropic::read_base_url()
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-#[derive(Debug, Deserialize)]
|
|
|
|
|
-struct AnthropicErrorBody {
|
|
|
|
|
- #[serde(rename = "type")]
|
|
|
|
|
- error_type: String,
|
|
|
|
|
- message: String,
|
|
|
|
|
|
|
+#[must_use]
|
|
|
|
|
+pub fn read_xai_base_url() -> String {
|
|
|
|
|
+ openai_compat::read_base_url(OpenAiCompatConfig::xai())
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
#[cfg(test)]
|
|
|
mod tests {
|
|
mod tests {
|
|
|
- use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
|
|
|
|
|
- use std::io::{Read, Write};
|
|
|
|
|
- use std::net::TcpListener;
|
|
|
|
|
- use std::sync::{Mutex, OnceLock};
|
|
|
|
|
- use std::thread;
|
|
|
|
|
- use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
|
|
|
|
-
|
|
|
|
|
- use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig};
|
|
|
|
|
-
|
|
|
|
|
- use crate::client::{
|
|
|
|
|
- now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
|
|
|
|
|
- resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
|
|
|
|
|
- };
|
|
|
|
|
- use crate::types::{ContentBlockDelta, MessageRequest};
|
|
|
|
|
-
|
|
|
|
|
- fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
|
|
|
|
- static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
|
|
|
|
- LOCK.get_or_init(|| Mutex::new(()))
|
|
|
|
|
- .lock()
|
|
|
|
|
- .expect("env lock")
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- fn temp_config_home() -> std::path::PathBuf {
|
|
|
|
|
- std::env::temp_dir().join(format!(
|
|
|
|
|
- "api-oauth-test-{}-{}",
|
|
|
|
|
- std::process::id(),
|
|
|
|
|
- SystemTime::now()
|
|
|
|
|
- .duration_since(UNIX_EPOCH)
|
|
|
|
|
- .expect("time")
|
|
|
|
|
- .as_nanos()
|
|
|
|
|
- ))
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- fn sample_oauth_config(token_url: String) -> OAuthConfig {
|
|
|
|
|
- OAuthConfig {
|
|
|
|
|
- client_id: "runtime-client".to_string(),
|
|
|
|
|
- authorize_url: "https://console.test/oauth/authorize".to_string(),
|
|
|
|
|
- token_url,
|
|
|
|
|
- callback_port: Some(4545),
|
|
|
|
|
- manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
|
|
|
|
|
- scopes: vec!["org:read".to_string(), "user:write".to_string()],
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- fn spawn_token_server(response_body: &'static str) -> String {
|
|
|
|
|
- let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
|
|
|
|
|
- let address = listener.local_addr().expect("local addr");
|
|
|
|
|
- thread::spawn(move || {
|
|
|
|
|
- let (mut stream, _) = listener.accept().expect("accept connection");
|
|
|
|
|
- let mut buffer = [0_u8; 4096];
|
|
|
|
|
- let _ = stream.read(&mut buffer).expect("read request");
|
|
|
|
|
- let response = format!(
|
|
|
|
|
- "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
|
|
|
|
|
- response_body.len(),
|
|
|
|
|
- response_body
|
|
|
|
|
- );
|
|
|
|
|
- stream
|
|
|
|
|
- .write_all(response.as_bytes())
|
|
|
|
|
- .expect("write response");
|
|
|
|
|
- });
|
|
|
|
|
- format!("http://{address}/oauth/token")
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn read_api_key_requires_presence() {
|
|
|
|
|
- let _guard = env_lock();
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_API_KEY");
|
|
|
|
|
- std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
|
|
|
- let error = super::read_api_key().expect_err("missing key should error");
|
|
|
|
|
- assert!(matches!(error, crate::error::ApiError::MissingApiKey));
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn read_api_key_requires_non_empty_value() {
|
|
|
|
|
- let _guard = env_lock();
|
|
|
|
|
- std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_API_KEY");
|
|
|
|
|
- let error = super::read_api_key().expect_err("empty key should error");
|
|
|
|
|
- assert!(matches!(error, crate::error::ApiError::MissingApiKey));
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn read_api_key_prefers_api_key_env() {
|
|
|
|
|
- let _guard = env_lock();
|
|
|
|
|
- std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
|
|
|
|
- std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
|
|
|
|
|
- assert_eq!(
|
|
|
|
|
- super::read_api_key().expect("api key should load"),
|
|
|
|
|
- "legacy-key"
|
|
|
|
|
- );
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_API_KEY");
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn read_auth_token_reads_auth_token_env() {
|
|
|
|
|
- let _guard = env_lock();
|
|
|
|
|
- std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
|
|
|
|
- assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn oauth_token_maps_to_bearer_auth_source() {
|
|
|
|
|
- let auth = AuthSource::from(OAuthTokenSet {
|
|
|
|
|
- access_token: "access-token".to_string(),
|
|
|
|
|
- refresh_token: Some("refresh".to_string()),
|
|
|
|
|
- expires_at: Some(123),
|
|
|
|
|
- scopes: vec!["scope:a".to_string()],
|
|
|
|
|
- });
|
|
|
|
|
- assert_eq!(auth.bearer_token(), Some("access-token"));
|
|
|
|
|
- assert_eq!(auth.api_key(), None);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn auth_source_from_env_combines_api_key_and_bearer_token() {
|
|
|
|
|
- let _guard = env_lock();
|
|
|
|
|
- std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
|
|
|
|
|
- std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
|
|
|
|
|
- let auth = AuthSource::from_env().expect("env auth");
|
|
|
|
|
- assert_eq!(auth.api_key(), Some("legacy-key"));
|
|
|
|
|
- assert_eq!(auth.bearer_token(), Some("auth-token"));
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_API_KEY");
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn auth_source_from_saved_oauth_when_env_absent() {
|
|
|
|
|
- let _guard = env_lock();
|
|
|
|
|
- let config_home = temp_config_home();
|
|
|
|
|
- std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_API_KEY");
|
|
|
|
|
- save_oauth_credentials(&runtime::OAuthTokenSet {
|
|
|
|
|
- access_token: "saved-access-token".to_string(),
|
|
|
|
|
- refresh_token: Some("refresh".to_string()),
|
|
|
|
|
- expires_at: Some(now_unix_timestamp() + 300),
|
|
|
|
|
- scopes: vec!["scope:a".to_string()],
|
|
|
|
|
- })
|
|
|
|
|
- .expect("save oauth credentials");
|
|
|
|
|
-
|
|
|
|
|
- let auth = AuthSource::from_env_or_saved().expect("saved auth");
|
|
|
|
|
- assert_eq!(auth.bearer_token(), Some("saved-access-token"));
|
|
|
|
|
-
|
|
|
|
|
- clear_oauth_credentials().expect("clear credentials");
|
|
|
|
|
- std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
|
|
|
- std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ use crate::providers::{detect_provider_kind, resolve_model_alias, ProviderKind};
|
|
|
|
|
|
|
|
#[test]
|
|
#[test]
|
|
|
- fn oauth_token_expiry_uses_expires_at_timestamp() {
|
|
|
|
|
- assert!(oauth_token_is_expired(&OAuthTokenSet {
|
|
|
|
|
- access_token: "access-token".to_string(),
|
|
|
|
|
- refresh_token: None,
|
|
|
|
|
- expires_at: Some(1),
|
|
|
|
|
- scopes: Vec::new(),
|
|
|
|
|
- }));
|
|
|
|
|
- assert!(!oauth_token_is_expired(&OAuthTokenSet {
|
|
|
|
|
- access_token: "access-token".to_string(),
|
|
|
|
|
- refresh_token: None,
|
|
|
|
|
- expires_at: Some(now_unix_timestamp() + 60),
|
|
|
|
|
- scopes: Vec::new(),
|
|
|
|
|
- }));
|
|
|
|
|
|
|
+ fn resolves_existing_and_grok_aliases() {
|
|
|
|
|
+ assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6");
|
|
|
|
|
+ assert_eq!(resolve_model_alias("grok"), "grok-3");
|
|
|
|
|
+ assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini");
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
|
#[test]
|
|
|
- fn resolve_saved_oauth_token_refreshes_expired_credentials() {
|
|
|
|
|
- let _guard = env_lock();
|
|
|
|
|
- let config_home = temp_config_home();
|
|
|
|
|
- std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_API_KEY");
|
|
|
|
|
- save_oauth_credentials(&runtime::OAuthTokenSet {
|
|
|
|
|
- access_token: "expired-access-token".to_string(),
|
|
|
|
|
- refresh_token: Some("refresh-token".to_string()),
|
|
|
|
|
- expires_at: Some(1),
|
|
|
|
|
- scopes: vec!["scope:a".to_string()],
|
|
|
|
|
- })
|
|
|
|
|
- .expect("save expired oauth credentials");
|
|
|
|
|
-
|
|
|
|
|
- let token_url = spawn_token_server(
|
|
|
|
|
- "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
|
|
|
|
|
- );
|
|
|
|
|
- let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
|
|
|
|
|
- .expect("resolve refreshed token")
|
|
|
|
|
- .expect("token set present");
|
|
|
|
|
- assert_eq!(resolved.access_token, "refreshed-token");
|
|
|
|
|
- let stored = runtime::load_oauth_credentials()
|
|
|
|
|
- .expect("load stored credentials")
|
|
|
|
|
- .expect("stored token set");
|
|
|
|
|
- assert_eq!(stored.access_token, "refreshed-token");
|
|
|
|
|
-
|
|
|
|
|
- clear_oauth_credentials().expect("clear credentials");
|
|
|
|
|
- std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
|
|
|
- std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
|
|
|
|
|
- let _guard = env_lock();
|
|
|
|
|
- let config_home = temp_config_home();
|
|
|
|
|
- std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_API_KEY");
|
|
|
|
|
- save_oauth_credentials(&runtime::OAuthTokenSet {
|
|
|
|
|
- access_token: "saved-access-token".to_string(),
|
|
|
|
|
- refresh_token: Some("refresh".to_string()),
|
|
|
|
|
- expires_at: Some(now_unix_timestamp() + 300),
|
|
|
|
|
- scopes: vec!["scope:a".to_string()],
|
|
|
|
|
- })
|
|
|
|
|
- .expect("save oauth credentials");
|
|
|
|
|
-
|
|
|
|
|
- let auth = resolve_startup_auth_source(|| panic!("config should not be loaded"))
|
|
|
|
|
- .expect("startup auth");
|
|
|
|
|
- assert_eq!(auth.bearer_token(), Some("saved-access-token"));
|
|
|
|
|
-
|
|
|
|
|
- clear_oauth_credentials().expect("clear credentials");
|
|
|
|
|
- std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
|
|
|
- std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
|
|
|
|
|
- let _guard = env_lock();
|
|
|
|
|
- let config_home = temp_config_home();
|
|
|
|
|
- std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_API_KEY");
|
|
|
|
|
- save_oauth_credentials(&runtime::OAuthTokenSet {
|
|
|
|
|
- access_token: "expired-access-token".to_string(),
|
|
|
|
|
- refresh_token: Some("refresh-token".to_string()),
|
|
|
|
|
- expires_at: Some(1),
|
|
|
|
|
- scopes: vec!["scope:a".to_string()],
|
|
|
|
|
- })
|
|
|
|
|
- .expect("save expired oauth credentials");
|
|
|
|
|
-
|
|
|
|
|
- let error =
|
|
|
|
|
- resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error");
|
|
|
|
|
- assert!(
|
|
|
|
|
- matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing"))
|
|
|
|
|
- );
|
|
|
|
|
-
|
|
|
|
|
- let stored = runtime::load_oauth_credentials()
|
|
|
|
|
- .expect("load stored credentials")
|
|
|
|
|
- .expect("stored token set");
|
|
|
|
|
- assert_eq!(stored.access_token, "expired-access-token");
|
|
|
|
|
- assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
|
|
|
|
-
|
|
|
|
|
- clear_oauth_credentials().expect("clear credentials");
|
|
|
|
|
- std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
|
|
|
- std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
|
|
|
|
|
- let _guard = env_lock();
|
|
|
|
|
- let config_home = temp_config_home();
|
|
|
|
|
- std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
|
|
|
|
|
- std::env::remove_var("ANTHROPIC_API_KEY");
|
|
|
|
|
- save_oauth_credentials(&runtime::OAuthTokenSet {
|
|
|
|
|
- access_token: "expired-access-token".to_string(),
|
|
|
|
|
- refresh_token: Some("refresh-token".to_string()),
|
|
|
|
|
- expires_at: Some(1),
|
|
|
|
|
- scopes: vec!["scope:a".to_string()],
|
|
|
|
|
- })
|
|
|
|
|
- .expect("save expired oauth credentials");
|
|
|
|
|
-
|
|
|
|
|
- let token_url = spawn_token_server(
|
|
|
|
|
- "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
|
|
|
|
|
- );
|
|
|
|
|
- let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
|
|
|
|
|
- .expect("resolve refreshed token")
|
|
|
|
|
- .expect("token set present");
|
|
|
|
|
- assert_eq!(resolved.access_token, "refreshed-token");
|
|
|
|
|
- assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token"));
|
|
|
|
|
- let stored = runtime::load_oauth_credentials()
|
|
|
|
|
- .expect("load stored credentials")
|
|
|
|
|
- .expect("stored token set");
|
|
|
|
|
- assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
|
|
|
|
-
|
|
|
|
|
- clear_oauth_credentials().expect("clear credentials");
|
|
|
|
|
- std::env::remove_var("CLAUDE_CONFIG_HOME");
|
|
|
|
|
- std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn message_request_stream_helper_sets_stream_true() {
|
|
|
|
|
- let request = MessageRequest {
|
|
|
|
|
- model: "claude-opus-4-6".to_string(),
|
|
|
|
|
- max_tokens: 64,
|
|
|
|
|
- messages: vec![],
|
|
|
|
|
- system: None,
|
|
|
|
|
- tools: None,
|
|
|
|
|
- tool_choice: None,
|
|
|
|
|
- stream: false,
|
|
|
|
|
- };
|
|
|
|
|
-
|
|
|
|
|
- assert!(request.with_streaming().stream);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn backoff_doubles_until_maximum() {
|
|
|
|
|
- let client = AnthropicClient::new("test-key").with_retry_policy(
|
|
|
|
|
- 3,
|
|
|
|
|
- Duration::from_millis(10),
|
|
|
|
|
- Duration::from_millis(25),
|
|
|
|
|
- );
|
|
|
|
|
- assert_eq!(
|
|
|
|
|
- client.backoff_for_attempt(1).expect("attempt 1"),
|
|
|
|
|
- Duration::from_millis(10)
|
|
|
|
|
- );
|
|
|
|
|
- assert_eq!(
|
|
|
|
|
- client.backoff_for_attempt(2).expect("attempt 2"),
|
|
|
|
|
- Duration::from_millis(20)
|
|
|
|
|
- );
|
|
|
|
|
- assert_eq!(
|
|
|
|
|
- client.backoff_for_attempt(3).expect("attempt 3"),
|
|
|
|
|
- Duration::from_millis(25)
|
|
|
|
|
- );
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn retryable_statuses_are_detected() {
|
|
|
|
|
- assert!(super::is_retryable_status(
|
|
|
|
|
- reqwest::StatusCode::TOO_MANY_REQUESTS
|
|
|
|
|
- ));
|
|
|
|
|
- assert!(super::is_retryable_status(
|
|
|
|
|
- reqwest::StatusCode::INTERNAL_SERVER_ERROR
|
|
|
|
|
- ));
|
|
|
|
|
- assert!(!super::is_retryable_status(
|
|
|
|
|
- reqwest::StatusCode::UNAUTHORIZED
|
|
|
|
|
- ));
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn tool_delta_variant_round_trips() {
|
|
|
|
|
- let delta = ContentBlockDelta::InputJsonDelta {
|
|
|
|
|
- partial_json: "{\"city\":\"Paris\"}".to_string(),
|
|
|
|
|
- };
|
|
|
|
|
- let encoded = serde_json::to_string(&delta).expect("delta should serialize");
|
|
|
|
|
- let decoded: ContentBlockDelta =
|
|
|
|
|
- serde_json::from_str(&encoded).expect("delta should deserialize");
|
|
|
|
|
- assert_eq!(decoded, delta);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn request_id_uses_primary_or_fallback_header() {
|
|
|
|
|
- let mut headers = reqwest::header::HeaderMap::new();
|
|
|
|
|
- headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header"));
|
|
|
|
|
- assert_eq!(
|
|
|
|
|
- super::request_id_from_headers(&headers).as_deref(),
|
|
|
|
|
- Some("req_primary")
|
|
|
|
|
- );
|
|
|
|
|
-
|
|
|
|
|
- headers.clear();
|
|
|
|
|
- headers.insert(
|
|
|
|
|
- ALT_REQUEST_ID_HEADER,
|
|
|
|
|
- "req_fallback".parse().expect("header"),
|
|
|
|
|
- );
|
|
|
|
|
- assert_eq!(
|
|
|
|
|
- super::request_id_from_headers(&headers).as_deref(),
|
|
|
|
|
- Some("req_fallback")
|
|
|
|
|
- );
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- #[test]
|
|
|
|
|
- fn auth_source_applies_headers() {
|
|
|
|
|
- let auth = AuthSource::ApiKeyAndBearer {
|
|
|
|
|
- api_key: "test-key".to_string(),
|
|
|
|
|
- bearer_token: "proxy-token".to_string(),
|
|
|
|
|
- };
|
|
|
|
|
- let request = auth
|
|
|
|
|
- .apply(reqwest::Client::new().post("https://example.test"))
|
|
|
|
|
- .build()
|
|
|
|
|
- .expect("request build");
|
|
|
|
|
- let headers = request.headers();
|
|
|
|
|
- assert_eq!(
|
|
|
|
|
- headers.get("x-api-key").and_then(|v| v.to_str().ok()),
|
|
|
|
|
- Some("test-key")
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ fn provider_detection_prefers_model_family() {
|
|
|
|
|
+ assert_eq!(detect_provider_kind("grok-3"), ProviderKind::Xai);
|
|
|
assert_eq!(
|
|
assert_eq!(
|
|
|
- headers.get("authorization").and_then(|v| v.to_str().ok()),
|
|
|
|
|
- Some("Bearer proxy-token")
|
|
|
|
|
|
|
+ detect_provider_kind("claude-sonnet-4-6"),
|
|
|
|
|
+ ProviderKind::Anthropic
|
|
|
);
|
|
);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|