| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416 |
- use std::collections::HashMap;
- use std::ffi::OsString;
- use std::sync::Arc;
- use std::sync::{Mutex as StdMutex, OnceLock};
- use api::{
- ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
- InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig,
- OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition,
- };
- use serde_json::json;
- use tokio::io::{AsyncReadExt, AsyncWriteExt};
- use tokio::net::TcpListener;
- use tokio::sync::Mutex;
- #[tokio::test]
- async fn send_message_uses_openai_compatible_endpoint_and_auth() {
- let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
- let body = concat!(
- "{",
- "\"id\":\"chatcmpl_test\",",
- "\"model\":\"grok-3\",",
- "\"choices\":[{",
- "\"message\":{\"role\":\"assistant\",\"content\":\"Hello from Grok\",\"tool_calls\":[]},",
- "\"finish_reason\":\"stop\"",
- "}],",
- "\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":5}",
- "}"
- );
- let server = spawn_server(
- state.clone(),
- vec![http_response("200 OK", "application/json", body)],
- )
- .await;
- let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
- .with_base_url(server.base_url());
- let response = client
- .send_message(&sample_request(false))
- .await
- .expect("request should succeed");
- assert_eq!(response.model, "grok-3");
- assert_eq!(response.total_tokens(), 16);
- assert_eq!(
- response.content,
- vec![OutputContentBlock::Text {
- text: "Hello from Grok".to_string(),
- }]
- );
- let captured = state.lock().await;
- let request = captured.first().expect("server should capture request");
- assert_eq!(request.path, "/chat/completions");
- assert_eq!(
- request.headers.get("authorization").map(String::as_str),
- Some("Bearer xai-test-key")
- );
- let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
- assert_eq!(body["model"], json!("grok-3"));
- assert_eq!(body["messages"][0]["role"], json!("system"));
- assert_eq!(body["tools"][0]["type"], json!("function"));
- }
- #[tokio::test]
- async fn send_message_accepts_full_chat_completions_endpoint_override() {
- let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
- let body = concat!(
- "{",
- "\"id\":\"chatcmpl_full_endpoint\",",
- "\"model\":\"grok-3\",",
- "\"choices\":[{",
- "\"message\":{\"role\":\"assistant\",\"content\":\"Endpoint override works\",\"tool_calls\":[]},",
- "\"finish_reason\":\"stop\"",
- "}],",
- "\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}",
- "}"
- );
- let server = spawn_server(
- state.clone(),
- vec![http_response("200 OK", "application/json", body)],
- )
- .await;
- let endpoint_url = format!("{}/chat/completions", server.base_url());
- let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
- .with_base_url(endpoint_url);
- let response = client
- .send_message(&sample_request(false))
- .await
- .expect("request should succeed");
- assert_eq!(response.total_tokens(), 10);
- let captured = state.lock().await;
- let request = captured.first().expect("server should capture request");
- assert_eq!(request.path, "/chat/completions");
- }
- #[tokio::test]
- async fn stream_message_normalizes_text_and_multiple_tool_calls() {
- let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
- let sse = concat!(
- "data: {\"id\":\"chatcmpl_stream\",\"model\":\"grok-3\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n",
- "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}},{\"index\":1,\"id\":\"call_2\",\"function\":{\"name\":\"clock\",\"arguments\":\"{\\\"zone\\\":\\\"UTC\\\"}\"}}]}}]}\n\n",
- "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n",
- "data: [DONE]\n\n"
- );
- let server = spawn_server(
- state.clone(),
- vec![http_response_with_headers(
- "200 OK",
- "text/event-stream",
- sse,
- &[("x-request-id", "req_grok_stream")],
- )],
- )
- .await;
- let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
- .with_base_url(server.base_url());
- let mut stream = client
- .stream_message(&sample_request(false))
- .await
- .expect("stream should start");
- assert_eq!(stream.request_id(), Some("req_grok_stream"));
- let mut events = Vec::new();
- while let Some(event) = stream.next_event().await.expect("event should parse") {
- events.push(event);
- }
- assert!(matches!(events[0], StreamEvent::MessageStart(_)));
- assert!(matches!(
- events[1],
- StreamEvent::ContentBlockStart(ContentBlockStartEvent {
- content_block: OutputContentBlock::Text { .. },
- ..
- })
- ));
- assert!(matches!(
- events[2],
- StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
- delta: ContentBlockDelta::TextDelta { .. },
- ..
- })
- ));
- assert!(matches!(
- events[3],
- StreamEvent::ContentBlockStart(ContentBlockStartEvent {
- index: 1,
- content_block: OutputContentBlock::ToolUse { .. },
- })
- ));
- assert!(matches!(
- events[4],
- StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
- index: 1,
- delta: ContentBlockDelta::InputJsonDelta { .. },
- })
- ));
- assert!(matches!(
- events[5],
- StreamEvent::ContentBlockStart(ContentBlockStartEvent {
- index: 2,
- content_block: OutputContentBlock::ToolUse { .. },
- })
- ));
- assert!(matches!(
- events[6],
- StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
- index: 2,
- delta: ContentBlockDelta::InputJsonDelta { .. },
- })
- ));
- assert!(matches!(
- events[7],
- StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 1 })
- ));
- assert!(matches!(
- events[8],
- StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 2 })
- ));
- assert!(matches!(
- events[9],
- StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 })
- ));
- assert!(matches!(events[10], StreamEvent::MessageDelta(_)));
- assert!(matches!(events[11], StreamEvent::MessageStop(_)));
- let captured = state.lock().await;
- let request = captured.first().expect("captured request");
- assert_eq!(request.path, "/chat/completions");
- assert!(request.body.contains("\"stream\":true"));
- }
- #[allow(clippy::await_holding_lock)]
- #[tokio::test]
- async fn provider_client_dispatches_xai_requests_from_env() {
- let _lock = env_lock();
- let _api_key = ScopedEnvVar::set("XAI_API_KEY", "xai-test-key");
- let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
- let server = spawn_server(
- state.clone(),
- vec![http_response(
- "200 OK",
- "application/json",
- "{\"id\":\"chatcmpl_provider\",\"model\":\"grok-3\",\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"Through provider client\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}",
- )],
- )
- .await;
- let _base_url = ScopedEnvVar::set("XAI_BASE_URL", server.base_url());
- let client =
- ProviderClient::from_model("grok").expect("xAI provider client should be constructed");
- assert!(matches!(client, ProviderClient::Xai(_)));
- let response = client
- .send_message(&sample_request(false))
- .await
- .expect("provider-dispatched request should succeed");
- assert_eq!(response.total_tokens(), 13);
- let captured = state.lock().await;
- let request = captured.first().expect("captured request");
- assert_eq!(request.path, "/chat/completions");
- assert_eq!(
- request.headers.get("authorization").map(String::as_str),
- Some("Bearer xai-test-key")
- );
- }
- #[derive(Debug, Clone, PartialEq, Eq)]
- struct CapturedRequest {
- path: String,
- headers: HashMap<String, String>,
- body: String,
- }
- struct TestServer {
- base_url: String,
- join_handle: tokio::task::JoinHandle<()>,
- }
- impl TestServer {
- fn base_url(&self) -> String {
- self.base_url.clone()
- }
- }
- impl Drop for TestServer {
- fn drop(&mut self) {
- self.join_handle.abort();
- }
- }
- async fn spawn_server(
- state: Arc<Mutex<Vec<CapturedRequest>>>,
- responses: Vec<String>,
- ) -> TestServer {
- let listener = TcpListener::bind("127.0.0.1:0")
- .await
- .expect("listener should bind");
- let address = listener.local_addr().expect("listener addr");
- let join_handle = tokio::spawn(async move {
- for response in responses {
- let (mut socket, _) = listener.accept().await.expect("accept");
- let mut buffer = Vec::new();
- let mut header_end = None;
- loop {
- let mut chunk = [0_u8; 1024];
- let read = socket.read(&mut chunk).await.expect("read request");
- if read == 0 {
- break;
- }
- buffer.extend_from_slice(&chunk[..read]);
- if let Some(position) = find_header_end(&buffer) {
- header_end = Some(position);
- break;
- }
- }
- let header_end = header_end.expect("headers should exist");
- let (header_bytes, remaining) = buffer.split_at(header_end);
- let header_text = String::from_utf8(header_bytes.to_vec()).expect("utf8 headers");
- let mut lines = header_text.split("\r\n");
- let request_line = lines.next().expect("request line");
- let path = request_line
- .split_whitespace()
- .nth(1)
- .expect("path")
- .to_string();
- let mut headers = HashMap::new();
- let mut content_length = 0_usize;
- for line in lines {
- if line.is_empty() {
- continue;
- }
- let (name, value) = line.split_once(':').expect("header");
- let value = value.trim().to_string();
- if name.eq_ignore_ascii_case("content-length") {
- content_length = value.parse().expect("content length");
- }
- headers.insert(name.to_ascii_lowercase(), value);
- }
- let mut body = remaining[4..].to_vec();
- while body.len() < content_length {
- let mut chunk = vec![0_u8; content_length - body.len()];
- let read = socket.read(&mut chunk).await.expect("read body");
- if read == 0 {
- break;
- }
- body.extend_from_slice(&chunk[..read]);
- }
- state.lock().await.push(CapturedRequest {
- path,
- headers,
- body: String::from_utf8(body).expect("utf8 body"),
- });
- socket
- .write_all(response.as_bytes())
- .await
- .expect("write response");
- }
- });
- TestServer {
- base_url: format!("http://{address}"),
- join_handle,
- }
- }
- fn find_header_end(bytes: &[u8]) -> Option<usize> {
- bytes.windows(4).position(|window| window == b"\r\n\r\n")
- }
- fn http_response(status: &str, content_type: &str, body: &str) -> String {
- http_response_with_headers(status, content_type, body, &[])
- }
- fn http_response_with_headers(
- status: &str,
- content_type: &str,
- body: &str,
- headers: &[(&str, &str)],
- ) -> String {
- let mut extra_headers = String::new();
- for (name, value) in headers {
- use std::fmt::Write as _;
- write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write");
- }
- format!(
- "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
- body.len()
- )
- }
- fn sample_request(stream: bool) -> MessageRequest {
- MessageRequest {
- model: "grok-3".to_string(),
- max_tokens: 64,
- messages: vec![InputMessage {
- role: "user".to_string(),
- content: vec![InputContentBlock::Text {
- text: "Say hello".to_string(),
- }],
- }],
- system: Some("Use tools when needed".to_string()),
- tools: Some(vec![ToolDefinition {
- name: "weather".to_string(),
- description: Some("Fetches weather".to_string()),
- input_schema: json!({
- "type": "object",
- "properties": {"city": {"type": "string"}},
- "required": ["city"]
- }),
- }]),
- tool_choice: Some(ToolChoice::Auto),
- stream,
- }
- }
- fn env_lock() -> std::sync::MutexGuard<'static, ()> {
- static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
- LOCK.get_or_init(|| StdMutex::new(()))
- .lock()
- .unwrap_or_else(std::sync::PoisonError::into_inner)
- }
- struct ScopedEnvVar {
- key: &'static str,
- previous: Option<OsString>,
- }
- impl ScopedEnvVar {
- fn set(key: &'static str, value: impl AsRef<std::ffi::OsStr>) -> Self {
- let previous = std::env::var_os(key);
- std::env::set_var(key, value);
- Self { key, previous }
- }
- }
- impl Drop for ScopedEnvVar {
- fn drop(&mut self) {
- match &self.previous {
- Some(value) => std::env::set_var(self.key, value),
- None => std::env::remove_var(self.key),
- }
- }
- }
|