openai_compat_integration.rs 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. use std::collections::HashMap;
  2. use std::sync::Arc;
  3. use api::{
  4. ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
  5. InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig,
  6. OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition,
  7. };
  8. use serde_json::json;
  9. use tokio::io::{AsyncReadExt, AsyncWriteExt};
  10. use tokio::net::TcpListener;
  11. use tokio::sync::Mutex;
  12. #[tokio::test]
  13. async fn send_message_uses_openai_compatible_endpoint_and_auth() {
  14. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  15. let body = concat!(
  16. "{",
  17. "\"id\":\"chatcmpl_test\",",
  18. "\"model\":\"grok-3\",",
  19. "\"choices\":[{",
  20. "\"message\":{\"role\":\"assistant\",\"content\":\"Hello from Grok\",\"tool_calls\":[]},",
  21. "\"finish_reason\":\"stop\"",
  22. "}],",
  23. "\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":5}",
  24. "}"
  25. );
  26. let server = spawn_server(
  27. state.clone(),
  28. vec![http_response("200 OK", "application/json", body)],
  29. )
  30. .await;
  31. let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
  32. .with_base_url(server.base_url());
  33. let response = client
  34. .send_message(&sample_request(false))
  35. .await
  36. .expect("request should succeed");
  37. assert_eq!(response.model, "grok-3");
  38. assert_eq!(response.total_tokens(), 16);
  39. assert_eq!(
  40. response.content,
  41. vec![OutputContentBlock::Text {
  42. text: "Hello from Grok".to_string(),
  43. }]
  44. );
  45. let captured = state.lock().await;
  46. let request = captured.first().expect("server should capture request");
  47. assert_eq!(request.path, "/chat/completions");
  48. assert_eq!(
  49. request.headers.get("authorization").map(String::as_str),
  50. Some("Bearer xai-test-key")
  51. );
  52. let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
  53. assert_eq!(body["model"], json!("grok-3"));
  54. assert_eq!(body["messages"][0]["role"], json!("system"));
  55. assert_eq!(body["tools"][0]["type"], json!("function"));
  56. }
  57. #[tokio::test]
  58. async fn stream_message_normalizes_text_and_multiple_tool_calls() {
  59. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  60. let sse = concat!(
  61. "data: {\"id\":\"chatcmpl_stream\",\"model\":\"grok-3\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n",
  62. "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}},{\"index\":1,\"id\":\"call_2\",\"function\":{\"name\":\"clock\",\"arguments\":\"{\\\"zone\\\":\\\"UTC\\\"}\"}}]}}]}\n\n",
  63. "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n",
  64. "data: [DONE]\n\n"
  65. );
  66. let server = spawn_server(
  67. state.clone(),
  68. vec![http_response_with_headers(
  69. "200 OK",
  70. "text/event-stream",
  71. sse,
  72. &[("x-request-id", "req_grok_stream")],
  73. )],
  74. )
  75. .await;
  76. let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
  77. .with_base_url(server.base_url());
  78. let mut stream = client
  79. .stream_message(&sample_request(false))
  80. .await
  81. .expect("stream should start");
  82. assert_eq!(stream.request_id(), Some("req_grok_stream"));
  83. let mut events = Vec::new();
  84. while let Some(event) = stream.next_event().await.expect("event should parse") {
  85. events.push(event);
  86. }
  87. assert!(matches!(events[0], StreamEvent::MessageStart(_)));
  88. assert!(matches!(
  89. events[1],
  90. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  91. content_block: OutputContentBlock::Text { .. },
  92. ..
  93. })
  94. ));
  95. assert!(matches!(
  96. events[2],
  97. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  98. delta: ContentBlockDelta::TextDelta { .. },
  99. ..
  100. })
  101. ));
  102. assert!(matches!(
  103. events[3],
  104. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  105. index: 1,
  106. content_block: OutputContentBlock::ToolUse { .. },
  107. })
  108. ));
  109. assert!(matches!(
  110. events[4],
  111. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  112. index: 1,
  113. delta: ContentBlockDelta::InputJsonDelta { .. },
  114. })
  115. ));
  116. assert!(matches!(
  117. events[5],
  118. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  119. index: 2,
  120. content_block: OutputContentBlock::ToolUse { .. },
  121. })
  122. ));
  123. assert!(matches!(
  124. events[6],
  125. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  126. index: 2,
  127. delta: ContentBlockDelta::InputJsonDelta { .. },
  128. })
  129. ));
  130. assert!(matches!(
  131. events[7],
  132. StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 1 })
  133. ));
  134. assert!(matches!(
  135. events[8],
  136. StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 2 })
  137. ));
  138. assert!(matches!(
  139. events[9],
  140. StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 })
  141. ));
  142. assert!(matches!(events[10], StreamEvent::MessageDelta(_)));
  143. assert!(matches!(events[11], StreamEvent::MessageStop(_)));
  144. let captured = state.lock().await;
  145. let request = captured.first().expect("captured request");
  146. assert_eq!(request.path, "/chat/completions");
  147. assert!(request.body.contains("\"stream\":true"));
  148. }
  149. #[derive(Debug, Clone, PartialEq, Eq)]
  150. struct CapturedRequest {
  151. path: String,
  152. headers: HashMap<String, String>,
  153. body: String,
  154. }
  155. struct TestServer {
  156. base_url: String,
  157. join_handle: tokio::task::JoinHandle<()>,
  158. }
  159. impl TestServer {
  160. fn base_url(&self) -> String {
  161. self.base_url.clone()
  162. }
  163. }
  164. impl Drop for TestServer {
  165. fn drop(&mut self) {
  166. self.join_handle.abort();
  167. }
  168. }
  169. async fn spawn_server(
  170. state: Arc<Mutex<Vec<CapturedRequest>>>,
  171. responses: Vec<String>,
  172. ) -> TestServer {
  173. let listener = TcpListener::bind("127.0.0.1:0")
  174. .await
  175. .expect("listener should bind");
  176. let address = listener.local_addr().expect("listener addr");
  177. let join_handle = tokio::spawn(async move {
  178. for response in responses {
  179. let (mut socket, _) = listener.accept().await.expect("accept");
  180. let mut buffer = Vec::new();
  181. let mut header_end = None;
  182. loop {
  183. let mut chunk = [0_u8; 1024];
  184. let read = socket.read(&mut chunk).await.expect("read request");
  185. if read == 0 {
  186. break;
  187. }
  188. buffer.extend_from_slice(&chunk[..read]);
  189. if let Some(position) = find_header_end(&buffer) {
  190. header_end = Some(position);
  191. break;
  192. }
  193. }
  194. let header_end = header_end.expect("headers should exist");
  195. let (header_bytes, remaining) = buffer.split_at(header_end);
  196. let header_text = String::from_utf8(header_bytes.to_vec()).expect("utf8 headers");
  197. let mut lines = header_text.split("\r\n");
  198. let request_line = lines.next().expect("request line");
  199. let path = request_line
  200. .split_whitespace()
  201. .nth(1)
  202. .expect("path")
  203. .to_string();
  204. let mut headers = HashMap::new();
  205. let mut content_length = 0_usize;
  206. for line in lines {
  207. if line.is_empty() {
  208. continue;
  209. }
  210. let (name, value) = line.split_once(':').expect("header");
  211. let value = value.trim().to_string();
  212. if name.eq_ignore_ascii_case("content-length") {
  213. content_length = value.parse().expect("content length");
  214. }
  215. headers.insert(name.to_ascii_lowercase(), value);
  216. }
  217. let mut body = remaining[4..].to_vec();
  218. while body.len() < content_length {
  219. let mut chunk = vec![0_u8; content_length - body.len()];
  220. let read = socket.read(&mut chunk).await.expect("read body");
  221. if read == 0 {
  222. break;
  223. }
  224. body.extend_from_slice(&chunk[..read]);
  225. }
  226. state.lock().await.push(CapturedRequest {
  227. path,
  228. headers,
  229. body: String::from_utf8(body).expect("utf8 body"),
  230. });
  231. socket
  232. .write_all(response.as_bytes())
  233. .await
  234. .expect("write response");
  235. }
  236. });
  237. TestServer {
  238. base_url: format!("http://{address}"),
  239. join_handle,
  240. }
  241. }
  242. fn find_header_end(bytes: &[u8]) -> Option<usize> {
  243. bytes.windows(4).position(|window| window == b"\r\n\r\n")
  244. }
  245. fn http_response(status: &str, content_type: &str, body: &str) -> String {
  246. http_response_with_headers(status, content_type, body, &[])
  247. }
  248. fn http_response_with_headers(
  249. status: &str,
  250. content_type: &str,
  251. body: &str,
  252. headers: &[(&str, &str)],
  253. ) -> String {
  254. let mut extra_headers = String::new();
  255. for (name, value) in headers {
  256. use std::fmt::Write as _;
  257. write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write");
  258. }
  259. format!(
  260. "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
  261. body.len()
  262. )
  263. }
  264. fn sample_request(stream: bool) -> MessageRequest {
  265. MessageRequest {
  266. model: "grok-3".to_string(),
  267. max_tokens: 64,
  268. messages: vec![InputMessage {
  269. role: "user".to_string(),
  270. content: vec![InputContentBlock::Text {
  271. text: "Say hello".to_string(),
  272. }],
  273. }],
  274. system: Some("Use tools when needed".to_string()),
  275. tools: Some(vec![ToolDefinition {
  276. name: "weather".to_string(),
  277. description: Some("Fetches weather".to_string()),
  278. input_schema: json!({
  279. "type": "object",
  280. "properties": {"city": {"type": "string"}},
  281. "required": ["city"]
  282. }),
  283. }]),
  284. tool_choice: Some(ToolChoice::Auto),
  285. stream,
  286. }
  287. }