client_integration.rs 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. use std::collections::HashMap;
  2. use std::sync::Arc;
  3. use std::sync::{Mutex as StdMutex, OnceLock};
  4. use std::time::Duration;
  5. use api::{
  6. AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
  7. InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock,
  8. PromptCache, StreamEvent, ToolChoice, ToolDefinition,
  9. };
  10. use serde_json::json;
  11. use tokio::io::{AsyncReadExt, AsyncWriteExt};
  12. use tokio::net::TcpListener;
  13. use tokio::sync::Mutex;
  14. fn env_lock() -> std::sync::MutexGuard<'static, ()> {
  15. static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
  16. LOCK.get_or_init(|| StdMutex::new(()))
  17. .lock()
  18. .unwrap_or_else(std::sync::PoisonError::into_inner)
  19. }
  20. #[tokio::test]
  21. async fn send_message_posts_json_and_parses_response() {
  22. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  23. let body = concat!(
  24. "{",
  25. "\"id\":\"msg_test\",",
  26. "\"type\":\"message\",",
  27. "\"role\":\"assistant\",",
  28. "\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],",
  29. "\"model\":\"claude-3-7-sonnet-latest\",",
  30. "\"stop_reason\":\"end_turn\",",
  31. "\"stop_sequence\":null,",
  32. "\"usage\":{\"input_tokens\":12,\"output_tokens\":4},",
  33. "\"request_id\":\"req_body_123\"",
  34. "}"
  35. );
  36. let server = spawn_server(
  37. state.clone(),
  38. vec![http_response("200 OK", "application/json", body)],
  39. )
  40. .await;
  41. let client = AnthropicClient::new("test-key")
  42. .with_auth_token(Some("proxy-token".to_string()))
  43. .with_base_url(server.base_url());
  44. let response = client
  45. .send_message(&sample_request(false))
  46. .await
  47. .expect("request should succeed");
  48. assert_eq!(response.id, "msg_test");
  49. assert_eq!(response.total_tokens(), 16);
  50. assert_eq!(response.request_id.as_deref(), Some("req_body_123"));
  51. assert_eq!(
  52. response.content,
  53. vec![OutputContentBlock::Text {
  54. text: "Hello from Claude".to_string(),
  55. }]
  56. );
  57. let captured = state.lock().await;
  58. let request = captured.first().expect("server should capture request");
  59. assert_eq!(request.method, "POST");
  60. assert_eq!(request.path, "/v1/messages");
  61. assert_eq!(
  62. request.headers.get("x-api-key").map(String::as_str),
  63. Some("test-key")
  64. );
  65. assert_eq!(
  66. request.headers.get("authorization").map(String::as_str),
  67. Some("Bearer proxy-token")
  68. );
  69. let body: serde_json::Value =
  70. serde_json::from_str(&request.body).expect("request body should be json");
  71. assert_eq!(
  72. body.get("model").and_then(serde_json::Value::as_str),
  73. Some("claude-3-7-sonnet-latest")
  74. );
  75. assert!(body.get("stream").is_none());
  76. assert_eq!(body["tools"][0]["name"], json!("get_weather"));
  77. assert_eq!(body["tool_choice"]["type"], json!("auto"));
  78. }
  79. #[tokio::test]
  80. #[allow(clippy::await_holding_lock)]
  81. async fn stream_message_parses_sse_events_with_tool_use() {
  82. let _guard = env_lock();
  83. let temp_root = std::env::temp_dir().join(format!(
  84. "api-stream-cache-{}-{}",
  85. std::process::id(),
  86. std::time::SystemTime::now()
  87. .duration_since(std::time::UNIX_EPOCH)
  88. .expect("time")
  89. .as_nanos()
  90. ));
  91. std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
  92. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  93. let sse = concat!(
  94. "event: message_start\n",
  95. "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
  96. "event: content_block_start\n",
  97. "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n",
  98. "event: content_block_delta\n",
  99. "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}\n\n",
  100. "event: content_block_stop\n",
  101. "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
  102. "event: message_delta\n",
  103. "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n",
  104. "event: message_stop\n",
  105. "data: {\"type\":\"message_stop\"}\n\n",
  106. "data: [DONE]\n\n"
  107. );
  108. let server = spawn_server(
  109. state.clone(),
  110. vec![http_response_with_headers(
  111. "200 OK",
  112. "text/event-stream",
  113. sse,
  114. &[("request-id", "req_stream_456")],
  115. )],
  116. )
  117. .await;
  118. let client = AnthropicClient::new("test-key")
  119. .with_auth_token(Some("proxy-token".to_string()))
  120. .with_base_url(server.base_url())
  121. .with_prompt_cache(PromptCache::new("stream-session"));
  122. let mut stream = client
  123. .stream_message(&sample_request(false))
  124. .await
  125. .expect("stream should start");
  126. assert_eq!(stream.request_id(), Some("req_stream_456"));
  127. let mut events = Vec::new();
  128. while let Some(event) = stream
  129. .next_event()
  130. .await
  131. .expect("stream event should parse")
  132. {
  133. events.push(event);
  134. }
  135. assert_eq!(events.len(), 6);
  136. assert!(matches!(events[0], StreamEvent::MessageStart(_)));
  137. assert!(matches!(
  138. events[1],
  139. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  140. content_block: OutputContentBlock::ToolUse { .. },
  141. ..
  142. })
  143. ));
  144. assert!(matches!(
  145. events[2],
  146. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  147. delta: ContentBlockDelta::InputJsonDelta { .. },
  148. ..
  149. })
  150. ));
  151. assert!(matches!(events[3], StreamEvent::ContentBlockStop(_)));
  152. assert!(matches!(
  153. events[4],
  154. StreamEvent::MessageDelta(MessageDeltaEvent { .. })
  155. ));
  156. assert!(matches!(events[5], StreamEvent::MessageStop(_)));
  157. match &events[1] {
  158. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  159. content_block: OutputContentBlock::ToolUse { name, input, .. },
  160. ..
  161. }) => {
  162. assert_eq!(name, "get_weather");
  163. assert_eq!(input, &json!({}));
  164. }
  165. other => panic!("expected tool_use block, got {other:?}"),
  166. }
  167. let captured = state.lock().await;
  168. let request = captured.first().expect("server should capture request");
  169. assert!(request.body.contains("\"stream\":true"));
  170. let cache_stats = client
  171. .prompt_cache_stats()
  172. .expect("prompt cache stats should exist");
  173. assert_eq!(cache_stats.tracked_requests, 1);
  174. assert_eq!(cache_stats.last_cache_read_input_tokens, Some(0));
  175. assert_eq!(
  176. cache_stats.last_cache_source.as_deref(),
  177. Some("api-response")
  178. );
  179. std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
  180. std::env::remove_var("CLAUDE_CONFIG_HOME");
  181. }
  182. #[tokio::test]
  183. async fn retries_retryable_failures_before_succeeding() {
  184. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  185. let server = spawn_server(
  186. state.clone(),
  187. vec![
  188. http_response(
  189. "429 Too Many Requests",
  190. "application/json",
  191. "{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down\"}}",
  192. ),
  193. http_response(
  194. "200 OK",
  195. "application/json",
  196. "{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
  197. ),
  198. ],
  199. )
  200. .await;
  201. let client = AnthropicClient::new("test-key")
  202. .with_base_url(server.base_url())
  203. .with_retry_policy(2, Duration::from_millis(1), Duration::from_millis(2));
  204. let response = client
  205. .send_message(&sample_request(false))
  206. .await
  207. .expect("retry should eventually succeed");
  208. assert_eq!(response.total_tokens(), 5);
  209. assert_eq!(state.lock().await.len(), 2);
  210. }
  211. #[tokio::test]
  212. async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
  213. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  214. let server = spawn_server(
  215. state.clone(),
  216. vec![
  217. http_response(
  218. "503 Service Unavailable",
  219. "application/json",
  220. "{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"busy\"}}",
  221. ),
  222. http_response(
  223. "503 Service Unavailable",
  224. "application/json",
  225. "{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"still busy\"}}",
  226. ),
  227. ],
  228. )
  229. .await;
  230. let client = AnthropicClient::new("test-key")
  231. .with_base_url(server.base_url())
  232. .with_retry_policy(1, Duration::from_millis(1), Duration::from_millis(2));
  233. let error = client
  234. .send_message(&sample_request(false))
  235. .await
  236. .expect_err("persistent 503 should fail");
  237. match error {
  238. ApiError::RetriesExhausted {
  239. attempts,
  240. last_error,
  241. } => {
  242. assert_eq!(attempts, 2);
  243. assert!(matches!(
  244. *last_error,
  245. ApiError::Api {
  246. status: reqwest::StatusCode::SERVICE_UNAVAILABLE,
  247. retryable: true,
  248. ..
  249. }
  250. ));
  251. }
  252. other => panic!("expected retries exhausted, got {other:?}"),
  253. }
  254. }
  255. #[tokio::test]
  256. #[allow(clippy::await_holding_lock)]
  257. async fn send_message_reuses_recent_completion_cache_entries() {
  258. let _guard = env_lock();
  259. let temp_root = std::env::temp_dir().join(format!(
  260. "api-prompt-cache-{}-{}",
  261. std::process::id(),
  262. std::time::SystemTime::now()
  263. .duration_since(std::time::UNIX_EPOCH)
  264. .expect("time")
  265. .as_nanos()
  266. ));
  267. std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
  268. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  269. let server = spawn_server(
  270. state.clone(),
  271. vec![http_response(
  272. "200 OK",
  273. "application/json",
  274. "{\"id\":\"msg_cached\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Cached once\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":4000,\"output_tokens\":2}}",
  275. )],
  276. )
  277. .await;
  278. let client = AnthropicClient::new("test-key")
  279. .with_base_url(server.base_url())
  280. .with_prompt_cache(PromptCache::new("integration-session"));
  281. let first = client
  282. .send_message(&sample_request(false))
  283. .await
  284. .expect("first request should succeed");
  285. let second = client
  286. .send_message(&sample_request(false))
  287. .await
  288. .expect("second request should reuse cache");
  289. assert_eq!(first.content, second.content);
  290. assert_eq!(state.lock().await.len(), 1);
  291. let cache_stats = client
  292. .prompt_cache_stats()
  293. .expect("prompt cache stats should exist");
  294. assert_eq!(cache_stats.completion_cache_hits, 1);
  295. assert_eq!(cache_stats.completion_cache_misses, 1);
  296. assert_eq!(cache_stats.completion_cache_writes, 1);
  297. std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
  298. std::env::remove_var("CLAUDE_CONFIG_HOME");
  299. }
  300. #[tokio::test]
  301. #[allow(clippy::await_holding_lock)]
  302. async fn send_message_tracks_unexpected_prompt_cache_breaks() {
  303. let _guard = env_lock();
  304. let temp_root = std::env::temp_dir().join(format!(
  305. "api-prompt-break-{}-{}",
  306. std::process::id(),
  307. std::time::SystemTime::now()
  308. .duration_since(std::time::UNIX_EPOCH)
  309. .expect("time")
  310. .as_nanos()
  311. ));
  312. std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
  313. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  314. let server = spawn_server(
  315. state,
  316. vec![
  317. http_response(
  318. "200 OK",
  319. "application/json",
  320. "{\"id\":\"msg_one\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"One\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":6000,\"output_tokens\":2}}",
  321. ),
  322. http_response(
  323. "200 OK",
  324. "application/json",
  325. "{\"id\":\"msg_two\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Two\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":1000,\"output_tokens\":2}}",
  326. ),
  327. ],
  328. )
  329. .await;
  330. let request = sample_request(false);
  331. let client = AnthropicClient::new("test-key")
  332. .with_base_url(server.base_url())
  333. .with_prompt_cache(PromptCache::with_config(api::PromptCacheConfig {
  334. session_id: "break-session".to_string(),
  335. completion_ttl: Duration::from_secs(0),
  336. ..api::PromptCacheConfig::default()
  337. }));
  338. client
  339. .send_message(&request)
  340. .await
  341. .expect("first response should succeed");
  342. client
  343. .send_message(&request)
  344. .await
  345. .expect("second response should succeed");
  346. let cache_stats = client
  347. .prompt_cache_stats()
  348. .expect("prompt cache stats should exist");
  349. assert_eq!(cache_stats.unexpected_cache_breaks, 1);
  350. assert_eq!(
  351. cache_stats.last_break_reason.as_deref(),
  352. Some("cache read tokens dropped while prompt fingerprint remained stable")
  353. );
  354. std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
  355. std::env::remove_var("CLAUDE_CONFIG_HOME");
  356. }
  357. #[tokio::test]
  358. #[ignore = "requires ANTHROPIC_API_KEY and network access"]
  359. async fn live_stream_smoke_test() {
  360. let client = AnthropicClient::from_env().expect("ANTHROPIC_API_KEY must be set");
  361. let mut stream = client
  362. .stream_message(&MessageRequest {
  363. model: std::env::var("ANTHROPIC_MODEL")
  364. .unwrap_or_else(|_| "claude-3-7-sonnet-latest".to_string()),
  365. max_tokens: 32,
  366. messages: vec![InputMessage::user_text(
  367. "Reply with exactly: hello from rust",
  368. )],
  369. system: None,
  370. tools: None,
  371. tool_choice: None,
  372. stream: false,
  373. })
  374. .await
  375. .expect("live stream should start");
  376. while let Some(_event) = stream
  377. .next_event()
  378. .await
  379. .expect("live stream should yield events")
  380. {}
  381. }
  382. #[derive(Debug, Clone, PartialEq, Eq)]
  383. struct CapturedRequest {
  384. method: String,
  385. path: String,
  386. headers: HashMap<String, String>,
  387. body: String,
  388. }
  389. struct TestServer {
  390. base_url: String,
  391. join_handle: tokio::task::JoinHandle<()>,
  392. }
  393. impl TestServer {
  394. fn base_url(&self) -> String {
  395. self.base_url.clone()
  396. }
  397. }
  398. impl Drop for TestServer {
  399. fn drop(&mut self) {
  400. self.join_handle.abort();
  401. }
  402. }
  403. async fn spawn_server(
  404. state: Arc<Mutex<Vec<CapturedRequest>>>,
  405. responses: Vec<String>,
  406. ) -> TestServer {
  407. let listener = TcpListener::bind("127.0.0.1:0")
  408. .await
  409. .expect("listener should bind");
  410. let address = listener
  411. .local_addr()
  412. .expect("listener should have local addr");
  413. let join_handle = tokio::spawn(async move {
  414. for response in responses {
  415. let (mut socket, _) = listener.accept().await.expect("server should accept");
  416. let mut buffer = Vec::new();
  417. let mut header_end = None;
  418. loop {
  419. let mut chunk = [0_u8; 1024];
  420. let read = socket
  421. .read(&mut chunk)
  422. .await
  423. .expect("request read should succeed");
  424. if read == 0 {
  425. break;
  426. }
  427. buffer.extend_from_slice(&chunk[..read]);
  428. if let Some(position) = find_header_end(&buffer) {
  429. header_end = Some(position);
  430. break;
  431. }
  432. }
  433. let header_end = header_end.expect("request should include headers");
  434. let (header_bytes, remaining) = buffer.split_at(header_end);
  435. let header_text =
  436. String::from_utf8(header_bytes.to_vec()).expect("headers should be utf8");
  437. let mut lines = header_text.split("\r\n");
  438. let request_line = lines.next().expect("request line should exist");
  439. let mut parts = request_line.split_whitespace();
  440. let method = parts.next().expect("method should exist").to_string();
  441. let path = parts.next().expect("path should exist").to_string();
  442. let mut headers = HashMap::new();
  443. let mut content_length = 0_usize;
  444. for line in lines {
  445. if line.is_empty() {
  446. continue;
  447. }
  448. let (name, value) = line.split_once(':').expect("header should have colon");
  449. let value = value.trim().to_string();
  450. if name.eq_ignore_ascii_case("content-length") {
  451. content_length = value.parse().expect("content length should parse");
  452. }
  453. headers.insert(name.to_ascii_lowercase(), value);
  454. }
  455. let mut body = remaining[4..].to_vec();
  456. while body.len() < content_length {
  457. let mut chunk = vec![0_u8; content_length - body.len()];
  458. let read = socket
  459. .read(&mut chunk)
  460. .await
  461. .expect("body read should succeed");
  462. if read == 0 {
  463. break;
  464. }
  465. body.extend_from_slice(&chunk[..read]);
  466. }
  467. state.lock().await.push(CapturedRequest {
  468. method,
  469. path,
  470. headers,
  471. body: String::from_utf8(body).expect("body should be utf8"),
  472. });
  473. socket
  474. .write_all(response.as_bytes())
  475. .await
  476. .expect("response write should succeed");
  477. }
  478. });
  479. TestServer {
  480. base_url: format!("http://{address}"),
  481. join_handle,
  482. }
  483. }
  484. fn find_header_end(bytes: &[u8]) -> Option<usize> {
  485. bytes.windows(4).position(|window| window == b"\r\n\r\n")
  486. }
  487. fn http_response(status: &str, content_type: &str, body: &str) -> String {
  488. http_response_with_headers(status, content_type, body, &[])
  489. }
  490. fn http_response_with_headers(
  491. status: &str,
  492. content_type: &str,
  493. body: &str,
  494. headers: &[(&str, &str)],
  495. ) -> String {
  496. let mut extra_headers = String::new();
  497. for (name, value) in headers {
  498. use std::fmt::Write as _;
  499. write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write should succeed");
  500. }
  501. format!(
  502. "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
  503. body.len()
  504. )
  505. }
  506. fn sample_request(stream: bool) -> MessageRequest {
  507. MessageRequest {
  508. model: "claude-3-7-sonnet-latest".to_string(),
  509. max_tokens: 64,
  510. messages: vec![InputMessage {
  511. role: "user".to_string(),
  512. content: vec![
  513. InputContentBlock::Text {
  514. text: "Say hello".to_string(),
  515. },
  516. InputContentBlock::ToolResult {
  517. tool_use_id: "toolu_prev".to_string(),
  518. content: vec![api::ToolResultContentBlock::Json {
  519. value: json!({"forecast": "sunny"}),
  520. }],
  521. is_error: false,
  522. },
  523. ],
  524. }],
  525. system: Some("Use tools when needed".to_string()),
  526. tools: Some(vec![ToolDefinition {
  527. name: "get_weather".to_string(),
  528. description: Some("Fetches the weather".to_string()),
  529. input_schema: json!({
  530. "type": "object",
  531. "properties": {"city": {"type": "string"}},
  532. "required": ["city"]
  533. }),
  534. }]),
  535. tool_choice: Some(ToolChoice::Auto),
  536. stream,
  537. }
  538. }