Pārlūkot izejas kodu

feat: provider tests + grok integration

Yeachan-Heo 2 mēneši atpakaļ
vecāks
revīzija
f477dde4a6

+ 4 - 5
rust/crates/api/src/client.rs

@@ -36,11 +36,10 @@ impl ProviderClient {
     ) -> Result<Self, ApiError> {
         let resolved_model = providers::resolve_model_alias(model);
         match providers::detect_provider_kind(&resolved_model) {
-            ProviderKind::Anthropic => Ok(Self::Anthropic(
-                anthropic_auth
-                    .map(AnthropicClient::from_auth)
-                    .unwrap_or(AnthropicClient::from_env()?),
-            )),
+            ProviderKind::Anthropic => Ok(Self::Anthropic(match anthropic_auth {
+                Some(auth) => AnthropicClient::from_auth(auth),
+                None => AnthropicClient::from_env()?,
+            })),
             ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env(
                 OpenAiCompatConfig::xai(),
             )?)),

+ 44 - 3
rust/crates/api/tests/client_integration.rs

@@ -3,9 +3,9 @@ use std::sync::Arc;
 use std::time::Duration;
 
 use api::{
-    AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
-    InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock,
-    StreamEvent, ToolChoice, ToolDefinition,
+    AnthropicClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent,
+    ContentBlockStartEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest,
+    OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition,
 };
 use serde_json::json;
 use tokio::io::{AsyncReadExt, AsyncWriteExt};
@@ -195,6 +195,47 @@ async fn retries_retryable_failures_before_succeeding() {
     assert_eq!(state.lock().await.len(), 2);
 }
 
+#[tokio::test]
+async fn provider_client_dispatches_anthropic_requests() {
+    let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
+    let server = spawn_server(
+        state.clone(),
+        vec![http_response(
+            "200 OK",
+            "application/json",
+            "{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
+        )],
+    )
+    .await;
+
+    let client = ProviderClient::from_model_with_anthropic_auth(
+        "claude-sonnet-4-6",
+        Some(AuthSource::ApiKey("test-key".to_string())),
+    )
+    .expect("anthropic provider client should be constructed");
+    let client = match client {
+        ProviderClient::Anthropic(client) => {
+            ProviderClient::Anthropic(client.with_base_url(server.base_url()))
+        }
+        other => panic!("expected anthropic provider, got {other:?}"),
+    };
+
+    let response = client
+        .send_message(&sample_request(false))
+        .await
+        .expect("provider-dispatched request should succeed");
+
+    assert_eq!(response.total_tokens(), 5);
+
+    let captured = state.lock().await;
+    let request = captured.first().expect("server should capture request");
+    assert_eq!(request.path, "/v1/messages");
+    assert_eq!(
+        request.headers.get("x-api-key").map(String::as_str),
+        Some("test-key")
+    );
+}
+
 #[tokio::test]
 async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
     let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));

+ 69 - 1
rust/crates/api/tests/openai_compat_integration.rs

@@ -1,10 +1,12 @@
 use std::collections::HashMap;
+use std::ffi::OsString;
 use std::sync::Arc;
+use std::sync::{Mutex as StdMutex, OnceLock};
 
 use api::{
     ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
     InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig,
-    OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition,
+    OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition,
 };
 use serde_json::json;
 use tokio::io::{AsyncReadExt, AsyncWriteExt};
@@ -158,6 +160,43 @@ async fn stream_message_normalizes_text_and_multiple_tool_calls() {
     assert!(request.body.contains("\"stream\":true"));
 }
 
+#[tokio::test]
+async fn provider_client_dispatches_xai_requests_from_env() {
+    let _lock = env_lock();
+    let _api_key = ScopedEnvVar::set("XAI_API_KEY", "xai-test-key");
+
+    let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
+    let server = spawn_server(
+        state.clone(),
+        vec![http_response(
+            "200 OK",
+            "application/json",
+            "{\"id\":\"chatcmpl_provider\",\"model\":\"grok-3\",\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"Through provider client\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}",
+        )],
+    )
+    .await;
+    let _base_url = ScopedEnvVar::set("XAI_BASE_URL", server.base_url());
+
+    let client =
+        ProviderClient::from_model("grok").expect("xAI provider client should be constructed");
+    assert!(matches!(client, ProviderClient::Xai(_)));
+
+    let response = client
+        .send_message(&sample_request(false))
+        .await
+        .expect("provider-dispatched request should succeed");
+
+    assert_eq!(response.total_tokens(), 13);
+
+    let captured = state.lock().await;
+    let request = captured.first().expect("captured request");
+    assert_eq!(request.path, "/chat/completions");
+    assert_eq!(
+        request.headers.get("authorization").map(String::as_str),
+        Some("Bearer xai-test-key")
+    );
+}
+
 #[derive(Debug, Clone, PartialEq, Eq)]
 struct CapturedRequest {
     path: String,
@@ -310,3 +349,32 @@ fn sample_request(stream: bool) -> MessageRequest {
         stream,
     }
 }
