openai_compat_integration.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. use std::collections::HashMap;
  2. use std::ffi::OsString;
  3. use std::sync::Arc;
  4. use std::sync::{Mutex as StdMutex, OnceLock};
  5. use api::{
  6. ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
  7. InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig,
  8. OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition,
  9. };
  10. use serde_json::json;
  11. use tokio::io::{AsyncReadExt, AsyncWriteExt};
  12. use tokio::net::TcpListener;
  13. use tokio::sync::Mutex;
  14. #[tokio::test]
  15. async fn send_message_uses_openai_compatible_endpoint_and_auth() {
  16. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  17. let body = concat!(
  18. "{",
  19. "\"id\":\"chatcmpl_test\",",
  20. "\"model\":\"grok-3\",",
  21. "\"choices\":[{",
  22. "\"message\":{\"role\":\"assistant\",\"content\":\"Hello from Grok\",\"tool_calls\":[]},",
  23. "\"finish_reason\":\"stop\"",
  24. "}],",
  25. "\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":5}",
  26. "}"
  27. );
  28. let server = spawn_server(
  29. state.clone(),
  30. vec![http_response("200 OK", "application/json", body)],
  31. )
  32. .await;
  33. let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
  34. .with_base_url(server.base_url());
  35. let response = client
  36. .send_message(&sample_request(false))
  37. .await
  38. .expect("request should succeed");
  39. assert_eq!(response.model, "grok-3");
  40. assert_eq!(response.total_tokens(), 16);
  41. assert_eq!(
  42. response.content,
  43. vec![OutputContentBlock::Text {
  44. text: "Hello from Grok".to_string(),
  45. }]
  46. );
  47. let captured = state.lock().await;
  48. let request = captured.first().expect("server should capture request");
  49. assert_eq!(request.path, "/chat/completions");
  50. assert_eq!(
  51. request.headers.get("authorization").map(String::as_str),
  52. Some("Bearer xai-test-key")
  53. );
  54. let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
  55. assert_eq!(body["model"], json!("grok-3"));
  56. assert_eq!(body["messages"][0]["role"], json!("system"));
  57. assert_eq!(body["tools"][0]["type"], json!("function"));
  58. }
  59. #[tokio::test]
  60. async fn send_message_accepts_full_chat_completions_endpoint_override() {
  61. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  62. let body = concat!(
  63. "{",
  64. "\"id\":\"chatcmpl_full_endpoint\",",
  65. "\"model\":\"grok-3\",",
  66. "\"choices\":[{",
  67. "\"message\":{\"role\":\"assistant\",\"content\":\"Endpoint override works\",\"tool_calls\":[]},",
  68. "\"finish_reason\":\"stop\"",
  69. "}],",
  70. "\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}",
  71. "}"
  72. );
  73. let server = spawn_server(
  74. state.clone(),
  75. vec![http_response("200 OK", "application/json", body)],
  76. )
  77. .await;
  78. let endpoint_url = format!("{}/chat/completions", server.base_url());
  79. let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
  80. .with_base_url(endpoint_url);
  81. let response = client
  82. .send_message(&sample_request(false))
  83. .await
  84. .expect("request should succeed");
  85. assert_eq!(response.total_tokens(), 10);
  86. let captured = state.lock().await;
  87. let request = captured.first().expect("server should capture request");
  88. assert_eq!(request.path, "/chat/completions");
  89. }
  90. #[tokio::test]
  91. async fn stream_message_normalizes_text_and_multiple_tool_calls() {
  92. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  93. let sse = concat!(
  94. "data: {\"id\":\"chatcmpl_stream\",\"model\":\"grok-3\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n",
  95. "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}},{\"index\":1,\"id\":\"call_2\",\"function\":{\"name\":\"clock\",\"arguments\":\"{\\\"zone\\\":\\\"UTC\\\"}\"}}]}}]}\n\n",
  96. "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n",
  97. "data: [DONE]\n\n"
  98. );
  99. let server = spawn_server(
  100. state.clone(),
  101. vec![http_response_with_headers(
  102. "200 OK",
  103. "text/event-stream",
  104. sse,
  105. &[("x-request-id", "req_grok_stream")],
  106. )],
  107. )
  108. .await;
  109. let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
  110. .with_base_url(server.base_url());
  111. let mut stream = client
  112. .stream_message(&sample_request(false))
  113. .await
  114. .expect("stream should start");
  115. assert_eq!(stream.request_id(), Some("req_grok_stream"));
  116. let mut events = Vec::new();
  117. while let Some(event) = stream.next_event().await.expect("event should parse") {
  118. events.push(event);
  119. }
  120. assert!(matches!(events[0], StreamEvent::MessageStart(_)));
  121. assert!(matches!(
  122. events[1],
  123. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  124. content_block: OutputContentBlock::Text { .. },
  125. ..
  126. })
  127. ));
  128. assert!(matches!(
  129. events[2],
  130. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  131. delta: ContentBlockDelta::TextDelta { .. },
  132. ..
  133. })
  134. ));
  135. assert!(matches!(
  136. events[3],
  137. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  138. index: 1,
  139. content_block: OutputContentBlock::ToolUse { .. },
  140. })
  141. ));
  142. assert!(matches!(
  143. events[4],
  144. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  145. index: 1,
  146. delta: ContentBlockDelta::InputJsonDelta { .. },
  147. })
  148. ));
  149. assert!(matches!(
  150. events[5],
  151. StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  152. index: 2,
  153. content_block: OutputContentBlock::ToolUse { .. },
  154. })
  155. ));
  156. assert!(matches!(
  157. events[6],
  158. StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  159. index: 2,
  160. delta: ContentBlockDelta::InputJsonDelta { .. },
  161. })
  162. ));
  163. assert!(matches!(
  164. events[7],
  165. StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 1 })
  166. ));
  167. assert!(matches!(
  168. events[8],
  169. StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 2 })
  170. ));
  171. assert!(matches!(
  172. events[9],
  173. StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 })
  174. ));
  175. assert!(matches!(events[10], StreamEvent::MessageDelta(_)));
  176. assert!(matches!(events[11], StreamEvent::MessageStop(_)));
  177. let captured = state.lock().await;
  178. let request = captured.first().expect("captured request");
  179. assert_eq!(request.path, "/chat/completions");
  180. assert!(request.body.contains("\"stream\":true"));
  181. }
  182. #[allow(clippy::await_holding_lock)]
  183. #[tokio::test]
  184. async fn provider_client_dispatches_xai_requests_from_env() {
  185. let _lock = env_lock();
  186. let _api_key = ScopedEnvVar::set("XAI_API_KEY", "xai-test-key");
  187. let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
  188. let server = spawn_server(
  189. state.clone(),
  190. vec![http_response(
  191. "200 OK",
  192. "application/json",
  193. "{\"id\":\"chatcmpl_provider\",\"model\":\"grok-3\",\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"Through provider client\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}",
  194. )],
  195. )
  196. .await;
  197. let _base_url = ScopedEnvVar::set("XAI_BASE_URL", server.base_url());
  198. let client =
  199. ProviderClient::from_model("grok").expect("xAI provider client should be constructed");
  200. assert!(matches!(client, ProviderClient::Xai(_)));
  201. let response = client
  202. .send_message(&sample_request(false))
  203. .await
  204. .expect("provider-dispatched request should succeed");
  205. assert_eq!(response.total_tokens(), 13);
  206. let captured = state.lock().await;
  207. let request = captured.first().expect("captured request");
  208. assert_eq!(request.path, "/chat/completions");
  209. assert_eq!(
  210. request.headers.get("authorization").map(String::as_str),
  211. Some("Bearer xai-test-key")
  212. );
  213. }
  214. #[derive(Debug, Clone, PartialEq, Eq)]
  215. struct CapturedRequest {
  216. path: String,
  217. headers: HashMap<String, String>,
  218. body: String,
  219. }
  220. struct TestServer {
  221. base_url: String,
  222. join_handle: tokio::task::JoinHandle<()>,
  223. }
  224. impl TestServer {
  225. fn base_url(&self) -> String {
  226. self.base_url.clone()
  227. }
  228. }
  229. impl Drop for TestServer {
  230. fn drop(&mut self) {
  231. self.join_handle.abort();
  232. }
  233. }
  234. async fn spawn_server(
  235. state: Arc<Mutex<Vec<CapturedRequest>>>,
  236. responses: Vec<String>,
  237. ) -> TestServer {
  238. let listener = TcpListener::bind("127.0.0.1:0")
  239. .await
  240. .expect("listener should bind");
  241. let address = listener.local_addr().expect("listener addr");
  242. let join_handle = tokio::spawn(async move {
  243. for response in responses {
  244. let (mut socket, _) = listener.accept().await.expect("accept");
  245. let mut buffer = Vec::new();
  246. let mut header_end = None;
  247. loop {
  248. let mut chunk = [0_u8; 1024];
  249. let read = socket.read(&mut chunk).await.expect("read request");
  250. if read == 0 {
  251. break;
  252. }
  253. buffer.extend_from_slice(&chunk[..read]);
  254. if let Some(position) = find_header_end(&buffer) {
  255. header_end = Some(position);
  256. break;
  257. }
  258. }
  259. let header_end = header_end.expect("headers should exist");
  260. let (header_bytes, remaining) = buffer.split_at(header_end);
  261. let header_text = String::from_utf8(header_bytes.to_vec()).expect("utf8 headers");
  262. let mut lines = header_text.split("\r\n");
  263. let request_line = lines.next().expect("request line");
  264. let path = request_line
  265. .split_whitespace()
  266. .nth(1)
  267. .expect("path")
  268. .to_string();
  269. let mut headers = HashMap::new();
  270. let mut content_length = 0_usize;
  271. for line in lines {
  272. if line.is_empty() {
  273. continue;
  274. }
  275. let (name, value) = line.split_once(':').expect("header");
  276. let value = value.trim().to_string();
  277. if name.eq_ignore_ascii_case("content-length") {
  278. content_length = value.parse().expect("content length");
  279. }
  280. headers.insert(name.to_ascii_lowercase(), value);
  281. }
  282. let mut body = remaining[4..].to_vec();
  283. while body.len() < content_length {
  284. let mut chunk = vec![0_u8; content_length - body.len()];
  285. let read = socket.read(&mut chunk).await.expect("read body");
  286. if read == 0 {
  287. break;
  288. }
  289. body.extend_from_slice(&chunk[..read]);
  290. }
  291. state.lock().await.push(CapturedRequest {
  292. path,
  293. headers,
  294. body: String::from_utf8(body).expect("utf8 body"),
  295. });
  296. socket
  297. .write_all(response.as_bytes())
  298. .await
  299. .expect("write response");
  300. }
  301. });
  302. TestServer {
  303. base_url: format!("http://{address}"),
  304. join_handle,
  305. }
  306. }
  307. fn find_header_end(bytes: &[u8]) -> Option<usize> {
  308. bytes.windows(4).position(|window| window == b"\r\n\r\n")
  309. }
  310. fn http_response(status: &str, content_type: &str, body: &str) -> String {
  311. http_response_with_headers(status, content_type, body, &[])
  312. }
  313. fn http_response_with_headers(
  314. status: &str,
  315. content_type: &str,
  316. body: &str,
  317. headers: &[(&str, &str)],
  318. ) -> String {
  319. let mut extra_headers = String::new();
  320. for (name, value) in headers {
  321. use std::fmt::Write as _;
  322. write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write");
  323. }
  324. format!(
  325. "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
  326. body.len()
  327. )
  328. }
  329. fn sample_request(stream: bool) -> MessageRequest {
  330. MessageRequest {
  331. model: "grok-3".to_string(),
  332. max_tokens: 64,
  333. messages: vec![InputMessage {
  334. role: "user".to_string(),
  335. content: vec![InputContentBlock::Text {
  336. text: "Say hello".to_string(),
  337. }],
  338. }],
  339. system: Some("Use tools when needed".to_string()),
  340. tools: Some(vec![ToolDefinition {
  341. name: "weather".to_string(),
  342. description: Some("Fetches weather".to_string()),
  343. input_schema: json!({
  344. "type": "object",
  345. "properties": {"city": {"type": "string"}},
  346. "required": ["city"]
  347. }),
  348. }]),
  349. tool_choice: Some(ToolChoice::Auto),
  350. stream,
  351. }
  352. }
  353. fn env_lock() -> std::sync::MutexGuard<'static, ()> {
  354. static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
  355. LOCK.get_or_init(|| StdMutex::new(()))
  356. .lock()
  357. .unwrap_or_else(std::sync::PoisonError::into_inner)
  358. }
  359. struct ScopedEnvVar {
  360. key: &'static str,
  361. previous: Option<OsString>,
  362. }
  363. impl ScopedEnvVar {
  364. fn set(key: &'static str, value: impl AsRef<std::ffi::OsStr>) -> Self {
  365. let previous = std::env::var_os(key);
  366. std::env::set_var(key, value);
  367. Self { key, previous }
  368. }
  369. }
  370. impl Drop for ScopedEnvVar {
  371. fn drop(&mut self) {
  372. match &self.previous {
  373. Some(value) => std::env::set_var(self.key, value),
  374. None => std::env::remove_var(self.key),
  375. }
  376. }
  377. }