client_integration.rs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. use std::collections::HashMap;
  2. use std::sync::Arc;
  3. use std::sync::{Mutex as StdMutex, OnceLock};
  4. use std::time::Duration;
  5. use api::{
  6. AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
  7. InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock,
  8. PromptCache, StreamEvent, ToolChoice, ToolDefinition,
  9. };
  10. use serde_json::json;
  11. use tokio::io::{AsyncReadExt, AsyncWriteExt};
  12. use tokio::net::TcpListener;
  13. use tokio::sync::Mutex;
  14. fn env_lock() -> std::sync::MutexGuard<'static, ()> {
  15. static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
  16. LOCK.get_or_init(|| StdMutex::new(()))
  17. .lock()
  18. .unwrap_or_else(std::sync::PoisonError::into_inner)
  19. }
  20. #[tokio::test]
  21. async fn send_message_posts_json_and_parses_response() {
  22. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  23. let body = concat!(
  24. "{",
  25. "\"id\":\"msg_test\",",
  26. "\"type\":\"message\",",
  27. "\"role\":\"assistant\",",
  28. "\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],",
  29. "\"model\":\"claude-3-7-sonnet-latest\",",
  30. "\"stop_reason\":\"end_turn\",",
  31. "\"stop_sequence\":null,",
  32. "\"usage\":{\"input_tokens\":12,\"output_tokens\":4},",
  33. "\"request_id\":\"req_body_123\"",
  34. "}"
  35. );
  36. let server = spawn_server(
  37. state.clone(),
  38. vec![http_response("200 OK", "application/json", body)],
  39. )
  40. .await;
  41. let client = AnthropicClient::new("test-key")
  42. .with_auth_token(Some("proxy-token".to_string()))
  43. .with_base_url(server.base_url());
  44. let response = client
  45. .send_message(&sample_request(false))
  46. .await
  47. .expect("request should succeed");
  48. assert_eq!(response.id, "msg_test");
  49. assert_eq!(response.total_tokens(), 16);
  50. assert_eq!(response.request_id.as_deref(), Some("req_body_123"));
  51. assert_eq!(
  52. response.content,
  53. vec![OutputContentBlock::Text {
  54. text: "Hello from Claude".to_string(),
  55. }]
  56. );
  57. let captured = state.lock().await;
  58. let request = captured.first().expect("server should capture request");
  59. assert_eq!(request.method, "POST");
  60. assert_eq!(request.path, "/v1/messages");
  61. assert_eq!(
  62. request.headers.get("x-api-key").map(String::as_str),
  63. Some("test-key")
  64. );
  65. assert_eq!(
  66. request.headers.get("authorization").map(String::as_str),
  67. Some("Bearer proxy-token")
  68. );
  69. let body: serde_json::Value =
  70. serde_json::from_str(&request.body).expect("request body should be json");
  71. assert_eq!(
  72. body.get("model").and_then(serde_json::Value::as_str),
  73. Some("claude-3-7-sonnet-latest")
  74. );
  75. assert!(body.get("stream").is_none());
  76. assert_eq!(body["tools"][0]["name"], json!("get_weather"));
  77. assert_eq!(body["tool_choice"]["type"], json!("auto"));
  78. }
  79. #[tokio::test]
  80. async fn stream_message_parses_sse_events_with_tool_use() {
  81. let _guard = env_lock();
  82. let temp_root = std::env::temp_dir().join(format!(
  83. "api-stream-cache-{}-{}",
  84. std::process::id(),
  85. std::time::SystemTime::now()
  86. .duration_since(std::time::UNIX_EPOCH)
  87. .expect("time")
  88. .as_nanos()
  89. ));
  90. std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
  91. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  92. let sse = concat!(
  93. "event: message_start\n",
  94. "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
  95. "event: content_block_start\n",
  96. "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n",
  97. "event: content_block_delta\n",
  98. "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}\n\n",
  99. "event: content_block_stop\n",
  100. "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
  101. "event: message_delta\n",
  102. "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n",
  103. "event: message_stop\n",
  104. "data: {\"type\":\"message_stop\"}\n\n",
  105. "data: [DONE]\n\n"
  106. );
  107. let server = spawn_server(
  108. state.clone(),
  109. vec![http_response_with_headers(
  110. "200 OK",
  111. "text/event-stream",
  112. sse,
  113. &[("request-id", "req_stream_456")],
  114. )],
  115. )
  116. .await;
  117. let client = AnthropicClient::new("test-key")
  118. .with_auth_token(Some("proxy-token".to_string()))
  119. .with_base_url(server.base_url())
  120. .with_prompt_cache(PromptCache::new("stream-session"));
  121. let mut stream = client
  122. .stream_message(&sample_request(false))
  123. .await
  124. .expect("stream should start");
  125. assert_eq!(stream.request_id(), Some("req_stream_456"));
  126. let mut events = Vec::new();
  127. while let Some(event) = stream
  128. .next_event()
  129. .await
  130. .expect("stream event should parse")
  131. {
  132. events.push(event);
  133. }
  134. assert_eq!(events.len(), 6);
  135. assert!(matches!(events[0], StreamEvent::MessageStart(_)));
  136. assert!(matches!(
  137. events[1],
  138. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  139. content_block: OutputContentBlock::ToolUse { .. },
  140. ..
  141. })
  142. ));
  143. assert!(matches!(
  144. events[2],
  145. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  146. delta: ContentBlockDelta::InputJsonDelta { .. },
  147. ..
  148. })
  149. ));
  150. assert!(matches!(events[3], StreamEvent::ContentBlockStop(_)));
  151. assert!(matches!(
  152. events[4],
  153. StreamEvent::MessageDelta(MessageDeltaEvent { .. })
  154. ));
  155. assert!(matches!(events[5], StreamEvent::MessageStop(_)));
  156. match &events[1] {
  157. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  158. content_block: OutputContentBlock::ToolUse { name, input, .. },
  159. ..
  160. }) => {
  161. assert_eq!(name, "get_weather");
  162. assert_eq!(input, &json!({}));
  163. }
  164. other => panic!("expected tool_use block, got {other:?}"),
  165. }
  166. let captured = state.lock().await;
  167. let request = captured.first().expect("server should capture request");
  168. assert!(request.body.contains("\"stream\":true"));
  169. let stats = client
  170. .prompt_cache_stats()
  171. .expect("prompt cache stats should exist");
  172. assert_eq!(stats.tracked_requests, 1);
  173. assert_eq!(stats.last_cache_read_input_tokens, Some(0));
  174. assert_eq!(stats.last_cache_source.as_deref(), Some("api-response"));
  175. std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
  176. std::env::remove_var("CLAUDE_CONFIG_HOME");
  177. }
  178. #[tokio::test]
  179. async fn retries_retryable_failures_before_succeeding() {
  180. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  181. let server = spawn_server(
  182. state.clone(),
  183. vec![
  184. http_response(
  185. "429 Too Many Requests",
  186. "application/json",
  187. "{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down\"}}",
  188. ),
  189. http_response(
  190. "200 OK",
  191. "application/json",
  192. "{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
  193. ),
  194. ],
  195. )
  196. .await;
  197. let client = AnthropicClient::new("test-key")
  198. .with_base_url(server.base_url())
  199. .with_retry_policy(2, Duration::from_millis(1), Duration::from_millis(2));
  200. let response = client
  201. .send_message(&sample_request(false))
  202. .await
  203. .expect("retry should eventually succeed");
  204. assert_eq!(response.total_tokens(), 5);
  205. assert_eq!(state.lock().await.len(), 2);
  206. }
  207. #[tokio::test]
  208. async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
  209. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  210. let server = spawn_server(
  211. state.clone(),
  212. vec![
  213. http_response(
  214. "503 Service Unavailable",
  215. "application/json",
  216. "{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"busy\"}}",
  217. ),
  218. http_response(
  219. "503 Service Unavailable",
  220. "application/json",
  221. "{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"still busy\"}}",
  222. ),
  223. ],
  224. )
  225. .await;
  226. let client = AnthropicClient::new("test-key")
  227. .with_base_url(server.base_url())
  228. .with_retry_policy(1, Duration::from_millis(1), Duration::from_millis(2));
  229. let error = client
  230. .send_message(&sample_request(false))
  231. .await
  232. .expect_err("persistent 503 should fail");
  233. match error {
  234. ApiError::RetriesExhausted {
  235. attempts,
  236. last_error,
  237. } => {
  238. assert_eq!(attempts, 2);
  239. assert!(matches!(
  240. *last_error,
  241. ApiError::Api {
  242. status: reqwest::StatusCode::SERVICE_UNAVAILABLE,
  243. retryable: true,
  244. ..
  245. }
  246. ));
  247. }
  248. other => panic!("expected retries exhausted, got {other:?}"),
  249. }
  250. }
  251. #[tokio::test]
  252. async fn send_message_reuses_recent_completion_cache_entries() {
  253. let _guard = env_lock();
  254. let temp_root = std::env::temp_dir().join(format!(
  255. "api-prompt-cache-{}-{}",
  256. std::process::id(),
  257. std::time::SystemTime::now()
  258. .duration_since(std::time::UNIX_EPOCH)
  259. .expect("time")
  260. .as_nanos()
  261. ));
  262. std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
  263. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  264. let server = spawn_server(
  265. state.clone(),
  266. vec![http_response(
  267. "200 OK",
  268. "application/json",
  269. "{\"id\":\"msg_cached\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Cached once\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":4000,\"output_tokens\":2}}",
  270. )],
  271. )
  272. .await;
  273. let client = AnthropicClient::new("test-key")
  274. .with_base_url(server.base_url())
  275. .with_prompt_cache(PromptCache::new("integration-session"));
  276. let first = client
  277. .send_message(&sample_request(false))
  278. .await
  279. .expect("first request should succeed");
  280. let second = client
  281. .send_message(&sample_request(false))
  282. .await
  283. .expect("second request should reuse cache");
  284. assert_eq!(first.content, second.content);
  285. assert_eq!(state.lock().await.len(), 1);
  286. let stats = client
  287. .prompt_cache_stats()
  288. .expect("prompt cache stats should exist");
  289. assert_eq!(stats.completion_cache_hits, 1);
  290. assert_eq!(stats.completion_cache_misses, 1);
  291. assert_eq!(stats.completion_cache_writes, 1);
  292. std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
  293. std::env::remove_var("CLAUDE_CONFIG_HOME");
  294. }
  295. #[tokio::test]
  296. async fn send_message_tracks_unexpected_prompt_cache_breaks() {
  297. let _guard = env_lock();
  298. let temp_root = std::env::temp_dir().join(format!(
  299. "api-prompt-break-{}-{}",
  300. std::process::id(),
  301. std::time::SystemTime::now()
  302. .duration_since(std::time::UNIX_EPOCH)
  303. .expect("time")
  304. .as_nanos()
  305. ));
  306. std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
  307. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  308. let server = spawn_server(
  309. state,
  310. vec![
  311. http_response(
  312. "200 OK",
  313. "application/json",
  314. "{\"id\":\"msg_one\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"One\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":6000,\"output_tokens\":2}}",
  315. ),
  316. http_response(
  317. "200 OK",
  318. "application/json",
  319. "{\"id\":\"msg_two\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Two\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":1000,\"output_tokens\":2}}",
  320. ),
  321. ],
  322. )
  323. .await;
  324. let request = sample_request(false);
  325. let client = AnthropicClient::new("test-key")
  326. .with_base_url(server.base_url())
  327. .with_prompt_cache(PromptCache::with_config(api::PromptCacheConfig {
  328. session_id: "break-session".to_string(),
  329. completion_ttl: Duration::from_secs(0),
  330. ..api::PromptCacheConfig::default()
  331. }));
  332. client
  333. .send_message(&request)
  334. .await
  335. .expect("first response should succeed");
  336. client
  337. .send_message(&request)
  338. .await
  339. .expect("second response should succeed");
  340. let stats = client
  341. .prompt_cache_stats()
  342. .expect("prompt cache stats should exist");
  343. assert_eq!(stats.unexpected_cache_breaks, 1);
  344. assert_eq!(
  345. stats.last_break_reason.as_deref(),
  346. Some("cache read tokens dropped while prompt fingerprint remained stable")
  347. );
  348. std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
  349. std::env::remove_var("CLAUDE_CONFIG_HOME");
  350. }
  351. #[tokio::test]
  352. #[ignore = "requires ANTHROPIC_API_KEY and network access"]
  353. async fn live_stream_smoke_test() {
  354. let client = AnthropicClient::from_env().expect("ANTHROPIC_API_KEY must be set");
  355. let mut stream = client
  356. .stream_message(&MessageRequest {
  357. model: std::env::var("ANTHROPIC_MODEL")
  358. .unwrap_or_else(|_| "claude-3-7-sonnet-latest".to_string()),
  359. max_tokens: 32,
  360. messages: vec![InputMessage::user_text(
  361. "Reply with exactly: hello from rust",
  362. )],
  363. system: None,
  364. tools: None,
  365. tool_choice: None,
  366. stream: false,
  367. })
  368. .await
  369. .expect("live stream should start");
  370. while let Some(_event) = stream
  371. .next_event()
  372. .await
  373. .expect("live stream should yield events")
  374. {}
  375. }
  376. #[derive(Debug, Clone, PartialEq, Eq)]
  377. struct CapturedRequest {
  378. method: String,
  379. path: String,
  380. headers: HashMap<String, String>,
  381. body: String,
  382. }
  383. struct TestServer {
  384. base_url: String,
  385. join_handle: tokio::task::JoinHandle<()>,
  386. }
  387. impl TestServer {
  388. fn base_url(&self) -> String {
  389. self.base_url.clone()
  390. }
  391. }
  392. impl Drop for TestServer {
  393. fn drop(&mut self) {
  394. self.join_handle.abort();
  395. }
  396. }
  397. async fn spawn_server(
  398. state: Arc<Mutex<Vec<CapturedRequest>>>,
  399. responses: Vec<String>,
  400. ) -> TestServer {
  401. let listener = TcpListener::bind("127.0.0.1:0")
  402. .await
  403. .expect("listener should bind");
  404. let address = listener
  405. .local_addr()
  406. .expect("listener should have local addr");
  407. let join_handle = tokio::spawn(async move {
  408. for response in responses {
  409. let (mut socket, _) = listener.accept().await.expect("server should accept");
  410. let mut buffer = Vec::new();
  411. let mut header_end = None;
  412. loop {
  413. let mut chunk = [0_u8; 1024];
  414. let read = socket
  415. .read(&mut chunk)
  416. .await
  417. .expect("request read should succeed");
  418. if read == 0 {
  419. break;
  420. }
  421. buffer.extend_from_slice(&chunk[..read]);
  422. if let Some(position) = find_header_end(&buffer) {
  423. header_end = Some(position);
  424. break;
  425. }
  426. }
  427. let header_end = header_end.expect("request should include headers");
  428. let (header_bytes, remaining) = buffer.split_at(header_end);
  429. let header_text =
  430. String::from_utf8(header_bytes.to_vec()).expect("headers should be utf8");
  431. let mut lines = header_text.split("\r\n");
  432. let request_line = lines.next().expect("request line should exist");
  433. let mut parts = request_line.split_whitespace();
  434. let method = parts.next().expect("method should exist").to_string();
  435. let path = parts.next().expect("path should exist").to_string();
  436. let mut headers = HashMap::new();
  437. let mut content_length = 0_usize;
  438. for line in lines {
  439. if line.is_empty() {
  440. continue;
  441. }
  442. let (name, value) = line.split_once(':').expect("header should have colon");
  443. let value = value.trim().to_string();
  444. if name.eq_ignore_ascii_case("content-length") {
  445. content_length = value.parse().expect("content length should parse");
  446. }
  447. headers.insert(name.to_ascii_lowercase(), value);
  448. }
  449. let mut body = remaining[4..].to_vec();
  450. while body.len() < content_length {
  451. let mut chunk = vec![0_u8; content_length - body.len()];
  452. let read = socket
  453. .read(&mut chunk)
  454. .await
  455. .expect("body read should succeed");
  456. if read == 0 {
  457. break;
  458. }
  459. body.extend_from_slice(&chunk[..read]);
  460. }
  461. state.lock().await.push(CapturedRequest {
  462. method,
  463. path,
  464. headers,
  465. body: String::from_utf8(body).expect("body should be utf8"),
  466. });
  467. socket
  468. .write_all(response.as_bytes())
  469. .await
  470. .expect("response write should succeed");
  471. }
  472. });
  473. TestServer {
  474. base_url: format!("http://{address}"),
  475. join_handle,
  476. }
  477. }
  478. fn find_header_end(bytes: &[u8]) -> Option<usize> {
  479. bytes.windows(4).position(|window| window == b"\r\n\r\n")
  480. }
  481. fn http_response(status: &str, content_type: &str, body: &str) -> String {
  482. http_response_with_headers(status, content_type, body, &[])
  483. }
  484. fn http_response_with_headers(
  485. status: &str,
  486. content_type: &str,
  487. body: &str,
  488. headers: &[(&str, &str)],
  489. ) -> String {
  490. let mut extra_headers = String::new();
  491. for (name, value) in headers {
  492. use std::fmt::Write as _;
  493. write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write should succeed");
  494. }
  495. format!(
  496. "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
  497. body.len()
  498. )
  499. }
  500. fn sample_request(stream: bool) -> MessageRequest {
  501. MessageRequest {
  502. model: "claude-3-7-sonnet-latest".to_string(),
  503. max_tokens: 64,
  504. messages: vec![InputMessage {
  505. role: "user".to_string(),
  506. content: vec![
  507. InputContentBlock::Text {
  508. text: "Say hello".to_string(),
  509. },
  510. InputContentBlock::ToolResult {
  511. tool_use_id: "toolu_prev".to_string(),
  512. content: vec![api::ToolResultContentBlock::Json {
  513. value: json!({"forecast": "sunny"}),
  514. }],
  515. is_error: false,
  516. },
  517. ],
  518. }],
  519. system: Some("Use tools when needed".to_string()),
  520. tools: Some(vec![ToolDefinition {
  521. name: "get_weather".to_string(),
  522. description: Some("Fetches the weather".to_string()),
  523. input_schema: json!({
  524. "type": "object",
  525. "properties": {"city": {"type": "string"}},
  526. "required": ["city"]
  527. }),
  528. }]),
  529. tool_choice: Some(ToolChoice::Auto),
  530. stream,
  531. }
  532. }