+
+fn env_lock() -> std::sync::MutexGuard<'static, ()> {
+    static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
+    LOCK.get_or_init(|| StdMutex::new(()))
+        .lock()
+        .unwrap_or_else(|poisoned| poisoned.into_inner())
+}
+
+struct ScopedEnvVar {
+    key: &'static str,
+    previous: Option<OsString>,
+}
+
+impl ScopedEnvVar {
+    fn set(key: &'static str, value: impl AsRef<std::ffi::OsStr>) -> Self {
+        let previous = std::env::var_os(key);
+        std::env::set_var(key, value);
+        Self { key, previous }
+    }
+}
+
+impl Drop for ScopedEnvVar {
+    fn drop(&mut self) {
+        match &self.previous {
+            Some(value) => std::env::set_var(self.key, value),
+            None => std::env::remove_var(self.key),
+        }
+    }
+}

+ 86 - 0
rust/crates/api/tests/provider_client_integration.rs

@@ -0,0 +1,86 @@
+use std::ffi::OsString;
+use std::sync::{Mutex, OnceLock};
+
+use api::{read_xai_base_url, ApiError, AuthSource, ProviderClient, ProviderKind};
+
+#[test]
+fn provider_client_routes_grok_aliases_through_xai() {
+    let _lock = env_lock();
+    let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", Some("xai-test-key"));
+
+    let client = ProviderClient::from_model("grok-mini").expect("grok alias should resolve");
+
+    assert_eq!(client.provider_kind(), ProviderKind::Xai);
+}
+
+#[test]
+fn provider_client_reports_missing_xai_credentials_for_grok_models() {
+    let _lock = env_lock();
+    let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", None);
+
+    let error = ProviderClient::from_model("grok-3")
+        .expect_err("grok requests without XAI_API_KEY should fail fast");
+
+    match error {
+        ApiError::MissingCredentials { provider, env_vars } => {
+            assert_eq!(provider, "xAI");
+            assert_eq!(env_vars, &["XAI_API_KEY"]);
+        }
+        other => panic!("expected missing xAI credentials, got {other:?}"),
+    }
+}
+
+#[test]
+fn provider_client_uses_explicit_anthropic_auth_without_env_lookup() {
+    let _lock = env_lock();
+    let _anthropic_api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None);
+    let _anthropic_auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None);
+
+    let client = ProviderClient::from_model_with_anthropic_auth(
+        "claude-sonnet-4-6",
+        Some(AuthSource::ApiKey("anthropic-test-key".to_string())),
+    )
+    .expect("explicit anthropic auth should avoid env lookup");
+
+    assert_eq!(client.provider_kind(), ProviderKind::Anthropic);
+}
+
+#[test]
+fn read_xai_base_url_prefers_env_override() {
+    let _lock = env_lock();
+    let _xai_base_url = EnvVarGuard::set("XAI_BASE_URL", Some("https://example.xai.test/v1"));
+
+    assert_eq!(read_xai_base_url(), "https://example.xai.test/v1");
+}
+
+fn env_lock() -> std::sync::MutexGuard<'static, ()> {
+    static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
+    LOCK.get_or_init(|| Mutex::new(()))
+        .lock()
+        .unwrap_or_else(|poisoned| poisoned.into_inner())
+}
+
+struct EnvVarGuard {
+    key: &'static str,
+    original: Option<OsString>,
+}
+
+impl EnvVarGuard {
+    fn set(key: &'static str, value: Option<&str>) -> Self {
+        let original = std::env::var_os(key);
+        match value {
+            Some(value) => std::env::set_var(key, value),
+            None => std::env::remove_var(key),
+        }
+        Self { key, original }
+    }
+}
+
+impl Drop for EnvVarGuard {
+    fn drop(&mut self) {
+        match &self.original {
+            Some(value) => std::env::set_var(self.key, value),
+            None => std::env::remove_var(self.key),
+        }
+    }
+}

+ 6 - 6
rust/crates/runtime/src/conversation.rs

@@ -118,7 +118,7 @@ where
             tool_executor,
             permission_policy,
             system_prompt,
-            RuntimeFeatureConfig::default(),
+            &RuntimeFeatureConfig::default(),
         )
     }
 
@@ -129,7 +129,7 @@ where
         tool_executor: T,
         permission_policy: PermissionPolicy,
         system_prompt: Vec<String>,
