client_integration.rs 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. use std::collections::HashMap;
  2. use std::sync::Arc;
  3. use std::time::Duration;
  4. use api::{
  5. AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
  6. InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock,
  7. StreamEvent, ToolChoice, ToolDefinition,
  8. };
  9. use serde_json::json;
  10. use telemetry::{ClientIdentity, MemoryTelemetrySink, SessionTracer, TelemetryEvent};
  11. use tokio::io::{AsyncReadExt, AsyncWriteExt};
  12. use tokio::net::TcpListener;
  13. use tokio::sync::Mutex;
  14. #[tokio::test]
  15. async fn send_message_posts_json_and_parses_response() {
  16. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  17. let body = concat!(
  18. "{",
  19. "\"id\":\"msg_test\",",
  20. "\"type\":\"message\",",
  21. "\"role\":\"assistant\",",
  22. "\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],",
  23. "\"model\":\"claude-3-7-sonnet-latest\",",
  24. "\"stop_reason\":\"end_turn\",",
  25. "\"stop_sequence\":null,",
  26. "\"usage\":{\"input_tokens\":12,\"output_tokens\":4},",
  27. "\"request_id\":\"req_body_123\"",
  28. "}"
  29. );
  30. let server = spawn_server(
  31. state.clone(),
  32. vec![http_response("200 OK", "application/json", body)],
  33. )
  34. .await;
  35. let client = AnthropicClient::new("test-key")
  36. .with_auth_token(Some("proxy-token".to_string()))
  37. .with_base_url(server.base_url());
  38. let response = client
  39. .send_message(&sample_request(false))
  40. .await
  41. .expect("request should succeed");
  42. assert_eq!(response.id, "msg_test");
  43. assert_eq!(response.total_tokens(), 16);
  44. assert_eq!(response.request_id.as_deref(), Some("req_body_123"));
  45. assert_eq!(
  46. response.content,
  47. vec![OutputContentBlock::Text {
  48. text: "Hello from Claude".to_string(),
  49. }]
  50. );
  51. let captured = state.lock().await;
  52. let request = captured.first().expect("server should capture request");
  53. assert_eq!(request.method, "POST");
  54. assert_eq!(request.path, "/v1/messages");
  55. assert_eq!(
  56. request.headers.get("x-api-key").map(String::as_str),
  57. Some("test-key")
  58. );
  59. assert_eq!(
  60. request.headers.get("authorization").map(String::as_str),
  61. Some("Bearer proxy-token")
  62. );
  63. assert_eq!(
  64. request.headers.get("anthropic-version").map(String::as_str),
  65. Some("2023-06-01")
  66. );
  67. assert_eq!(
  68. request.headers.get("user-agent").map(String::as_str),
  69. Some("claude-code/0.1.0")
  70. );
  71. assert_eq!(
  72. request.headers.get("anthropic-beta").map(String::as_str),
  73. Some("claude-code-20250219,prompt-caching-scope-2026-01-05")
  74. );
  75. let body: serde_json::Value =
  76. serde_json::from_str(&request.body).expect("request body should be json");
  77. assert_eq!(
  78. body.get("model").and_then(serde_json::Value::as_str),
  79. Some("claude-3-7-sonnet-latest")
  80. );
  81. assert!(body.get("stream").is_none());
  82. assert_eq!(body["tools"][0]["name"], json!("get_weather"));
  83. assert_eq!(body["tool_choice"]["type"], json!("auto"));
  84. assert_eq!(
  85. body["betas"],
  86. json!(["claude-code-20250219", "prompt-caching-scope-2026-01-05"])
  87. );
  88. }
  89. #[tokio::test]
  90. async fn send_message_applies_request_profile_and_records_telemetry() {
  91. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  92. let server = spawn_server(
  93. state.clone(),
  94. vec![http_response_with_headers(
  95. "200 OK",
  96. "application/json",
  97. concat!(
  98. "{",
  99. "\"id\":\"msg_profile\",",
  100. "\"type\":\"message\",",
  101. "\"role\":\"assistant\",",
  102. "\"content\":[{\"type\":\"text\",\"text\":\"ok\"}],",
  103. "\"model\":\"claude-3-7-sonnet-latest\",",
  104. "\"stop_reason\":\"end_turn\",",
  105. "\"stop_sequence\":null,",
  106. "\"usage\":{\"input_tokens\":1,\"cache_creation_input_tokens\":2,\"cache_read_input_tokens\":3,\"output_tokens\":1}",
  107. "}"
  108. ),
  109. &[("request-id", "req_profile_123")],
  110. )],
  111. )
  112. .await;
  113. let sink = Arc::new(MemoryTelemetrySink::default());
  114. let client = AnthropicClient::new("test-key")
  115. .with_base_url(server.base_url())
  116. .with_client_identity(ClientIdentity::new("claude-code", "9.9.9").with_runtime("rust-cli"))
  117. .with_beta("tools-2026-04-01")
  118. .with_extra_body_param("metadata", json!({"source": "clawd-code"}))
  119. .with_session_tracer(SessionTracer::new("session-telemetry", sink.clone()));
  120. let response = client
  121. .send_message(&sample_request(false))
  122. .await
  123. .expect("request should succeed");
  124. assert_eq!(response.request_id.as_deref(), Some("req_profile_123"));
  125. let captured = state.lock().await;
  126. let request = captured.first().expect("server should capture request");
  127. assert_eq!(
  128. request.headers.get("anthropic-beta").map(String::as_str),
  129. Some("claude-code-20250219,prompt-caching-scope-2026-01-05,tools-2026-04-01")
  130. );
  131. assert_eq!(
  132. request.headers.get("user-agent").map(String::as_str),
  133. Some("claude-code/9.9.9")
  134. );
  135. let body: serde_json::Value =
  136. serde_json::from_str(&request.body).expect("request body should be json");
  137. assert_eq!(body["metadata"]["source"], json!("clawd-code"));
  138. assert_eq!(
  139. body["betas"],
  140. json!([
  141. "claude-code-20250219",
  142. "prompt-caching-scope-2026-01-05",
  143. "tools-2026-04-01"
  144. ])
  145. );
  146. let events = sink.events();
  147. assert_eq!(events.len(), 6);
  148. assert!(matches!(
  149. &events[0],
  150. TelemetryEvent::HttpRequestStarted {
  151. session_id,
  152. attempt: 1,
  153. method,
  154. path,
  155. ..
  156. } if session_id == "session-telemetry" && method == "POST" && path == "/v1/messages"
  157. ));
  158. assert!(matches!(
  159. &events[1],
  160. TelemetryEvent::SessionTrace(trace) if trace.name == "http_request_started"
  161. ));
  162. assert!(matches!(
  163. &events[2],
  164. TelemetryEvent::HttpRequestSucceeded {
  165. request_id,
  166. status: 200,
  167. ..
  168. } if request_id.as_deref() == Some("req_profile_123")
  169. ));
  170. assert!(matches!(
  171. &events[3],
  172. TelemetryEvent::SessionTrace(trace) if trace.name == "http_request_succeeded"
  173. ));
  174. assert!(matches!(
  175. &events[4],
  176. TelemetryEvent::Analytics(event)
  177. if event.namespace == "api"
  178. && event.action == "message_usage"
  179. && event.properties.get("request_id") == Some(&json!("req_profile_123"))
  180. && event.properties.get("total_tokens") == Some(&json!(7))
  181. && event.properties.get("estimated_cost_usd") == Some(&json!("$0.0001"))
  182. ));
  183. assert!(matches!(
  184. &events[5],
  185. TelemetryEvent::SessionTrace(trace) if trace.name == "analytics"
  186. ));
  187. }
  188. #[tokio::test]
  189. async fn stream_message_parses_sse_events_with_tool_use() {
  190. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  191. let sse = concat!(
  192. "event: message_start\n",
  193. "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
  194. "event: content_block_start\n",
  195. "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n",
  196. "event: content_block_delta\n",
  197. "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}\n\n",
  198. "event: content_block_stop\n",
  199. "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
  200. "event: message_delta\n",
  201. "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n",
  202. "event: message_stop\n",
  203. "data: {\"type\":\"message_stop\"}\n\n",
  204. "data: [DONE]\n\n"
  205. );
  206. let server = spawn_server(
  207. state.clone(),
  208. vec![http_response_with_headers(
  209. "200 OK",
  210. "text/event-stream",
  211. sse,
  212. &[("request-id", "req_stream_456")],
  213. )],
  214. )
  215. .await;
  216. let client = AnthropicClient::new("test-key")
  217. .with_auth_token(Some("proxy-token".to_string()))
  218. .with_base_url(server.base_url());
  219. let mut stream = client
  220. .stream_message(&sample_request(false))
  221. .await
  222. .expect("stream should start");
  223. assert_eq!(stream.request_id(), Some("req_stream_456"));
  224. let mut events = Vec::new();
  225. while let Some(event) = stream
  226. .next_event()
  227. .await
  228. .expect("stream event should parse")
  229. {
  230. events.push(event);
  231. }
  232. assert_eq!(events.len(), 6);
  233. assert!(matches!(events[0], StreamEvent::MessageStart(_)));
  234. assert!(matches!(
  235. events[1],
  236. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  237. content_block: OutputContentBlock::ToolUse { .. },
  238. ..
  239. })
  240. ));
  241. assert!(matches!(
  242. events[2],
  243. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  244. delta: ContentBlockDelta::InputJsonDelta { .. },
  245. ..
  246. })
  247. ));
  248. assert!(matches!(events[3], StreamEvent::ContentBlockStop(_)));
  249. assert!(matches!(
  250. events[4],
  251. StreamEvent::MessageDelta(MessageDeltaEvent { .. })
  252. ));
  253. assert!(matches!(events[5], StreamEvent::MessageStop(_)));
  254. match &events[1] {
  255. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  256. content_block: OutputContentBlock::ToolUse { name, input, .. },
  257. ..
  258. }) => {
  259. assert_eq!(name, "get_weather");
  260. assert_eq!(input, &json!({}));
  261. }
  262. other => panic!("expected tool_use block, got {other:?}"),
  263. }
  264. let captured = state.lock().await;
  265. let request = captured.first().expect("server should capture request");
  266. assert!(request.body.contains("\"stream\":true"));
  267. }
  268. #[tokio::test]
  269. async fn retries_retryable_failures_before_succeeding() {
  270. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  271. let server = spawn_server(
  272. state.clone(),
  273. vec![
  274. http_response(
  275. "429 Too Many Requests",
  276. "application/json",
  277. "{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down\"}}",
  278. ),
  279. http_response(
  280. "200 OK",
  281. "application/json",
  282. "{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
  283. ),
  284. ],
  285. )
  286. .await;
  287. let client = AnthropicClient::new("test-key")
  288. .with_base_url(server.base_url())
  289. .with_retry_policy(2, Duration::from_millis(1), Duration::from_millis(2));
  290. let response = client
  291. .send_message(&sample_request(false))
  292. .await
  293. .expect("retry should eventually succeed");
  294. assert_eq!(response.total_tokens(), 5);
  295. assert_eq!(state.lock().await.len(), 2);
  296. }
  297. #[tokio::test]
  298. async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
  299. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  300. let server = spawn_server(
  301. state.clone(),
  302. vec![
  303. http_response(
  304. "503 Service Unavailable",
  305. "application/json",
  306. "{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"busy\"}}",
  307. ),
  308. http_response(
  309. "503 Service Unavailable",
  310. "application/json",
  311. "{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"still busy\"}}",
  312. ),
  313. ],
  314. )
  315. .await;
  316. let client = AnthropicClient::new("test-key")
  317. .with_base_url(server.base_url())
  318. .with_retry_policy(1, Duration::from_millis(1), Duration::from_millis(2));
  319. let error = client
  320. .send_message(&sample_request(false))
  321. .await
  322. .expect_err("persistent 503 should fail");
  323. match error {
  324. ApiError::RetriesExhausted {
  325. attempts,
  326. last_error,
  327. } => {
  328. assert_eq!(attempts, 2);
  329. assert!(matches!(
  330. *last_error,
  331. ApiError::Api {
  332. status: reqwest::StatusCode::SERVICE_UNAVAILABLE,
  333. retryable: true,
  334. ..
  335. }
  336. ));
  337. }
  338. other => panic!("expected retries exhausted, got {other:?}"),
  339. }
  340. }
  341. #[tokio::test]
  342. #[ignore = "requires ANTHROPIC_API_KEY and network access"]
  343. async fn live_stream_smoke_test() {
  344. let client = AnthropicClient::from_env().expect("ANTHROPIC_API_KEY must be set");
  345. let mut stream = client
  346. .stream_message(&MessageRequest {
  347. model: std::env::var("ANTHROPIC_MODEL")
  348. .unwrap_or_else(|_| "claude-3-7-sonnet-latest".to_string()),
  349. max_tokens: 32,
  350. messages: vec![InputMessage::user_text(
  351. "Reply with exactly: hello from rust",
  352. )],
  353. system: None,
  354. tools: None,
  355. tool_choice: None,
  356. stream: false,
  357. })
  358. .await
  359. .expect("live stream should start");
  360. while let Some(_event) = stream
  361. .next_event()
  362. .await
  363. .expect("live stream should yield events")
  364. {}
  365. }
  366. #[derive(Debug, Clone, PartialEq, Eq)]
  367. struct CapturedRequest {
  368. method: String,
  369. path: String,
  370. headers: HashMap<String, String>,
  371. body: String,
  372. }
  373. struct TestServer {
  374. base_url: String,
  375. join_handle: tokio::task::JoinHandle<()>,
  376. }
  377. impl TestServer {
  378. fn base_url(&self) -> String {
  379. self.base_url.clone()
  380. }
  381. }
  382. impl Drop for TestServer {
  383. fn drop(&mut self) {
  384. self.join_handle.abort();
  385. }
  386. }
  387. async fn spawn_server(
  388. state: Arc<Mutex<Vec<CapturedRequest>>>,
  389. responses: Vec<String>,
  390. ) -> TestServer {
  391. let listener = TcpListener::bind("127.0.0.1:0")
  392. .await
  393. .expect("listener should bind");
  394. let address = listener
  395. .local_addr()
  396. .expect("listener should have local addr");
  397. let join_handle = tokio::spawn(async move {
  398. for response in responses {
  399. let (mut socket, _) = listener.accept().await.expect("server should accept");
  400. let mut buffer = Vec::new();
  401. let mut header_end = None;
  402. loop {
  403. let mut chunk = [0_u8; 1024];
  404. let read = socket
  405. .read(&mut chunk)
  406. .await
  407. .expect("request read should succeed");
  408. if read == 0 {
  409. break;
  410. }
  411. buffer.extend_from_slice(&chunk[..read]);
  412. if let Some(position) = find_header_end(&buffer) {
  413. header_end = Some(position);
  414. break;
  415. }
  416. }
  417. let header_end = header_end.expect("request should include headers");
  418. let (header_bytes, remaining) = buffer.split_at(header_end);
  419. let header_text =
  420. String::from_utf8(header_bytes.to_vec()).expect("headers should be utf8");
  421. let mut lines = header_text.split("\r\n");
  422. let request_line = lines.next().expect("request line should exist");
  423. let mut parts = request_line.split_whitespace();
  424. let method = parts.next().expect("method should exist").to_string();
  425. let path = parts.next().expect("path should exist").to_string();
  426. let mut headers = HashMap::new();
  427. let mut content_length = 0_usize;
  428. for line in lines {
  429. if line.is_empty() {
  430. continue;
  431. }
  432. let (name, value) = line.split_once(':').expect("header should have colon");
  433. let value = value.trim().to_string();
  434. if name.eq_ignore_ascii_case("content-length") {
  435. content_length = value.parse().expect("content length should parse");
  436. }
  437. headers.insert(name.to_ascii_lowercase(), value);
  438. }
  439. let mut body = remaining[4..].to_vec();
  440. while body.len() < content_length {
  441. let mut chunk = vec![0_u8; content_length - body.len()];
  442. let read = socket
  443. .read(&mut chunk)
  444. .await
  445. .expect("body read should succeed");
  446. if read == 0 {
  447. break;
  448. }
  449. body.extend_from_slice(&chunk[..read]);
  450. }
  451. state.lock().await.push(CapturedRequest {
  452. method,
  453. path,
  454. headers,
  455. body: String::from_utf8(body).expect("body should be utf8"),
  456. });
  457. socket
  458. .write_all(response.as_bytes())
  459. .await
  460. .expect("response write should succeed");
  461. }
  462. });
  463. TestServer {
  464. base_url: format!("http://{address}"),
  465. join_handle,
  466. }
  467. }
  468. fn find_header_end(bytes: &[u8]) -> Option<usize> {
  469. bytes.windows(4).position(|window| window == b"\r\n\r\n")
  470. }
  471. fn http_response(status: &str, content_type: &str, body: &str) -> String {
  472. http_response_with_headers(status, content_type, body, &[])
  473. }
  474. fn http_response_with_headers(
  475. status: &str,
  476. content_type: &str,
  477. body: &str,
  478. headers: &[(&str, &str)],
  479. ) -> String {
  480. let mut extra_headers = String::new();
  481. for (name, value) in headers {
  482. use std::fmt::Write as _;
  483. write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write should succeed");
  484. }
  485. format!(
  486. "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
  487. body.len()
  488. )
  489. }
  490. fn sample_request(stream: bool) -> MessageRequest {
  491. MessageRequest {
  492. model: "claude-3-7-sonnet-latest".to_string(),
  493. max_tokens: 64,
  494. messages: vec![InputMessage {
  495. role: "user".to_string(),
  496. content: vec![
  497. InputContentBlock::Text {
  498. text: "Say hello".to_string(),
  499. },
  500. InputContentBlock::ToolResult {
  501. tool_use_id: "toolu_prev".to_string(),
  502. content: vec![api::ToolResultContentBlock::Json {
  503. value: json!({"forecast": "sunny"}),
  504. }],
  505. is_error: false,
  506. },
  507. ],
  508. }],
  509. system: Some("Use tools when needed".to_string()),
  510. tools: Some(vec![ToolDefinition {
  511. name: "get_weather".to_string(),
  512. description: Some("Fetches the weather".to_string()),
  513. input_schema: json!({
  514. "type": "object",
  515. "properties": {"city": {"type": "string"}},
  516. "required": ["city"]
  517. }),
  518. }]),
  519. tool_choice: Some(ToolChoice::Auto),
  520. stream,
  521. }
  522. }