Browse Source

wip: grok provider abstraction

Yeachan-Heo 2 tháng trước cách đây
mục cha
commit
40008b6513

+ 28 - 3
rust/crates/api/src/providers/openai_compat.rs

@@ -185,7 +185,7 @@ impl OpenAiCompatClient {
         &self,
         request: &MessageRequest,
     ) -> Result<reqwest::Response, ApiError> {
-        let request_url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
+        let request_url = chat_completions_endpoint(&self.base_url);
         self.http
             .post(&request_url)
             .header("content-type", "application/json")
@@ -866,6 +866,15 @@ pub fn read_base_url(config: OpenAiCompatConfig) -> String {
     std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
 }
 
+fn chat_completions_endpoint(base_url: &str) -> String {
+    let trimmed = base_url.trim_end_matches('/');
+    if trimmed.ends_with("/chat/completions") {
+        trimmed.to_string()
+    } else {
+        format!("{trimmed}/chat/completions")
+    }
+}
+
 fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
     headers
         .get(REQUEST_ID_HEADER)
@@ -927,8 +936,8 @@ impl StringExt for String {
 #[cfg(test)]
 mod tests {
     use super::{
-        build_chat_completion_request, normalize_finish_reason, openai_tool_choice,
-        parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig,
+        build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason,
+        openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig,
     };
     use crate::error::ApiError;
     use crate::types::{
@@ -1010,6 +1019,22 @@ mod tests {
         ));
     }
 
+    #[test]
+    fn endpoint_builder_accepts_base_urls_and_full_endpoints() {
+        assert_eq!(
+            chat_completions_endpoint("https://api.x.ai/v1"),
+            "https://api.x.ai/v1/chat/completions"
+        );
+        assert_eq!(
+            chat_completions_endpoint("https://api.x.ai/v1/"),
+            "https://api.x.ai/v1/chat/completions"
+        );
+        assert_eq!(
+            chat_completions_endpoint("https://api.x.ai/v1/chat/completions"),
+            "https://api.x.ai/v1/chat/completions"
+        );
+    }
+
     fn env_lock() -> std::sync::MutexGuard<'static, ()> {
         static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
         LOCK.get_or_init(|| Mutex::new(()))

+ 35 - 0
rust/crates/api/tests/openai_compat_integration.rs

@@ -62,6 +62,41 @@ async fn send_message_uses_openai_compatible_endpoint_and_auth() {
     assert_eq!(body["tools"][0]["type"], json!("function"));
 }
 
+#[tokio::test]
+async fn send_message_accepts_full_chat_completions_endpoint_override() {
+    let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
+    let body = concat!(
+        "{",
+        "\"id\":\"chatcmpl_full_endpoint\",",
+        "\"model\":\"grok-3\",",
+        "\"choices\":[{",
+        "\"message\":{\"role\":\"assistant\",\"content\":\"Endpoint override works\",\"tool_calls\":[]},",
+        "\"finish_reason\":\"stop\"",
+        "}],",
+        "\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}",
+        "}"
+    );
+    let server = spawn_server(
+        state.clone(),
+        vec![http_response("200 OK", "application/json", body)],
+    )
+    .await;
+
+    let endpoint_url = format!("{}/chat/completions", server.base_url());
+    let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
+        .with_base_url(endpoint_url);
+    let response = client
+        .send_message(&sample_request(false))
+        .await
+        .expect("request should succeed");
+
+    assert_eq!(response.total_tokens(), 10);
+
+    let captured = state.lock().await;
+    let request = captured.first().expect("server should capture request");
+    assert_eq!(request.path, "/chat/completions");
+}
+
 #[tokio::test]
 async fn stream_message_normalizes_text_and_multiple_tool_calls() {
     let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));

+ 2 - 1
rust/crates/rusty-claude-cli/src/main.rs

@@ -1907,13 +1907,14 @@ fn build_runtime(
     permission_mode: PermissionMode,
 ) -> Result<ConversationRuntime<ProviderRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
 {
+    let feature_config = build_runtime_feature_config()?;
     Ok(ConversationRuntime::new_with_features(
         session,
         ProviderRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?,
         CliToolExecutor::new(allowed_tools, emit_output),
         permission_policy(permission_mode),
         system_prompt,
-        build_runtime_feature_config()?,
+        &feature_config,
     ))
 }