-        feature_config: RuntimeFeatureConfig,
+        feature_config: &RuntimeFeatureConfig,
     ) -> Self {
         let usage_tracker = UsageTracker::from_session(&session);
         Self {
@@ -140,7 +140,7 @@ where
             system_prompt,
             max_iterations: usize::MAX,
             usage_tracker,
-            hook_runner: HookRunner::from_feature_config(&feature_config),
+            hook_runner: HookRunner::from_feature_config(feature_config),
         }
     }
 
@@ -609,7 +609,7 @@ mod tests {
             }),
             PermissionPolicy::new(PermissionMode::DangerFullAccess),
             vec!["system".to_string()],
-            RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
+            &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
                 vec![shell_snippet("printf 'blocked by hook'; exit 2")],
                 Vec::new(),
             )),
@@ -675,7 +675,7 @@ mod tests {
             StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())),
             PermissionPolicy::new(PermissionMode::DangerFullAccess),
             vec!["system".to_string()],
-            RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
+            &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new(
                 vec![shell_snippet("printf 'pre hook ran'")],
                 vec![shell_snippet("printf 'post hook ran'")],
             )),
@@ -697,7 +697,7 @@ mod tests {
             "post hook should preserve non-error result: {output:?}"
         );
         assert!(
-            output.contains("4"),
+            output.contains('4'),
             "tool output missing value: {output:?}"
         );
         assert!(

+ 35 - 27
rust/crates/runtime/src/hooks.rs

@@ -51,6 +51,16 @@ pub struct HookRunner {
     config: RuntimeHookConfig,
 }
 
+#[derive(Debug, Clone, Copy)]
+struct HookCommandRequest<'a> {
+    event: HookEvent,
+    tool_name: &'a str,
+    tool_input: &'a str,
+    tool_output: Option<&'a str>,
+    is_error: bool,
+    payload: &'a str,
+}
+
 impl HookRunner {
     #[must_use]
     pub fn new(config: RuntimeHookConfig) -> Self {
@@ -118,14 +128,16 @@ impl HookRunner {
         let mut messages = Vec::new();
 
         for command in commands {
-            match self.run_command(
+            match Self::run_command(
                 command,
-                event,
-                tool_name,
-                tool_input,
-                tool_output,
-                is_error,
-                &payload,
+                HookCommandRequest {
+                    event,
+                    tool_name,
+                    tool_input,
+                    tool_output,
+                    is_error,
+                    payload: &payload,
+                },
             ) {
                 HookCommandOutcome::Allow { message } => {
                     if let Some(message) = message {
@@ -149,29 +161,23 @@ impl HookRunner {
         HookRunResult::allow(messages)
     }
 
-    fn run_command(
-        &self,
-        command: &str,
-        event: HookEvent,
-        tool_name: &str,
-        tool_input: &str,
-        tool_output: Option<&str>,
-        is_error: bool,
-        payload: &str,
-    ) -> HookCommandOutcome {
+    fn run_command(command: &str, request: HookCommandRequest<'_>) -> HookCommandOutcome {
         let mut child = shell_command(command);
         child.stdin(std::process::Stdio::piped());
         child.stdout(std::process::Stdio::piped());
         child.stderr(std::process::Stdio::piped());
-        child.env("HOOK_EVENT", event.as_str());
-        child.env("HOOK_TOOL_NAME", tool_name);
-        child.env("HOOK_TOOL_INPUT", tool_input);
-        child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" });
-        if let Some(tool_output) = tool_output {
+        child.env("HOOK_EVENT", request.event.as_str());
+        child.env("HOOK_TOOL_NAME", request.tool_name);
+        child.env("HOOK_TOOL_INPUT", request.tool_input);
+        child.env(
+            "HOOK_TOOL_IS_ERROR",
+            if request.is_error { "1" } else { "0" },
+        );
+        if let Some(tool_output) = request.tool_output {
             child.env("HOOK_TOOL_OUTPUT", tool_output);
         }
 
-        match child.output_with_stdin(payload.as_bytes()) {
+        match child.output_with_stdin(request.payload.as_bytes()) {
             Ok(output) => {
                 let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
                 let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
@@ -189,16 +195,18 @@ impl HookRunner {
                     },
                     None => HookCommandOutcome::Warn {
                         message: format!(
-                            "{} hook `{command}` terminated by signal while handling `{tool_name}`",
-                            event.as_str()
+                            "{} hook `{command}` terminated by signal while handling `{}`",
+                            request.event.as_str(),
+                            request.tool_name
                         ),
                     },
                 }
             }
             Err(error) => HookCommandOutcome::Warn {
                 message: format!(
-                    "{} hook `{command}` failed to start for `{tool_name}`: {error}",
-                    event.as_str()
+                    "{} hook `{command}` failed to start for `{}`: {error}",
+                    request.event.as_str(),
+                    request.tool_name
                 ),
             },
         }