openai_compat_integration.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. use std::collections::HashMap;
  2. use std::ffi::OsString;
  3. use std::sync::Arc;
  4. use std::sync::{Mutex as StdMutex, OnceLock};
  5. use api::{
  6. ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
  7. InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OpenAiCompatClient,
  8. OpenAiCompatConfig, OutputContentBlock, ProviderClient, StreamEvent, ToolChoice,
  9. ToolDefinition,
  10. };
  11. use serde_json::json;
  12. use tokio::io::{AsyncReadExt, AsyncWriteExt};
  13. use tokio::net::TcpListener;
  14. use tokio::sync::Mutex;
  15. #[tokio::test]
  16. async fn send_message_uses_openai_compatible_endpoint_and_auth() {
  17. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  18. let body = concat!(
  19. "{",
  20. "\"id\":\"chatcmpl_test\",",
  21. "\"model\":\"grok-3\",",
  22. "\"choices\":[{",
  23. "\"message\":{\"role\":\"assistant\",\"content\":\"Hello from Grok\",\"tool_calls\":[]},",
  24. "\"finish_reason\":\"stop\"",
  25. "}],",
  26. "\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":5}",
  27. "}"
  28. );
  29. let server = spawn_server(
  30. state.clone(),
  31. vec![http_response("200 OK", "application/json", body)],
  32. )
  33. .await;
  34. let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
  35. .with_base_url(server.base_url());
  36. let response = client
  37. .send_message(&sample_request(false))
  38. .await
  39. .expect("request should succeed");
  40. assert_eq!(response.model, "grok-3");
  41. assert_eq!(response.total_tokens(), 16);
  42. assert_eq!(
  43. response.content,
  44. vec![OutputContentBlock::Text {
  45. text: "Hello from Grok".to_string(),
  46. }]
  47. );
  48. let captured = state.lock().await;
  49. let request = captured.first().expect("server should capture request");
  50. assert_eq!(request.path, "/chat/completions");
  51. assert_eq!(
  52. request.headers.get("authorization").map(String::as_str),
  53. Some("Bearer xai-test-key")
  54. );
  55. let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
  56. assert_eq!(body["model"], json!("grok-3"));
  57. assert_eq!(body["messages"][0]["role"], json!("system"));
  58. assert_eq!(body["tools"][0]["type"], json!("function"));
  59. }
  60. #[tokio::test]
  61. async fn send_message_accepts_full_chat_completions_endpoint_override() {
  62. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  63. let body = concat!(
  64. "{",
  65. "\"id\":\"chatcmpl_full_endpoint\",",
  66. "\"model\":\"grok-3\",",
  67. "\"choices\":[{",
  68. "\"message\":{\"role\":\"assistant\",\"content\":\"Endpoint override works\",\"tool_calls\":[]},",
  69. "\"finish_reason\":\"stop\"",
  70. "}],",
  71. "\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}",
  72. "}"
  73. );
  74. let server = spawn_server(
  75. state.clone(),
  76. vec![http_response("200 OK", "application/json", body)],
  77. )
  78. .await;
  79. let endpoint_url = format!("{}/chat/completions", server.base_url());
  80. let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
  81. .with_base_url(endpoint_url);
  82. let response = client
  83. .send_message(&sample_request(false))
  84. .await
  85. .expect("request should succeed");
  86. assert_eq!(response.total_tokens(), 10);
  87. let captured = state.lock().await;
  88. let request = captured.first().expect("server should capture request");
  89. assert_eq!(request.path, "/chat/completions");
  90. }
  91. #[tokio::test]
  92. async fn stream_message_normalizes_text_and_multiple_tool_calls() {
  93. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  94. let sse = concat!(
  95. "data: {\"id\":\"chatcmpl_stream\",\"model\":\"grok-3\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n",
  96. "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}},{\"index\":1,\"id\":\"call_2\",\"function\":{\"name\":\"clock\",\"arguments\":\"{\\\"zone\\\":\\\"UTC\\\"}\"}}]}}]}\n\n",
  97. "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n",
  98. "data: [DONE]\n\n"
  99. );
  100. let server = spawn_server(
  101. state.clone(),
  102. vec![http_response_with_headers(
  103. "200 OK",
  104. "text/event-stream",
  105. sse,
  106. &[("x-request-id", "req_grok_stream")],
  107. )],
  108. )
  109. .await;
  110. let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
  111. .with_base_url(server.base_url());
  112. let mut stream = client
  113. .stream_message(&sample_request(false))
  114. .await
  115. .expect("stream should start");
  116. assert_eq!(stream.request_id(), Some("req_grok_stream"));
  117. let mut events = Vec::new();
  118. while let Some(event) = stream.next_event().await.expect("event should parse") {
  119. events.push(event);
  120. }
  121. assert!(matches!(events[0], StreamEvent::MessageStart(_)));
  122. assert!(matches!(
  123. events[1],
  124. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  125. content_block: OutputContentBlock::Text { .. },
  126. ..
  127. })
  128. ));
  129. assert!(matches!(
  130. events[2],
  131. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  132. delta: ContentBlockDelta::TextDelta { .. },
  133. ..
  134. })
  135. ));
  136. assert!(matches!(
  137. events[3],
  138. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  139. index: 1,
  140. content_block: OutputContentBlock::ToolUse { .. },
  141. })
  142. ));
  143. assert!(matches!(
  144. events[4],
  145. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  146. index: 1,
  147. delta: ContentBlockDelta::InputJsonDelta { .. },
  148. })
  149. ));
  150. assert!(matches!(
  151. events[5],
  152. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  153. index: 2,
  154. content_block: OutputContentBlock::ToolUse { .. },
  155. })
  156. ));
  157. assert!(matches!(
  158. events[6],
  159. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  160. index: 2,
  161. delta: ContentBlockDelta::InputJsonDelta { .. },
  162. })
  163. ));
  164. assert!(matches!(
  165. events[7],
  166. StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 1 })
  167. ));
  168. assert!(matches!(
  169. events[8],
  170. StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 2 })
  171. ));
  172. assert!(matches!(
  173. events[9],
  174. StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 })
  175. ));
  176. assert!(matches!(events[10], StreamEvent::MessageDelta(_)));
  177. assert!(matches!(events[11], StreamEvent::MessageStop(_)));
  178. let captured = state.lock().await;
  179. let request = captured.first().expect("captured request");
  180. assert_eq!(request.path, "/chat/completions");
  181. assert!(request.body.contains("\"stream\":true"));
  182. }
  183. #[allow(clippy::await_holding_lock)]
  184. #[tokio::test]
  185. async fn openai_streaming_requests_opt_into_usage_chunks() {
  186. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  187. let sse = concat!(
  188. "data: {\"id\":\"chatcmpl_openai_stream\",\"model\":\"gpt-5\",\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n",
  189. "data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n",
  190. "data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}\n\n",
  191. "data: [DONE]\n\n"
  192. );
  193. let server = spawn_server(
  194. state.clone(),
  195. vec![http_response_with_headers(
  196. "200 OK",
  197. "text/event-stream",
  198. sse,
  199. &[("x-request-id", "req_openai_stream")],
  200. )],
  201. )
  202. .await;
  203. let client = OpenAiCompatClient::new("openai-test-key", OpenAiCompatConfig::openai())
  204. .with_base_url(server.base_url());
  205. let mut stream = client
  206. .stream_message(&sample_request(false))
  207. .await
  208. .expect("stream should start");
  209. assert_eq!(stream.request_id(), Some("req_openai_stream"));
  210. let mut events = Vec::new();
  211. while let Some(event) = stream.next_event().await.expect("event should parse") {
  212. events.push(event);
  213. }
  214. assert!(matches!(events[0], StreamEvent::MessageStart(_)));
  215. assert!(matches!(
  216. events[1],
  217. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  218. content_block: OutputContentBlock::Text { .. },
  219. ..
  220. })
  221. ));
  222. assert!(matches!(
  223. events[2],
  224. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  225. delta: ContentBlockDelta::TextDelta { .. },
  226. ..
  227. })
  228. ));
  229. assert!(matches!(
  230. events[3],
  231. StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 })
  232. ));
  233. assert!(matches!(
  234. events[4],
  235. StreamEvent::MessageDelta(MessageDeltaEvent { .. })
  236. ));
  237. assert!(matches!(events[5], StreamEvent::MessageStop(_)));
  238. match &events[4] {
  239. StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => {
  240. assert_eq!(usage.input_tokens, 9);
  241. assert_eq!(usage.output_tokens, 4);
  242. }
  243. other => panic!("expected message delta, got {other:?}"),
  244. }
  245. let captured = state.lock().await;
  246. let request = captured.first().expect("captured request");
  247. assert_eq!(request.path, "/chat/completions");
  248. let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
  249. assert_eq!(body["stream"], json!(true));
  250. assert_eq!(body["stream_options"], json!({"include_usage": true}));
  251. }
  252. #[allow(clippy::await_holding_lock)]
  253. #[tokio::test]
  254. async fn provider_client_dispatches_xai_requests_from_env() {
  255. let _lock = env_lock();
  256. let _api_key = ScopedEnvVar::set("XAI_API_KEY", "xai-test-key");
  257. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  258. let server = spawn_server(
  259. state.clone(),
  260. vec![http_response(
  261. "200 OK",
  262. "application/json",
  263. "{\"id\":\"chatcmpl_provider\",\"model\":\"grok-3\",\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"Through provider client\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}",
  264. )],
  265. )
  266. .await;
  267. let _base_url = ScopedEnvVar::set("XAI_BASE_URL", server.base_url());
  268. let client =
  269. ProviderClient::from_model("grok").expect("xAI provider client should be constructed");
  270. assert!(matches!(client, ProviderClient::Xai(_)));
  271. let response = client
  272. .send_message(&sample_request(false))
  273. .await
  274. .expect("provider-dispatched request should succeed");
  275. assert_eq!(response.total_tokens(), 13);
  276. let captured = state.lock().await;
  277. let request = captured.first().expect("captured request");
  278. assert_eq!(request.path, "/chat/completions");
  279. assert_eq!(
  280. request.headers.get("authorization").map(String::as_str),
  281. Some("Bearer xai-test-key")
  282. );
  283. }
  284. #[derive(Debug, Clone, PartialEq, Eq)]
  285. struct CapturedRequest {
  286. path: String,
  287. headers: HashMap<String, String>,
  288. body: String,
  289. }
  290. struct TestServer {
  291. base_url: String,
  292. join_handle: tokio::task::JoinHandle<()>,
  293. }
  294. impl TestServer {
  295. fn base_url(&self) -> String {
  296. self.base_url.clone()
  297. }
  298. }
  299. impl Drop for TestServer {
  300. fn drop(&mut self) {
  301. self.join_handle.abort();
  302. }
  303. }
  304. async fn spawn_server(
  305. state: Arc<Mutex<Vec<CapturedRequest>>>,
  306. responses: Vec<String>,
  307. ) -> TestServer {
  308. let listener = TcpListener::bind("127.0.0.1:0")
  309. .await
  310. .expect("listener should bind");
  311. let address = listener.local_addr().expect("listener addr");
  312. let join_handle = tokio::spawn(async move {
  313. for response in responses {
  314. let (mut socket, _) = listener.accept().await.expect("accept");
  315. let mut buffer = Vec::new();
  316. let mut header_end = None;
  317. loop {
  318. let mut chunk = [0_u8; 1024];
  319. let read = socket.read(&mut chunk).await.expect("read request");
  320. if read == 0 {
  321. break;
  322. }
  323. buffer.extend_from_slice(&chunk[..read]);
  324. if let Some(position) = find_header_end(&buffer) {
  325. header_end = Some(position);
  326. break;
  327. }
  328. }
  329. let header_end = header_end.expect("headers should exist");
  330. let (header_bytes, remaining) = buffer.split_at(header_end);
  331. let header_text = String::from_utf8(header_bytes.to_vec()).expect("utf8 headers");
  332. let mut lines = header_text.split("\r\n");
  333. let request_line = lines.next().expect("request line");
  334. let path = request_line
  335. .split_whitespace()
  336. .nth(1)
  337. .expect("path")
  338. .to_string();
  339. let mut headers = HashMap::new();
  340. let mut content_length = 0_usize;
  341. for line in lines {
  342. if line.is_empty() {
  343. continue;
  344. }
  345. let (name, value) = line.split_once(':').expect("header");
  346. let value = value.trim().to_string();
  347. if name.eq_ignore_ascii_case("content-length") {
  348. content_length = value.parse().expect("content length");
  349. }
  350. headers.insert(name.to_ascii_lowercase(), value);
  351. }
  352. let mut body = remaining[4..].to_vec();
  353. while body.len() < content_length {
  354. let mut chunk = vec![0_u8; content_length - body.len()];
  355. let read = socket.read(&mut chunk).await.expect("read body");
  356. if read == 0 {
  357. break;
  358. }
  359. body.extend_from_slice(&chunk[..read]);
  360. }
  361. state.lock().await.push(CapturedRequest {
  362. path,
  363. headers,
  364. body: String::from_utf8(body).expect("utf8 body"),
  365. });
  366. socket
  367. .write_all(response.as_bytes())
  368. .await
  369. .expect("write response");
  370. }
  371. });
  372. TestServer {
  373. base_url: format!("http://{address}"),
  374. join_handle,
  375. }
  376. }
  377. fn find_header_end(bytes: &[u8]) -> Option<usize> {
  378. bytes.windows(4).position(|window| window == b"\r\n\r\n")
  379. }
  380. fn http_response(status: &str, content_type: &str, body: &str) -> String {
  381. http_response_with_headers(status, content_type, body, &[])
  382. }
  383. fn http_response_with_headers(
  384. status: &str,
  385. content_type: &str,
  386. body: &str,
  387. headers: &[(&str, &str)],
  388. ) -> String {
  389. let mut extra_headers = String::new();
  390. for (name, value) in headers {
  391. use std::fmt::Write as _;
  392. write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write");
  393. }
  394. format!(
  395. "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
  396. body.len()
  397. )
  398. }
  399. fn sample_request(stream: bool) -> MessageRequest {
  400. MessageRequest {
  401. model: "grok-3".to_string(),
  402. max_tokens: 64,
  403. messages: vec![InputMessage {
  404. role: "user".to_string(),
  405. content: vec![InputContentBlock::Text {
  406. text: "Say hello".to_string(),
  407. }],
  408. }],
  409. system: Some("Use tools when needed".to_string()),
  410. tools: Some(vec![ToolDefinition {
  411. name: "weather".to_string(),
  412. description: Some("Fetches weather".to_string()),
  413. input_schema: json!({
  414. "type": "object",
  415. "properties": {"city": {"type": "string"}},
  416. "required": ["city"]
  417. }),
  418. }]),
  419. tool_choice: Some(ToolChoice::Auto),
  420. stream,
  421. }
  422. }
  423. fn env_lock() -> std::sync::MutexGuard<'static, ()> {
  424. static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
  425. LOCK.get_or_init(|| StdMutex::new(()))
  426. .lock()
  427. .unwrap_or_else(std::sync::PoisonError::into_inner)
  428. }
  429. struct ScopedEnvVar {
  430. key: &'static str,
  431. previous: Option<OsString>,
  432. }
  433. impl ScopedEnvVar {
  434. fn set(key: &'static str, value: impl AsRef<std::ffi::OsStr>) -> Self {
  435. let previous = std::env::var_os(key);
  436. std::env::set_var(key, value);
  437. Self { key, previous }
  438. }
  439. }
  440. impl Drop for ScopedEnvVar {
  441. fn drop(&mut self) {
  442. match &self.previous {
  443. Some(value) => std::env::set_var(self.key, value),
  444. None => std::env::remove_var(self.key),
  445. }
  446. }
  447. }