client_integration.rs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. use std::collections::HashMap;
  2. use std::sync::Arc;
  3. use std::time::Duration;
  4. use api::{
  5. AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
  6. InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock,
  7. StreamEvent, ToolChoice, ToolDefinition,
  8. };
  9. use serde_json::json;
  10. use tokio::io::{AsyncReadExt, AsyncWriteExt};
  11. use tokio::net::TcpListener;
  12. use tokio::sync::Mutex;
  13. #[tokio::test]
  14. async fn send_message_posts_json_and_parses_response() {
  15. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  16. let body = concat!(
  17. "{",
  18. "\"id\":\"msg_test\",",
  19. "\"type\":\"message\",",
  20. "\"role\":\"assistant\",",
  21. "\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],",
  22. "\"model\":\"claude-3-7-sonnet-latest\",",
  23. "\"stop_reason\":\"end_turn\",",
  24. "\"stop_sequence\":null,",
  25. "\"usage\":{\"input_tokens\":12,\"output_tokens\":4},",
  26. "\"request_id\":\"req_body_123\"",
  27. "}"
  28. );
  29. let server = spawn_server(
  30. state.clone(),
  31. vec![http_response("200 OK", "application/json", body)],
  32. )
  33. .await;
  34. let client = AnthropicClient::new("test-key")
  35. .with_auth_token(Some("proxy-token".to_string()))
  36. .with_base_url(server.base_url());
  37. let response = client
  38. .send_message(&sample_request(false))
  39. .await
  40. .expect("request should succeed");
  41. assert_eq!(response.id, "msg_test");
  42. assert_eq!(response.total_tokens(), 16);
  43. assert_eq!(response.request_id.as_deref(), Some("req_body_123"));
  44. assert_eq!(
  45. response.content,
  46. vec![OutputContentBlock::Text {
  47. text: "Hello from Claude".to_string(),
  48. }]
  49. );
  50. let captured = state.lock().await;
  51. let request = captured.first().expect("server should capture request");
  52. assert_eq!(request.method, "POST");
  53. assert_eq!(request.path, "/v1/messages");
  54. assert_eq!(
  55. request.headers.get("x-api-key").map(String::as_str),
  56. Some("test-key")
  57. );
  58. assert_eq!(
  59. request.headers.get("authorization").map(String::as_str),
  60. Some("Bearer proxy-token")
  61. );
  62. let body: serde_json::Value =
  63. serde_json::from_str(&request.body).expect("request body should be json");
  64. assert_eq!(
  65. body.get("model").and_then(serde_json::Value::as_str),
  66. Some("claude-3-7-sonnet-latest")
  67. );
  68. assert!(body.get("stream").is_none());
  69. assert_eq!(body["tools"][0]["name"], json!("get_weather"));
  70. assert_eq!(body["tool_choice"]["type"], json!("auto"));
  71. }
  72. #[tokio::test]
  73. async fn send_message_parses_response_with_thinking_blocks() {
  74. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  75. let body = concat!(
  76. "{",
  77. "\"id\":\"msg_thinking\",",
  78. "\"type\":\"message\",",
  79. "\"role\":\"assistant\",",
  80. "\"content\":[",
  81. "{\"type\":\"thinking\",\"thinking\":\"step 1\",\"signature\":\"sig_123\"},",
  82. "{\"type\":\"text\",\"text\":\"Final answer\"}",
  83. "],",
  84. "\"model\":\"claude-3-7-sonnet-latest\",",
  85. "\"stop_reason\":\"end_turn\",",
  86. "\"stop_sequence\":null,",
  87. "\"usage\":{\"input_tokens\":12,\"output_tokens\":4}",
  88. "}"
  89. );
  90. let server = spawn_server(
  91. state,
  92. vec![http_response("200 OK", "application/json", body)],
  93. )
  94. .await;
  95. let client = AnthropicClient::new("test-key").with_base_url(server.base_url());
  96. let response = client
  97. .send_message(&sample_request(false))
  98. .await
  99. .expect("request should succeed");
  100. assert_eq!(response.content.len(), 2);
  101. assert!(matches!(
  102. &response.content[0],
  103. OutputContentBlock::Thinking { thinking, signature }
  104. if thinking == "step 1" && signature.as_deref() == Some("sig_123")
  105. ));
  106. assert!(matches!(
  107. &response.content[1],
  108. OutputContentBlock::Text { text } if text == "Final answer"
  109. ));
  110. }
  111. #[tokio::test]
  112. async fn stream_message_parses_sse_events_with_tool_use() {
  113. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  114. let sse = concat!(
  115. "event: message_start\n",
  116. "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
  117. "event: content_block_start\n",
  118. "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n",
  119. "event: content_block_delta\n",
  120. "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}\n\n",
  121. "event: content_block_stop\n",
  122. "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
  123. "event: message_delta\n",
  124. "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n",
  125. "event: message_stop\n",
  126. "data: {\"type\":\"message_stop\"}\n\n",
  127. "data: [DONE]\n\n"
  128. );
  129. let server = spawn_server(
  130. state.clone(),
  131. vec![http_response_with_headers(
  132. "200 OK",
  133. "text/event-stream",
  134. sse,
  135. &[("request-id", "req_stream_456")],
  136. )],
  137. )
  138. .await;
  139. let client = AnthropicClient::new("test-key")
  140. .with_auth_token(Some("proxy-token".to_string()))
  141. .with_base_url(server.base_url());
  142. let mut stream = client
  143. .stream_message(&sample_request(false))
  144. .await
  145. .expect("stream should start");
  146. assert_eq!(stream.request_id(), Some("req_stream_456"));
  147. let mut events = Vec::new();
  148. while let Some(event) = stream
  149. .next_event()
  150. .await
  151. .expect("stream event should parse")
  152. {
  153. events.push(event);
  154. }
  155. assert_eq!(events.len(), 6);
  156. assert!(matches!(events[0], StreamEvent::MessageStart(_)));
  157. assert!(matches!(
  158. events[1],
  159. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  160. content_block: OutputContentBlock::ToolUse { .. },
  161. ..
  162. })
  163. ));
  164. assert!(matches!(
  165. events[2],
  166. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  167. delta: ContentBlockDelta::InputJsonDelta { .. },
  168. ..
  169. })
  170. ));
  171. assert!(matches!(events[3], StreamEvent::ContentBlockStop(_)));
  172. assert!(matches!(
  173. events[4],
  174. StreamEvent::MessageDelta(MessageDeltaEvent { .. })
  175. ));
  176. assert!(matches!(events[5], StreamEvent::MessageStop(_)));
  177. match &events[1] {
  178. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  179. content_block: OutputContentBlock::ToolUse { name, input, .. },
  180. ..
  181. }) => {
  182. assert_eq!(name, "get_weather");
  183. assert_eq!(input, &json!({}));
  184. }
  185. other => panic!("expected tool_use block, got {other:?}"),
  186. }
  187. let captured = state.lock().await;
  188. let request = captured.first().expect("server should capture request");
  189. assert!(request.body.contains("\"stream\":true"));
  190. }
  191. #[tokio::test]
  192. async fn stream_message_parses_sse_events_with_thinking_blocks() {
  193. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  194. let sse = concat!(
  195. "event: message_start\n",
  196. "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream_thinking\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
  197. "event: content_block_start\n",
  198. "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\n",
  199. "event: content_block_delta\n",
  200. "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"step 1\"}}\n\n",
  201. "event: content_block_delta\n",
  202. "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"signature_delta\",\"signature\":\"sig_123\"}}\n\n",
  203. "event: content_block_stop\n",
  204. "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
  205. "event: content_block_start\n",
  206. "data: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"text\",\"text\":\"Final answer\"}}\n\n",
  207. "event: content_block_stop\n",
  208. "data: {\"type\":\"content_block_stop\",\"index\":1}\n\n",
  209. "event: message_delta\n",
  210. "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n",
  211. "event: message_stop\n",
  212. "data: {\"type\":\"message_stop\"}\n\n",
  213. "data: [DONE]\n\n"
  214. );
  215. let server = spawn_server(
  216. state,
  217. vec![http_response("200 OK", "text/event-stream", sse)],
  218. )
  219. .await;
  220. let client = AnthropicClient::new("test-key").with_base_url(server.base_url());
  221. let mut stream = client
  222. .stream_message(&sample_request(false))
  223. .await
  224. .expect("stream should start");
  225. let mut events = Vec::new();
  226. while let Some(event) = stream
  227. .next_event()
  228. .await
  229. .expect("stream event should parse")
  230. {
  231. events.push(event);
  232. }
  233. assert_eq!(events.len(), 9);
  234. assert!(matches!(
  235. &events[1],
  236. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  237. content_block: OutputContentBlock::Thinking { thinking, signature },
  238. ..
  239. }) if thinking.is_empty() && signature.is_none()
  240. ));
  241. assert!(matches!(
  242. &events[2],
  243. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  244. delta: ContentBlockDelta::ThinkingDelta { thinking },
  245. ..
  246. }) if thinking == "step 1"
  247. ));
  248. assert!(matches!(
  249. &events[3],
  250. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  251. delta: ContentBlockDelta::SignatureDelta { signature },
  252. ..
  253. }) if signature == "sig_123"
  254. ));
  255. assert!(matches!(
  256. &events[5],
  257. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  258. content_block: OutputContentBlock::Text { text },
  259. ..
  260. }) if text == "Final answer"
  261. ));
  262. assert!(matches!(events[6], StreamEvent::ContentBlockStop(_)));
  263. assert!(matches!(events[7], StreamEvent::MessageDelta(_)));
  264. assert!(matches!(events[8], StreamEvent::MessageStop(_)));
  265. }
  266. #[tokio::test]
  267. async fn retries_retryable_failures_before_succeeding() {
  268. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  269. let server = spawn_server(
  270. state.clone(),
  271. vec![
  272. http_response(
  273. "429 Too Many Requests",
  274. "application/json",
  275. "{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down\"}}",
  276. ),
  277. http_response(
  278. "200 OK",
  279. "application/json",
  280. "{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
  281. ),
  282. ],
  283. )
  284. .await;
  285. let client = AnthropicClient::new("test-key")
  286. .with_base_url(server.base_url())
  287. .with_retry_policy(2, Duration::from_millis(1), Duration::from_millis(2));
  288. let response = client
  289. .send_message(&sample_request(false))
  290. .await
  291. .expect("retry should eventually succeed");
  292. assert_eq!(response.total_tokens(), 5);
  293. assert_eq!(state.lock().await.len(), 2);
  294. }
  295. #[tokio::test]
  296. async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
  297. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  298. let server = spawn_server(
  299. state.clone(),
  300. vec![
  301. http_response(
  302. "503 Service Unavailable",
  303. "application/json",
  304. "{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"busy\"}}",
  305. ),
  306. http_response(
  307. "503 Service Unavailable",
  308. "application/json",
  309. "{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"still busy\"}}",
  310. ),
  311. ],
  312. )
  313. .await;
  314. let client = AnthropicClient::new("test-key")
  315. .with_base_url(server.base_url())
  316. .with_retry_policy(1, Duration::from_millis(1), Duration::from_millis(2));
  317. let error = client
  318. .send_message(&sample_request(false))
  319. .await
  320. .expect_err("persistent 503 should fail");
  321. match error {
  322. ApiError::RetriesExhausted {
  323. attempts,
  324. last_error,
  325. } => {
  326. assert_eq!(attempts, 2);
  327. assert!(matches!(
  328. *last_error,
  329. ApiError::Api {
  330. status: reqwest::StatusCode::SERVICE_UNAVAILABLE,
  331. retryable: true,
  332. ..
  333. }
  334. ));
  335. }
  336. other => panic!("expected retries exhausted, got {other:?}"),
  337. }
  338. }
  339. #[tokio::test]
  340. #[ignore = "requires ANTHROPIC_API_KEY and network access"]
  341. async fn live_stream_smoke_test() {
  342. let client = AnthropicClient::from_env().expect("ANTHROPIC_API_KEY must be set");
  343. let mut stream = client
  344. .stream_message(&MessageRequest {
  345. model: std::env::var("ANTHROPIC_MODEL")
  346. .unwrap_or_else(|_| "claude-3-7-sonnet-latest".to_string()),
  347. max_tokens: 32,
  348. messages: vec![InputMessage::user_text(
  349. "Reply with exactly: hello from rust",
  350. )],
  351. system: None,
  352. tools: None,
  353. tool_choice: None,
  354. stream: false,
  355. })
  356. .await
  357. .expect("live stream should start");
  358. while let Some(_event) = stream
  359. .next_event()
  360. .await
  361. .expect("live stream should yield events")
  362. {}
  363. }
  364. #[derive(Debug, Clone, PartialEq, Eq)]
  365. struct CapturedRequest {
  366. method: String,
  367. path: String,
  368. headers: HashMap<String, String>,
  369. body: String,
  370. }
  371. struct TestServer {
  372. base_url: String,
  373. join_handle: tokio::task::JoinHandle<()>,
  374. }
  375. impl TestServer {
  376. fn base_url(&self) -> String {
  377. self.base_url.clone()
  378. }
  379. }
  380. impl Drop for TestServer {
  381. fn drop(&mut self) {
  382. self.join_handle.abort();
  383. }
  384. }
  385. async fn spawn_server(
  386. state: Arc<Mutex<Vec<CapturedRequest>>>,
  387. responses: Vec<String>,
  388. ) -> TestServer {
  389. let listener = TcpListener::bind("127.0.0.1:0")
  390. .await
  391. .expect("listener should bind");
  392. let address = listener
  393. .local_addr()
  394. .expect("listener should have local addr");
  395. let join_handle = tokio::spawn(async move {
  396. for response in responses {
  397. let (mut socket, _) = listener.accept().await.expect("server should accept");
  398. let mut buffer = Vec::new();
  399. let mut header_end = None;
  400. loop {
  401. let mut chunk = [0_u8; 1024];
  402. let read = socket
  403. .read(&mut chunk)
  404. .await
  405. .expect("request read should succeed");
  406. if read == 0 {
  407. break;
  408. }
  409. buffer.extend_from_slice(&chunk[..read]);
  410. if let Some(position) = find_header_end(&buffer) {
  411. header_end = Some(position);
  412. break;
  413. }
  414. }
  415. let header_end = header_end.expect("request should include headers");
  416. let (header_bytes, remaining) = buffer.split_at(header_end);
  417. let header_text =
  418. String::from_utf8(header_bytes.to_vec()).expect("headers should be utf8");
  419. let mut lines = header_text.split("\r\n");
  420. let request_line = lines.next().expect("request line should exist");
  421. let mut parts = request_line.split_whitespace();
  422. let method = parts.next().expect("method should exist").to_string();
  423. let path = parts.next().expect("path should exist").to_string();
  424. let mut headers = HashMap::new();
  425. let mut content_length = 0_usize;
  426. for line in lines {
  427. if line.is_empty() {
  428. continue;
  429. }
  430. let (name, value) = line.split_once(':').expect("header should have colon");
  431. let value = value.trim().to_string();
  432. if name.eq_ignore_ascii_case("content-length") {
  433. content_length = value.parse().expect("content length should parse");
  434. }
  435. headers.insert(name.to_ascii_lowercase(), value);
  436. }
  437. let mut body = remaining[4..].to_vec();
  438. while body.len() < content_length {
  439. let mut chunk = vec![0_u8; content_length - body.len()];
  440. let read = socket
  441. .read(&mut chunk)
  442. .await
  443. .expect("body read should succeed");
  444. if read == 0 {
  445. break;
  446. }
  447. body.extend_from_slice(&chunk[..read]);
  448. }
  449. state.lock().await.push(CapturedRequest {
  450. method,
  451. path,
  452. headers,
  453. body: String::from_utf8(body).expect("body should be utf8"),
  454. });
  455. socket
  456. .write_all(response.as_bytes())
  457. .await
  458. .expect("response write should succeed");
  459. }
  460. });
  461. TestServer {
  462. base_url: format!("http://{address}"),
  463. join_handle,
  464. }
  465. }
  466. fn find_header_end(bytes: &[u8]) -> Option<usize> {
  467. bytes.windows(4).position(|window| window == b"\r\n\r\n")
  468. }
  469. fn http_response(status: &str, content_type: &str, body: &str) -> String {
  470. http_response_with_headers(status, content_type, body, &[])
  471. }
  472. fn http_response_with_headers(
  473. status: &str,
  474. content_type: &str,
  475. body: &str,
  476. headers: &[(&str, &str)],
  477. ) -> String {
  478. let mut extra_headers = String::new();
  479. for (name, value) in headers {
  480. use std::fmt::Write as _;
  481. write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write should succeed");
  482. }
  483. format!(
  484. "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
  485. body.len()
  486. )
  487. }
  488. fn sample_request(stream: bool) -> MessageRequest {
  489. MessageRequest {
  490. model: "claude-3-7-sonnet-latest".to_string(),
  491. max_tokens: 64,
  492. messages: vec![InputMessage {
  493. role: "user".to_string(),
  494. content: vec![
  495. InputContentBlock::Text {
  496. text: "Say hello".to_string(),
  497. },
  498. InputContentBlock::ToolResult {
  499. tool_use_id: "toolu_prev".to_string(),
  500. content: vec![api::ToolResultContentBlock::Json {
  501. value: json!({"forecast": "sunny"}),
  502. }],
  503. is_error: false,
  504. },
  505. ],
  506. }],
  507. system: Some("Use tools when needed".to_string()),
  508. tools: Some(vec![ToolDefinition {
  509. name: "get_weather".to_string(),
  510. description: Some("Fetches the weather".to_string()),
  511. input_schema: json!({
  512. "type": "object",
  513. "properties": {"city": {"type": "string"}},
  514. "required": ["city"]
  515. }),
  516. }]),
  517. tool_choice: Some(ToolChoice::Auto),
  518. stream,
  519. }
  520. }