client.rs 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006
  1. use std::collections::VecDeque;
  2. use std::time::{Duration, SystemTime, UNIX_EPOCH};
  3. use runtime::{
  4. load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
  5. OAuthTokenExchangeRequest,
  6. };
  7. use serde::Deserialize;
  8. use crate::error::ApiError;
  9. use crate::sse::SseParser;
  10. use crate::types::{MessageRequest, MessageResponse, StreamEvent};
  11. const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
  12. const ANTHROPIC_VERSION: &str = "2023-06-01";
  13. const REQUEST_ID_HEADER: &str = "request-id";
  14. const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
  15. const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
  16. const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
  17. const DEFAULT_MAX_RETRIES: u32 = 2;
  18. #[derive(Debug, Clone, PartialEq, Eq)]
  19. pub enum AuthSource {
  20. None,
  21. ApiKey(String),
  22. BearerToken(String),
  23. ApiKeyAndBearer {
  24. api_key: String,
  25. bearer_token: String,
  26. },
  27. }
  28. impl AuthSource {
  29. pub fn from_env() -> Result<Self, ApiError> {
  30. let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?;
  31. let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?;
  32. match (api_key, auth_token) {
  33. (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer {
  34. api_key,
  35. bearer_token,
  36. }),
  37. (Some(api_key), None) => Ok(Self::ApiKey(api_key)),
  38. (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
  39. (None, None) => Err(ApiError::MissingApiKey),
  40. }
  41. }
  42. #[must_use]
  43. pub fn api_key(&self) -> Option<&str> {
  44. match self {
  45. Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
  46. Self::None | Self::BearerToken(_) => None,
  47. }
  48. }
  49. #[must_use]
  50. pub fn bearer_token(&self) -> Option<&str> {
  51. match self {
  52. Self::BearerToken(token)
  53. | Self::ApiKeyAndBearer {
  54. bearer_token: token,
  55. ..
  56. } => Some(token),
  57. Self::None | Self::ApiKey(_) => None,
  58. }
  59. }
  60. #[must_use]
  61. pub fn masked_authorization_header(&self) -> &'static str {
  62. if self.bearer_token().is_some() {
  63. "Bearer [REDACTED]"
  64. } else {
  65. "<absent>"
  66. }
  67. }
  68. pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
  69. if let Some(api_key) = self.api_key() {
  70. request_builder = request_builder.header("x-api-key", api_key);
  71. }
  72. if let Some(token) = self.bearer_token() {
  73. request_builder = request_builder.bearer_auth(token);
  74. }
  75. request_builder
  76. }
  77. }
  78. #[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
  79. pub struct OAuthTokenSet {
  80. pub access_token: String,
  81. pub refresh_token: Option<String>,
  82. pub expires_at: Option<u64>,
  83. #[serde(default)]
  84. pub scopes: Vec<String>,
  85. }
  86. impl From<OAuthTokenSet> for AuthSource {
  87. fn from(value: OAuthTokenSet) -> Self {
  88. Self::BearerToken(value.access_token)
  89. }
  90. }
  91. #[derive(Debug, Clone)]
  92. pub struct AnthropicClient {
  93. http: reqwest::Client,
  94. auth: AuthSource,
  95. base_url: String,
  96. max_retries: u32,
  97. initial_backoff: Duration,
  98. max_backoff: Duration,
  99. }
  100. impl AnthropicClient {
  101. #[must_use]
  102. pub fn new(api_key: impl Into<String>) -> Self {
  103. Self {
  104. http: reqwest::Client::new(),
  105. auth: AuthSource::ApiKey(api_key.into()),
  106. base_url: DEFAULT_BASE_URL.to_string(),
  107. max_retries: DEFAULT_MAX_RETRIES,
  108. initial_backoff: DEFAULT_INITIAL_BACKOFF,
  109. max_backoff: DEFAULT_MAX_BACKOFF,
  110. }
  111. }
  112. #[must_use]
  113. pub fn from_auth(auth: AuthSource) -> Self {
  114. Self {
  115. http: reqwest::Client::new(),
  116. auth,
  117. base_url: DEFAULT_BASE_URL.to_string(),
  118. max_retries: DEFAULT_MAX_RETRIES,
  119. initial_backoff: DEFAULT_INITIAL_BACKOFF,
  120. max_backoff: DEFAULT_MAX_BACKOFF,
  121. }
  122. }
  123. pub fn from_env() -> Result<Self, ApiError> {
  124. Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
  125. }
  126. #[must_use]
  127. pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
  128. self.auth = auth;
  129. self
  130. }
  131. #[must_use]
  132. pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
  133. match (
  134. self.auth.api_key().map(ToOwned::to_owned),
  135. auth_token.filter(|token| !token.is_empty()),
  136. ) {
  137. (Some(api_key), Some(bearer_token)) => {
  138. self.auth = AuthSource::ApiKeyAndBearer {
  139. api_key,
  140. bearer_token,
  141. };
  142. }
  143. (Some(api_key), None) => {
  144. self.auth = AuthSource::ApiKey(api_key);
  145. }
  146. (None, Some(bearer_token)) => {
  147. self.auth = AuthSource::BearerToken(bearer_token);
  148. }
  149. (None, None) => {
  150. self.auth = AuthSource::None;
  151. }
  152. }
  153. self
  154. }
  155. #[must_use]
  156. pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
  157. self.base_url = base_url.into();
  158. self
  159. }
  160. #[must_use]
  161. pub fn with_retry_policy(
  162. mut self,
  163. max_retries: u32,
  164. initial_backoff: Duration,
  165. max_backoff: Duration,
  166. ) -> Self {
  167. self.max_retries = max_retries;
  168. self.initial_backoff = initial_backoff;
  169. self.max_backoff = max_backoff;
  170. self
  171. }
  172. #[must_use]
  173. pub fn auth_source(&self) -> &AuthSource {
  174. &self.auth
  175. }
  176. pub async fn send_message(
  177. &self,
  178. request: &MessageRequest,
  179. ) -> Result<MessageResponse, ApiError> {
  180. let request = MessageRequest {
  181. stream: false,
  182. ..request.clone()
  183. };
  184. let response = self.send_with_retry(&request).await?;
  185. let request_id = request_id_from_headers(response.headers());
  186. let mut response = response
  187. .json::<MessageResponse>()
  188. .await
  189. .map_err(ApiError::from)?;
  190. if response.request_id.is_none() {
  191. response.request_id = request_id;
  192. }
  193. Ok(response)
  194. }
  195. pub async fn stream_message(
  196. &self,
  197. request: &MessageRequest,
  198. ) -> Result<MessageStream, ApiError> {
  199. let response = self
  200. .send_with_retry(&request.clone().with_streaming())
  201. .await?;
  202. Ok(MessageStream {
  203. request_id: request_id_from_headers(response.headers()),
  204. response,
  205. parser: SseParser::new(),
  206. pending: VecDeque::new(),
  207. done: false,
  208. })
  209. }
  210. pub async fn exchange_oauth_code(
  211. &self,
  212. config: &OAuthConfig,
  213. request: &OAuthTokenExchangeRequest,
  214. ) -> Result<OAuthTokenSet, ApiError> {
  215. let response = self
  216. .http
  217. .post(&config.token_url)
  218. .header("content-type", "application/x-www-form-urlencoded")
  219. .form(&request.form_params())
  220. .send()
  221. .await
  222. .map_err(ApiError::from)?;
  223. let response = expect_success(response).await?;
  224. response
  225. .json::<OAuthTokenSet>()
  226. .await
  227. .map_err(ApiError::from)
  228. }
  229. pub async fn refresh_oauth_token(
  230. &self,
  231. config: &OAuthConfig,
  232. request: &OAuthRefreshRequest,
  233. ) -> Result<OAuthTokenSet, ApiError> {
  234. let response = self
  235. .http
  236. .post(&config.token_url)
  237. .header("content-type", "application/x-www-form-urlencoded")
  238. .form(&request.form_params())
  239. .send()
  240. .await
  241. .map_err(ApiError::from)?;
  242. let response = expect_success(response).await?;
  243. response
  244. .json::<OAuthTokenSet>()
  245. .await
  246. .map_err(ApiError::from)
  247. }
  248. async fn send_with_retry(
  249. &self,
  250. request: &MessageRequest,
  251. ) -> Result<reqwest::Response, ApiError> {
  252. let mut attempts = 0;
  253. let mut last_error: Option<ApiError>;
  254. loop {
  255. attempts += 1;
  256. match self.send_raw_request(request).await {
  257. Ok(response) => match expect_success(response).await {
  258. Ok(response) => return Ok(response),
  259. Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
  260. last_error = Some(error);
  261. }
  262. Err(error) => return Err(error),
  263. },
  264. Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
  265. last_error = Some(error);
  266. }
  267. Err(error) => return Err(error),
  268. }
  269. if attempts > self.max_retries {
  270. break;
  271. }
  272. tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
  273. }
  274. Err(ApiError::RetriesExhausted {
  275. attempts,
  276. last_error: Box::new(last_error.expect("retry loop must capture an error")),
  277. })
  278. }
  279. async fn send_raw_request(
  280. &self,
  281. request: &MessageRequest,
  282. ) -> Result<reqwest::Response, ApiError> {
  283. let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
  284. let resolved_base_url = self.base_url.trim_end_matches('/');
  285. eprintln!("[anthropic-client] resolved_base_url={resolved_base_url}");
  286. eprintln!("[anthropic-client] request_url={request_url}");
  287. let request_builder = self
  288. .http
  289. .post(&request_url)
  290. .header("anthropic-version", ANTHROPIC_VERSION)
  291. .header("content-type", "application/json");
  292. let mut request_builder = self.auth.apply(request_builder);
  293. eprintln!(
  294. "[anthropic-client] headers x-api-key={} authorization={} anthropic-version={ANTHROPIC_VERSION} content-type=application/json",
  295. if self.auth.api_key().is_some() {
  296. "[REDACTED]"
  297. } else {
  298. "<absent>"
  299. },
  300. self.auth.masked_authorization_header()
  301. );
  302. request_builder = request_builder.json(request);
  303. request_builder.send().await.map_err(ApiError::from)
  304. }
  305. fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
  306. let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
  307. return Err(ApiError::BackoffOverflow {
  308. attempt,
  309. base_delay: self.initial_backoff,
  310. });
  311. };
  312. Ok(self
  313. .initial_backoff
  314. .checked_mul(multiplier)
  315. .map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
  316. }
  317. }
  318. impl AuthSource {
  319. pub fn from_env_or_saved() -> Result<Self, ApiError> {
  320. if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
  321. return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
  322. Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
  323. api_key,
  324. bearer_token,
  325. }),
  326. None => Ok(Self::ApiKey(api_key)),
  327. };
  328. }
  329. if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
  330. return Ok(Self::BearerToken(bearer_token));
  331. }
  332. match load_saved_oauth_token() {
  333. Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
  334. if token_set.refresh_token.is_some() {
  335. Err(ApiError::Auth(
  336. "saved OAuth token is expired; load runtime OAuth config to refresh it"
  337. .to_string(),
  338. ))
  339. } else {
  340. Err(ApiError::ExpiredOAuthToken)
  341. }
  342. }
  343. Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
  344. Ok(None) => Err(ApiError::MissingApiKey),
  345. Err(error) => Err(error),
  346. }
  347. }
  348. }
  349. #[must_use]
  350. pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
  351. token_set
  352. .expires_at
  353. .is_some_and(|expires_at| expires_at <= now_unix_timestamp())
  354. }
  355. pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
  356. let Some(token_set) = load_saved_oauth_token()? else {
  357. return Ok(None);
  358. };
  359. resolve_saved_oauth_token_set(config, token_set).map(Some)
  360. }
  361. pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
  362. where
  363. F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
  364. {
  365. if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
  366. return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
  367. Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
  368. api_key,
  369. bearer_token,
  370. }),
  371. None => Ok(AuthSource::ApiKey(api_key)),
  372. };
  373. }
  374. if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
  375. return Ok(AuthSource::BearerToken(bearer_token));
  376. }
  377. let Some(token_set) = load_saved_oauth_token()? else {
  378. return Err(ApiError::MissingApiKey);
  379. };
  380. if !oauth_token_is_expired(&token_set) {
  381. return Ok(AuthSource::BearerToken(token_set.access_token));
  382. }
  383. if token_set.refresh_token.is_none() {
  384. return Err(ApiError::ExpiredOAuthToken);
  385. }
  386. let Some(config) = load_oauth_config()? else {
  387. return Err(ApiError::Auth(
  388. "saved OAuth token is expired; runtime OAuth config is missing".to_string(),
  389. ));
  390. };
  391. Ok(AuthSource::from(resolve_saved_oauth_token_set(
  392. &config, token_set,
  393. )?))
  394. }
  395. fn resolve_saved_oauth_token_set(
  396. config: &OAuthConfig,
  397. token_set: OAuthTokenSet,
  398. ) -> Result<OAuthTokenSet, ApiError> {
  399. if !oauth_token_is_expired(&token_set) {
  400. return Ok(token_set);
  401. }
  402. let Some(refresh_token) = token_set.refresh_token.clone() else {
  403. return Err(ApiError::ExpiredOAuthToken);
  404. };
  405. let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url());
  406. let refreshed = client_runtime_block_on(async {
  407. client
  408. .refresh_oauth_token(
  409. config,
  410. &OAuthRefreshRequest::from_config(
  411. config,
  412. refresh_token,
  413. Some(token_set.scopes.clone()),
  414. ),
  415. )
  416. .await
  417. })?;
  418. let resolved = OAuthTokenSet {
  419. access_token: refreshed.access_token,
  420. refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
  421. expires_at: refreshed.expires_at,
  422. scopes: refreshed.scopes,
  423. };
  424. save_oauth_credentials(&runtime::OAuthTokenSet {
  425. access_token: resolved.access_token.clone(),
  426. refresh_token: resolved.refresh_token.clone(),
  427. expires_at: resolved.expires_at,
  428. scopes: resolved.scopes.clone(),
  429. })
  430. .map_err(ApiError::from)?;
  431. Ok(resolved)
  432. }
  433. fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
  434. where
  435. F: std::future::Future<Output = Result<T, ApiError>>,
  436. {
  437. tokio::runtime::Runtime::new()
  438. .map_err(ApiError::from)?
  439. .block_on(future)
  440. }
  441. fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
  442. let token_set = load_oauth_credentials().map_err(ApiError::from)?;
  443. Ok(token_set.map(|token_set| OAuthTokenSet {
  444. access_token: token_set.access_token,
  445. refresh_token: token_set.refresh_token,
  446. expires_at: token_set.expires_at,
  447. scopes: token_set.scopes,
  448. }))
  449. }
  450. fn now_unix_timestamp() -> u64 {
  451. SystemTime::now()
  452. .duration_since(UNIX_EPOCH)
  453. .map_or(0, |duration| duration.as_secs())
  454. }
  455. fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
  456. match std::env::var(key) {
  457. Ok(value) if !value.is_empty() => Ok(Some(value)),
  458. Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
  459. Err(error) => Err(ApiError::from(error)),
  460. }
  461. }
  462. #[cfg(test)]
  463. fn read_api_key() -> Result<String, ApiError> {
  464. let auth = AuthSource::from_env_or_saved()?;
  465. auth.api_key()
  466. .or_else(|| auth.bearer_token())
  467. .map(ToOwned::to_owned)
  468. .ok_or(ApiError::MissingApiKey)
  469. }
  470. #[cfg(test)]
  471. fn read_auth_token() -> Option<String> {
  472. read_env_non_empty("ANTHROPIC_AUTH_TOKEN")
  473. .ok()
  474. .and_then(std::convert::identity)
  475. }
  476. fn read_base_url() -> String {
  477. std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string())
  478. }
  479. fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
  480. headers
  481. .get(REQUEST_ID_HEADER)
  482. .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
  483. .and_then(|value| value.to_str().ok())
  484. .map(ToOwned::to_owned)
  485. }
  486. #[derive(Debug)]
  487. pub struct MessageStream {
  488. request_id: Option<String>,
  489. response: reqwest::Response,
  490. parser: SseParser,
  491. pending: VecDeque<StreamEvent>,
  492. done: bool,
  493. }
  494. impl MessageStream {
  495. #[must_use]
  496. pub fn request_id(&self) -> Option<&str> {
  497. self.request_id.as_deref()
  498. }
  499. pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
  500. loop {
  501. if let Some(event) = self.pending.pop_front() {
  502. return Ok(Some(event));
  503. }
  504. if self.done {
  505. let remaining = self.parser.finish()?;
  506. self.pending.extend(remaining);
  507. if let Some(event) = self.pending.pop_front() {
  508. return Ok(Some(event));
  509. }
  510. return Ok(None);
  511. }
  512. match self.response.chunk().await? {
  513. Some(chunk) => {
  514. self.pending.extend(self.parser.push(&chunk)?);
  515. }
  516. None => {
  517. self.done = true;
  518. }
  519. }
  520. }
  521. }
  522. }
  523. async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
  524. let status = response.status();
  525. if status.is_success() {
  526. return Ok(response);
  527. }
  528. let body = response.text().await.unwrap_or_else(|_| String::new());
  529. let parsed_error = serde_json::from_str::<AnthropicErrorEnvelope>(&body).ok();
  530. let retryable = is_retryable_status(status);
  531. Err(ApiError::Api {
  532. status,
  533. error_type: parsed_error
  534. .as_ref()
  535. .map(|error| error.error.error_type.clone()),
  536. message: parsed_error
  537. .as_ref()
  538. .map(|error| error.error.message.clone()),
  539. body,
  540. retryable,
  541. })
  542. }
  543. const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
  544. matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
  545. }
  546. #[derive(Debug, Deserialize)]
  547. struct AnthropicErrorEnvelope {
  548. error: AnthropicErrorBody,
  549. }
  550. #[derive(Debug, Deserialize)]
  551. struct AnthropicErrorBody {
  552. #[serde(rename = "type")]
  553. error_type: String,
  554. message: String,
  555. }
  556. #[cfg(test)]
  557. mod tests {
  558. use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
  559. use std::io::{Read, Write};
  560. use std::net::TcpListener;
  561. use std::sync::{Mutex, OnceLock};
  562. use std::thread;
  563. use std::time::{Duration, SystemTime, UNIX_EPOCH};
  564. use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig};
  565. use crate::client::{
  566. now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
  567. resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
  568. };
  569. use crate::types::{ContentBlockDelta, MessageRequest};
  570. fn env_lock() -> std::sync::MutexGuard<'static, ()> {
  571. static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
  572. LOCK.get_or_init(|| Mutex::new(()))
  573. .lock()
  574. .expect("env lock")
  575. }
  576. fn temp_config_home() -> std::path::PathBuf {
  577. std::env::temp_dir().join(format!(
  578. "api-oauth-test-{}-{}",
  579. std::process::id(),
  580. SystemTime::now()
  581. .duration_since(UNIX_EPOCH)
  582. .expect("time")
  583. .as_nanos()
  584. ))
  585. }
  586. fn sample_oauth_config(token_url: String) -> OAuthConfig {
  587. OAuthConfig {
  588. client_id: "runtime-client".to_string(),
  589. authorize_url: "https://console.test/oauth/authorize".to_string(),
  590. token_url,
  591. callback_port: Some(4545),
  592. manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
  593. scopes: vec!["org:read".to_string(), "user:write".to_string()],
  594. }
  595. }
  596. fn spawn_token_server(response_body: &'static str) -> String {
  597. let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
  598. let address = listener.local_addr().expect("local addr");
  599. thread::spawn(move || {
  600. let (mut stream, _) = listener.accept().expect("accept connection");
  601. let mut buffer = [0_u8; 4096];
  602. let _ = stream.read(&mut buffer).expect("read request");
  603. let response = format!(
  604. "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
  605. response_body.len(),
  606. response_body
  607. );
  608. stream
  609. .write_all(response.as_bytes())
  610. .expect("write response");
  611. });
  612. format!("http://{address}/oauth/token")
  613. }
  614. #[test]
  615. fn read_api_key_requires_presence() {
  616. let _guard = env_lock();
  617. std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
  618. std::env::remove_var("ANTHROPIC_API_KEY");
  619. std::env::remove_var("CLAUDE_CONFIG_HOME");
  620. let error = super::read_api_key().expect_err("missing key should error");
  621. assert!(matches!(error, crate::error::ApiError::MissingApiKey));
  622. }
  623. #[test]
  624. fn read_api_key_requires_non_empty_value() {
  625. let _guard = env_lock();
  626. std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
  627. std::env::remove_var("ANTHROPIC_API_KEY");
  628. let error = super::read_api_key().expect_err("empty key should error");
  629. assert!(matches!(error, crate::error::ApiError::MissingApiKey));
  630. std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
  631. }
  632. #[test]
  633. fn read_api_key_prefers_api_key_env() {
  634. let _guard = env_lock();
  635. std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
  636. std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
  637. assert_eq!(
  638. super::read_api_key().expect("api key should load"),
  639. "legacy-key"
  640. );
  641. std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
  642. std::env::remove_var("ANTHROPIC_API_KEY");
  643. }
  644. #[test]
  645. fn read_auth_token_reads_auth_token_env() {
  646. let _guard = env_lock();
  647. std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
  648. assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
  649. std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
  650. }
  651. #[test]
  652. fn oauth_token_maps_to_bearer_auth_source() {
  653. let auth = AuthSource::from(OAuthTokenSet {
  654. access_token: "access-token".to_string(),
  655. refresh_token: Some("refresh".to_string()),
  656. expires_at: Some(123),
  657. scopes: vec!["scope:a".to_string()],
  658. });
  659. assert_eq!(auth.bearer_token(), Some("access-token"));
  660. assert_eq!(auth.api_key(), None);
  661. }
  662. #[test]
  663. fn auth_source_from_env_combines_api_key_and_bearer_token() {
  664. let _guard = env_lock();
  665. std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
  666. std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
  667. let auth = AuthSource::from_env().expect("env auth");
  668. assert_eq!(auth.api_key(), Some("legacy-key"));
  669. assert_eq!(auth.bearer_token(), Some("auth-token"));
  670. std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
  671. std::env::remove_var("ANTHROPIC_API_KEY");
  672. }
  673. #[test]
  674. fn auth_source_from_saved_oauth_when_env_absent() {
  675. let _guard = env_lock();
  676. let config_home = temp_config_home();
  677. std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
  678. std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
  679. std::env::remove_var("ANTHROPIC_API_KEY");
  680. save_oauth_credentials(&runtime::OAuthTokenSet {
  681. access_token: "saved-access-token".to_string(),
  682. refresh_token: Some("refresh".to_string()),
  683. expires_at: Some(now_unix_timestamp() + 300),
  684. scopes: vec!["scope:a".to_string()],
  685. })
  686. .expect("save oauth credentials");
  687. let auth = AuthSource::from_env_or_saved().expect("saved auth");
  688. assert_eq!(auth.bearer_token(), Some("saved-access-token"));
  689. clear_oauth_credentials().expect("clear credentials");
  690. std::env::remove_var("CLAUDE_CONFIG_HOME");
  691. std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
  692. }
  693. #[test]
  694. fn oauth_token_expiry_uses_expires_at_timestamp() {
  695. assert!(oauth_token_is_expired(&OAuthTokenSet {
  696. access_token: "access-token".to_string(),
  697. refresh_token: None,
  698. expires_at: Some(1),
  699. scopes: Vec::new(),
  700. }));
  701. assert!(!oauth_token_is_expired(&OAuthTokenSet {
  702. access_token: "access-token".to_string(),
  703. refresh_token: None,
  704. expires_at: Some(now_unix_timestamp() + 60),
  705. scopes: Vec::new(),
  706. }));
  707. }
  708. #[test]
  709. fn resolve_saved_oauth_token_refreshes_expired_credentials() {
  710. let _guard = env_lock();
  711. let config_home = temp_config_home();
  712. std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
  713. std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
  714. std::env::remove_var("ANTHROPIC_API_KEY");
  715. save_oauth_credentials(&runtime::OAuthTokenSet {
  716. access_token: "expired-access-token".to_string(),
  717. refresh_token: Some("refresh-token".to_string()),
  718. expires_at: Some(1),
  719. scopes: vec!["scope:a".to_string()],
  720. })
  721. .expect("save expired oauth credentials");
  722. let token_url = spawn_token_server(
  723. "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
  724. );
  725. let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
  726. .expect("resolve refreshed token")
  727. .expect("token set present");
  728. assert_eq!(resolved.access_token, "refreshed-token");
  729. let stored = runtime::load_oauth_credentials()
  730. .expect("load stored credentials")
  731. .expect("stored token set");
  732. assert_eq!(stored.access_token, "refreshed-token");
  733. clear_oauth_credentials().expect("clear credentials");
  734. std::env::remove_var("CLAUDE_CONFIG_HOME");
  735. std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
  736. }
  737. #[test]
  738. fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
  739. let _guard = env_lock();
  740. let config_home = temp_config_home();
  741. std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
  742. std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
  743. std::env::remove_var("ANTHROPIC_API_KEY");
  744. save_oauth_credentials(&runtime::OAuthTokenSet {
  745. access_token: "saved-access-token".to_string(),
  746. refresh_token: Some("refresh".to_string()),
  747. expires_at: Some(now_unix_timestamp() + 300),
  748. scopes: vec!["scope:a".to_string()],
  749. })
  750. .expect("save oauth credentials");
  751. let auth = resolve_startup_auth_source(|| panic!("config should not be loaded"))
  752. .expect("startup auth");
  753. assert_eq!(auth.bearer_token(), Some("saved-access-token"));
  754. clear_oauth_credentials().expect("clear credentials");
  755. std::env::remove_var("CLAUDE_CONFIG_HOME");
  756. std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
  757. }
  758. #[test]
  759. fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
  760. let _guard = env_lock();
  761. let config_home = temp_config_home();
  762. std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
  763. std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
  764. std::env::remove_var("ANTHROPIC_API_KEY");
  765. save_oauth_credentials(&runtime::OAuthTokenSet {
  766. access_token: "expired-access-token".to_string(),
  767. refresh_token: Some("refresh-token".to_string()),
  768. expires_at: Some(1),
  769. scopes: vec!["scope:a".to_string()],
  770. })
  771. .expect("save expired oauth credentials");
  772. let error =
  773. resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error");
  774. assert!(
  775. matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing"))
  776. );
  777. let stored = runtime::load_oauth_credentials()
  778. .expect("load stored credentials")
  779. .expect("stored token set");
  780. assert_eq!(stored.access_token, "expired-access-token");
  781. assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
  782. clear_oauth_credentials().expect("clear credentials");
  783. std::env::remove_var("CLAUDE_CONFIG_HOME");
  784. std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
  785. }
  786. #[test]
  787. fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
  788. let _guard = env_lock();
  789. let config_home = temp_config_home();
  790. std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
  791. std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
  792. std::env::remove_var("ANTHROPIC_API_KEY");
  793. save_oauth_credentials(&runtime::OAuthTokenSet {
  794. access_token: "expired-access-token".to_string(),
  795. refresh_token: Some("refresh-token".to_string()),
  796. expires_at: Some(1),
  797. scopes: vec!["scope:a".to_string()],
  798. })
  799. .expect("save expired oauth credentials");
  800. let token_url = spawn_token_server(
  801. "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
  802. );
  803. let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
  804. .expect("resolve refreshed token")
  805. .expect("token set present");
  806. assert_eq!(resolved.access_token, "refreshed-token");
  807. assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token"));
  808. let stored = runtime::load_oauth_credentials()
  809. .expect("load stored credentials")
  810. .expect("stored token set");
  811. assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
  812. clear_oauth_credentials().expect("clear credentials");
  813. std::env::remove_var("CLAUDE_CONFIG_HOME");
  814. std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
  815. }
  816. #[test]
  817. fn message_request_stream_helper_sets_stream_true() {
  818. let request = MessageRequest {
  819. model: "claude-3-7-sonnet-latest".to_string(),
  820. max_tokens: 64,
  821. messages: vec![],
  822. system: None,
  823. tools: None,
  824. tool_choice: None,
  825. stream: false,
  826. };
  827. assert!(request.with_streaming().stream);
  828. }
  829. #[test]
  830. fn backoff_doubles_until_maximum() {
  831. let client = AnthropicClient::new("test-key").with_retry_policy(
  832. 3,
  833. Duration::from_millis(10),
  834. Duration::from_millis(25),
  835. );
  836. assert_eq!(
  837. client.backoff_for_attempt(1).expect("attempt 1"),
  838. Duration::from_millis(10)
  839. );
  840. assert_eq!(
  841. client.backoff_for_attempt(2).expect("attempt 2"),
  842. Duration::from_millis(20)
  843. );
  844. assert_eq!(
  845. client.backoff_for_attempt(3).expect("attempt 3"),
  846. Duration::from_millis(25)
  847. );
  848. }
  849. #[test]
  850. fn retryable_statuses_are_detected() {
  851. assert!(super::is_retryable_status(
  852. reqwest::StatusCode::TOO_MANY_REQUESTS
  853. ));
  854. assert!(super::is_retryable_status(
  855. reqwest::StatusCode::INTERNAL_SERVER_ERROR
  856. ));
  857. assert!(!super::is_retryable_status(
  858. reqwest::StatusCode::UNAUTHORIZED
  859. ));
  860. }
  861. #[test]
  862. fn tool_delta_variant_round_trips() {
  863. let delta = ContentBlockDelta::InputJsonDelta {
  864. partial_json: "{\"city\":\"Paris\"}".to_string(),
  865. };
  866. let encoded = serde_json::to_string(&delta).expect("delta should serialize");
  867. let decoded: ContentBlockDelta =
  868. serde_json::from_str(&encoded).expect("delta should deserialize");
  869. assert_eq!(decoded, delta);
  870. }
  871. #[test]
  872. fn request_id_uses_primary_or_fallback_header() {
  873. let mut headers = reqwest::header::HeaderMap::new();
  874. headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header"));
  875. assert_eq!(
  876. super::request_id_from_headers(&headers).as_deref(),
  877. Some("req_primary")
  878. );
  879. headers.clear();
  880. headers.insert(
  881. ALT_REQUEST_ID_HEADER,
  882. "req_fallback".parse().expect("header"),
  883. );
  884. assert_eq!(
  885. super::request_id_from_headers(&headers).as_deref(),
  886. Some("req_fallback")
  887. );
  888. }
  889. #[test]
  890. fn auth_source_applies_headers() {
  891. let auth = AuthSource::ApiKeyAndBearer {
  892. api_key: "test-key".to_string(),
  893. bearer_token: "proxy-token".to_string(),
  894. };
  895. let request = auth
  896. .apply(reqwest::Client::new().post("https://example.test"))
  897. .build()
  898. .expect("request build");
  899. let headers = request.headers();
  900. assert_eq!(
  901. headers.get("x-api-key").and_then(|v| v.to_str().ok()),
  902. Some("test-key")
  903. );
  904. assert_eq!(
  905. headers.get("authorization").and_then(|v| v.to_str().ok()),
  906. Some("Bearer proxy-token")
  907. );
  908. }
  909. }