client_integration.rs 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. use std::collections::HashMap;
  2. use std::sync::Arc;
  3. use api::{AnthropicClient, InputMessage, MessageRequest, OutputContentBlock, StreamEvent};
  4. use tokio::io::{AsyncReadExt, AsyncWriteExt};
  5. use tokio::net::TcpListener;
  6. use tokio::sync::Mutex;
  7. #[tokio::test]
  8. async fn send_message_posts_json_and_parses_response() {
  9. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  10. let body = concat!(
  11. "{",
  12. "\"id\":\"msg_test\",",
  13. "\"type\":\"message\",",
  14. "\"role\":\"assistant\",",
  15. "\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],",
  16. "\"model\":\"claude-3-7-sonnet-latest\",",
  17. "\"stop_reason\":\"end_turn\",",
  18. "\"stop_sequence\":null,",
  19. "\"usage\":{\"input_tokens\":12,\"output_tokens\":4}",
  20. "}"
  21. );
  22. let server = spawn_server(state.clone(), http_response("application/json", body)).await;
  23. let client = AnthropicClient::new("test-key")
  24. .with_auth_token(Some("proxy-token".to_string()))
  25. .with_base_url(server.base_url());
  26. let response = client
  27. .send_message(&sample_request(false))
  28. .await
  29. .expect("request should succeed");
  30. assert_eq!(response.id, "msg_test");
  31. assert_eq!(
  32. response.content,
  33. vec![OutputContentBlock::Text {
  34. text: "Hello from Claude".to_string(),
  35. }]
  36. );
  37. let captured = state.lock().await;
  38. let request = captured.first().expect("server should capture request");
  39. assert_eq!(request.method, "POST");
  40. assert_eq!(request.path, "/v1/messages");
  41. assert_eq!(
  42. request.headers.get("x-api-key").map(String::as_str),
  43. Some("test-key")
  44. );
  45. assert_eq!(
  46. request.headers.get("authorization").map(String::as_str),
  47. Some("Bearer proxy-token")
  48. );
  49. assert_eq!(
  50. request.headers.get("anthropic-version").map(String::as_str),
  51. Some("2023-06-01")
  52. );
  53. let body: serde_json::Value =
  54. serde_json::from_str(&request.body).expect("request body should be json");
  55. assert_eq!(
  56. body.get("model").and_then(serde_json::Value::as_str),
  57. Some("claude-3-7-sonnet-latest")
  58. );
  59. assert!(
  60. body.get("stream").is_none(),
  61. "non-stream request should omit stream=false"
  62. );
  63. }
  64. #[tokio::test]
  65. async fn stream_message_parses_sse_events() {
  66. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  67. let sse = concat!(
  68. "event: message_start\n",
  69. "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
  70. "event: content_block_start\n",
  71. "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n",
  72. "event: content_block_delta\n",
  73. "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n",
  74. "event: content_block_stop\n",
  75. "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
  76. "event: message_stop\n",
  77. "data: {\"type\":\"message_stop\"}\n\n",
  78. "data: [DONE]\n\n"
  79. );
  80. let server = spawn_server(state.clone(), http_response("text/event-stream", sse)).await;
  81. let client = AnthropicClient::new("test-key")
  82. .with_auth_token(Some("proxy-token".to_string()))
  83. .with_base_url(server.base_url());
  84. let mut stream = client
  85. .stream_message(&sample_request(false))
  86. .await
  87. .expect("stream should start");
  88. let mut events = Vec::new();
  89. while let Some(event) = stream
  90. .next_event()
  91. .await
  92. .expect("stream event should parse")
  93. {
  94. events.push(event);
  95. }
  96. assert_eq!(events.len(), 5);
  97. assert!(matches!(events[0], StreamEvent::MessageStart(_)));
  98. assert!(matches!(events[1], StreamEvent::ContentBlockStart(_)));
  99. assert!(matches!(events[2], StreamEvent::ContentBlockDelta(_)));
  100. assert!(matches!(events[3], StreamEvent::ContentBlockStop(_)));
  101. assert!(matches!(events[4], StreamEvent::MessageStop(_)));
  102. let captured = state.lock().await;
  103. let request = captured.first().expect("server should capture request");
  104. assert!(request.body.contains("\"stream\":true"));
  105. }
  106. #[tokio::test]
  107. #[ignore = "requires ANTHROPIC_API_KEY and network access"]
  108. async fn live_stream_smoke_test() {
  109. let client = AnthropicClient::from_env().expect("ANTHROPIC_API_KEY must be set");
  110. let mut stream = client
  111. .stream_message(&MessageRequest {
  112. model: std::env::var("ANTHROPIC_MODEL")
  113. .unwrap_or_else(|_| "claude-3-7-sonnet-latest".to_string()),
  114. max_tokens: 32,
  115. messages: vec![InputMessage::user_text(
  116. "Reply with exactly: hello from rust",
  117. )],
  118. system: None,
  119. stream: false,
  120. })
  121. .await
  122. .expect("live stream should start");
  123. let mut saw_start = false;
  124. let mut saw_follow_up = false;
  125. let mut event_kinds = Vec::new();
  126. while let Some(event) = stream
  127. .next_event()
  128. .await
  129. .expect("live stream should yield events")
  130. {
  131. match event {
  132. StreamEvent::MessageStart(_) => {
  133. saw_start = true;
  134. event_kinds.push("message_start");
  135. }
  136. StreamEvent::ContentBlockStart(_) => {
  137. saw_follow_up = true;
  138. event_kinds.push("content_block_start");
  139. }
  140. StreamEvent::ContentBlockDelta(_) => {
  141. saw_follow_up = true;
  142. event_kinds.push("content_block_delta");
  143. }
  144. StreamEvent::ContentBlockStop(_) => {
  145. saw_follow_up = true;
  146. event_kinds.push("content_block_stop");
  147. }
  148. StreamEvent::MessageStop(_) => {
  149. saw_follow_up = true;
  150. event_kinds.push("message_stop");
  151. }
  152. }
  153. }
  154. assert!(
  155. saw_start,
  156. "expected a message_start event; got {event_kinds:?}"
  157. );
  158. assert!(
  159. saw_follow_up,
  160. "expected at least one follow-up stream event; got {event_kinds:?}"
  161. );
  162. }
  163. #[derive(Debug, Clone, PartialEq, Eq)]
  164. struct CapturedRequest {
  165. method: String,
  166. path: String,
  167. headers: HashMap<String, String>,
  168. body: String,
  169. }
  170. struct TestServer {
  171. base_url: String,
  172. join_handle: tokio::task::JoinHandle<()>,
  173. }
  174. impl TestServer {
  175. fn base_url(&self) -> String {
  176. self.base_url.clone()
  177. }
  178. }
  179. impl Drop for TestServer {
  180. fn drop(&mut self) {
  181. self.join_handle.abort();
  182. }
  183. }
  184. async fn spawn_server(state: Arc<Mutex<Vec<CapturedRequest>>>, response: String) -> TestServer {
  185. let listener = TcpListener::bind("127.0.0.1:0")
  186. .await
  187. .expect("listener should bind");
  188. let address = listener
  189. .local_addr()
  190. .expect("listener should have local addr");
  191. let join_handle = tokio::spawn(async move {
  192. let (mut socket, _) = listener.accept().await.expect("server should accept");
  193. let mut buffer = Vec::new();
  194. let mut header_end = None;
  195. loop {
  196. let mut chunk = [0_u8; 1024];
  197. let read = socket
  198. .read(&mut chunk)
  199. .await
  200. .expect("request read should succeed");
  201. if read == 0 {
  202. break;
  203. }
  204. buffer.extend_from_slice(&chunk[..read]);
  205. if let Some(position) = find_header_end(&buffer) {
  206. header_end = Some(position);
  207. break;
  208. }
  209. }
  210. let header_end = header_end.expect("request should include headers");
  211. let (header_bytes, remaining) = buffer.split_at(header_end);
  212. let header_text = String::from_utf8(header_bytes.to_vec()).expect("headers should be utf8");
  213. let mut lines = header_text.split("\r\n");
  214. let request_line = lines.next().expect("request line should exist");
  215. let mut parts = request_line.split_whitespace();
  216. let method = parts.next().expect("method should exist").to_string();
  217. let path = parts.next().expect("path should exist").to_string();
  218. let mut headers = HashMap::new();
  219. let mut content_length = 0_usize;
  220. for line in lines {
  221. if line.is_empty() {
  222. continue;
  223. }
  224. let (name, value) = line.split_once(':').expect("header should have colon");
  225. let value = value.trim().to_string();
  226. if name.eq_ignore_ascii_case("content-length") {
  227. content_length = value.parse().expect("content length should parse");
  228. }
  229. headers.insert(name.to_ascii_lowercase(), value);
  230. }
  231. let mut body = remaining[4..].to_vec();
  232. while body.len() < content_length {
  233. let mut chunk = vec![0_u8; content_length - body.len()];
  234. let read = socket
  235. .read(&mut chunk)
  236. .await
  237. .expect("body read should succeed");
  238. if read == 0 {
  239. break;
  240. }
  241. body.extend_from_slice(&chunk[..read]);
  242. }
  243. state.lock().await.push(CapturedRequest {
  244. method,
  245. path,
  246. headers,
  247. body: String::from_utf8(body).expect("body should be utf8"),
  248. });
  249. socket
  250. .write_all(response.as_bytes())
  251. .await
  252. .expect("response write should succeed");
  253. });
  254. TestServer {
  255. base_url: format!("http://{address}"),
  256. join_handle,
  257. }
  258. }
  259. fn find_header_end(bytes: &[u8]) -> Option<usize> {
  260. bytes.windows(4).position(|window| window == b"\r\n\r\n")
  261. }
  262. fn http_response(content_type: &str, body: &str) -> String {
  263. format!(
  264. "HTTP/1.1 200 OK\r\ncontent-type: {content_type}\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}",
  265. body.len()
  266. )
  267. }
  268. fn sample_request(stream: bool) -> MessageRequest {
  269. MessageRequest {
  270. model: "claude-3-7-sonnet-latest".to_string(),
  271. max_tokens: 64,
  272. messages: vec![InputMessage::user_text("Say hello")],
  273. system: None,
  274. stream,
  275. }
  276. }