Explorar o código

feat: provider abstraction layer + Grok API support

Yeachan-Heo hai 2 meses
pai
achega
2a0f4b677a

+ 80 - 932
rust/crates/api/src/client.rs

@@ -1,994 +1,142 @@
-use std::collections::VecDeque;
-use std::time::{Duration, SystemTime, UNIX_EPOCH};
-
-use runtime::{
-    load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
-    OAuthTokenExchangeRequest,
-};
-use serde::Deserialize;
-
 use crate::error::ApiError;
-use crate::sse::SseParser;
+use crate::providers::anthropic::{self, AnthropicClient, AuthSource};
+use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig};
+use crate::providers::{self, Provider, ProviderKind};
 use crate::types::{MessageRequest, MessageResponse, StreamEvent};
 
-const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
-const ANTHROPIC_VERSION: &str = "2023-06-01";
-const REQUEST_ID_HEADER: &str = "request-id";
-const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
-const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
-const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
-const DEFAULT_MAX_RETRIES: u32 = 2;
-
-#[derive(Debug, Clone, PartialEq, Eq)]
-pub enum AuthSource {
-    None,
-    ApiKey(String),
-    BearerToken(String),
-    ApiKeyAndBearer {
-        api_key: String,
-        bearer_token: String,
-    },
-}
-
-impl AuthSource {
-    pub fn from_env() -> Result<Self, ApiError> {
-        let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?;
-        let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?;
-        match (api_key, auth_token) {
-            (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer {
-                api_key,
-                bearer_token,
-            }),
-            (Some(api_key), None) => Ok(Self::ApiKey(api_key)),
-            (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
-            (None, None) => Err(ApiError::MissingApiKey),
-        }
-    }
-
-    #[must_use]
-    pub fn api_key(&self) -> Option<&str> {
-        match self {
-            Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
-            Self::None | Self::BearerToken(_) => None,
-        }
-    }
-
-    #[must_use]
-    pub fn bearer_token(&self) -> Option<&str> {
-        match self {
-            Self::BearerToken(token)
-            | Self::ApiKeyAndBearer {
-                bearer_token: token,
-                ..
-            } => Some(token),
-            Self::None | Self::ApiKey(_) => None,
-        }
-    }
-
-    #[must_use]
-    pub fn masked_authorization_header(&self) -> &'static str {
-        if self.bearer_token().is_some() {
-            "Bearer [REDACTED]"
-        } else {
-            "<absent>"
-        }
-    }
-
-    pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
-        if let Some(api_key) = self.api_key() {
-            request_builder = request_builder.header("x-api-key", api_key);
-        }
-        if let Some(token) = self.bearer_token() {
-            request_builder = request_builder.bearer_auth(token);
-        }
-        request_builder
-    }
+async fn send_via_provider<P: Provider>(
+    provider: &P,
+    request: &MessageRequest,
+) -> Result<MessageResponse, ApiError> {
+    provider.send_message(request).await
 }
 
-#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
-pub struct OAuthTokenSet {
-    pub access_token: String,
-    pub refresh_token: Option<String>,
-    pub expires_at: Option<u64>,
-    #[serde(default)]
-    pub scopes: Vec<String>,
-}
-
-impl From<OAuthTokenSet> for AuthSource {
-    fn from(value: OAuthTokenSet) -> Self {
-        Self::BearerToken(value.access_token)
-    }
+async fn stream_via_provider<P: Provider>(
+    provider: &P,
+    request: &MessageRequest,
+) -> Result<P::Stream, ApiError> {
+    provider.stream_message(request).await
 }
 
 #[derive(Debug, Clone)]
-pub struct AnthropicClient {
-    http: reqwest::Client,
-    auth: AuthSource,
-    base_url: String,
-    max_retries: u32,
-    initial_backoff: Duration,
-    max_backoff: Duration,
+pub enum ProviderClient {
+    Anthropic(AnthropicClient),
+    Xai(OpenAiCompatClient),
+    OpenAi(OpenAiCompatClient),
 }
 
-impl AnthropicClient {
-    #[must_use]
-    pub fn new(api_key: impl Into<String>) -> Self {
-        Self {
-            http: reqwest::Client::new(),
-            auth: AuthSource::ApiKey(api_key.into()),
-            base_url: DEFAULT_BASE_URL.to_string(),
-            max_retries: DEFAULT_MAX_RETRIES,
-            initial_backoff: DEFAULT_INITIAL_BACKOFF,
-            max_backoff: DEFAULT_MAX_BACKOFF,
-        }
+impl ProviderClient {
+    pub fn from_model(model: &str) -> Result<Self, ApiError> {
+        Self::from_model_with_anthropic_auth(model, None)
     }
 
-    #[must_use]
-    pub fn from_auth(auth: AuthSource) -> Self {
-        Self {
-            http: reqwest::Client::new(),
-            auth,
-            base_url: DEFAULT_BASE_URL.to_string(),
-            max_retries: DEFAULT_MAX_RETRIES,
-            initial_backoff: DEFAULT_INITIAL_BACKOFF,
-            max_backoff: DEFAULT_MAX_BACKOFF,
+    pub fn from_model_with_anthropic_auth(
+        model: &str,
+        anthropic_auth: Option<AuthSource>,
+    ) -> Result<Self, ApiError> {
+        let resolved_model = providers::resolve_model_alias(model);
+        match providers::detect_provider_kind(&resolved_model) {
+            ProviderKind::Anthropic => Ok(Self::Anthropic(
+                anthropic_auth
+                    .map(AnthropicClient::from_auth)
+                    .unwrap_or(AnthropicClient::from_env()?),
+            )),
+            ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env(
+                OpenAiCompatConfig::xai(),
+            )?)),
+            ProviderKind::OpenAi => Ok(Self::OpenAi(OpenAiCompatClient::from_env(
+                OpenAiCompatConfig::openai(),
+            )?)),
         }
     }
 
-    pub fn from_env() -> Result<Self, ApiError> {
-        Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
-    }
-
-    #[must_use]
-    pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
-        self.auth = auth;
-        self
-    }
-
     #[must_use]
-    pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
-        match (
-            self.auth.api_key().map(ToOwned::to_owned),
-            auth_token.filter(|token| !token.is_empty()),
-        ) {
-            (Some(api_key), Some(bearer_token)) => {
-                self.auth = AuthSource::ApiKeyAndBearer {
-                    api_key,
-                    bearer_token,
-                };
-            }
-            (Some(api_key), None) => {
-                self.auth = AuthSource::ApiKey(api_key);
-            }
-            (None, Some(bearer_token)) => {
-                self.auth = AuthSource::BearerToken(bearer_token);
-            }
-            (None, None) => {
-                self.auth = AuthSource::None;
-            }
+    pub const fn provider_kind(&self) -> ProviderKind {
+        match self {
+            Self::Anthropic(_) => ProviderKind::Anthropic,
+            Self::Xai(_) => ProviderKind::Xai,
+            Self::OpenAi(_) => ProviderKind::OpenAi,
         }
-        self
-    }
-
-    #[must_use]
-    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
-        self.base_url = base_url.into();
-        self
-    }
-
-    #[must_use]
-    pub fn with_retry_policy(
-        mut self,
-        max_retries: u32,
-        initial_backoff: Duration,
-        max_backoff: Duration,
-    ) -> Self {
-        self.max_retries = max_retries;
-        self.initial_backoff = initial_backoff;
-        self.max_backoff = max_backoff;
-        self
-    }
-
-    #[must_use]
-    pub fn auth_source(&self) -> &AuthSource {
-        &self.auth
     }
 
     pub async fn send_message(
         &self,
         request: &MessageRequest,
     ) -> Result<MessageResponse, ApiError> {
-        let request = MessageRequest {
-            stream: false,
-            ..request.clone()
-        };
-        let response = self.send_with_retry(&request).await?;
-        let request_id = request_id_from_headers(response.headers());
-        let mut response = response
-            .json::<MessageResponse>()
-            .await
-            .map_err(ApiError::from)?;
-        if response.request_id.is_none() {
-            response.request_id = request_id;
+        match self {
+            Self::Anthropic(client) => send_via_provider(client, request).await,
+            Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await,
         }
-        Ok(response)
     }
 
     pub async fn stream_message(
         &self,
         request: &MessageRequest,
     ) -> Result<MessageStream, ApiError> {
-        let response = self
-            .send_with_retry(&request.clone().with_streaming())
-            .await?;
-        Ok(MessageStream {
-            request_id: request_id_from_headers(response.headers()),
-            response,
-            parser: SseParser::new(),
-            pending: VecDeque::new(),
-            done: false,
-        })
-    }
-
-    pub async fn exchange_oauth_code(
-        &self,
-        config: &OAuthConfig,
-        request: &OAuthTokenExchangeRequest,
-    ) -> Result<OAuthTokenSet, ApiError> {
-        let response = self
-            .http
-            .post(&config.token_url)
-            .header("content-type", "application/x-www-form-urlencoded")
-            .form(&request.form_params())
-            .send()
-            .await
-            .map_err(ApiError::from)?;
-        let response = expect_success(response).await?;
-        response
-            .json::<OAuthTokenSet>()
-            .await
-            .map_err(ApiError::from)
-    }
-
-    pub async fn refresh_oauth_token(
-        &self,
-        config: &OAuthConfig,
-        request: &OAuthRefreshRequest,
-    ) -> Result<OAuthTokenSet, ApiError> {
-        let response = self
-            .http
-            .post(&config.token_url)
-            .header("content-type", "application/x-www-form-urlencoded")
-            .form(&request.form_params())
-            .send()
-            .await
-            .map_err(ApiError::from)?;
-        let response = expect_success(response).await?;
-        response
-            .json::<OAuthTokenSet>()
-            .await
-            .map_err(ApiError::from)
-    }
-
-    async fn send_with_retry(
-        &self,
-        request: &MessageRequest,
-    ) -> Result<reqwest::Response, ApiError> {
-        let mut attempts = 0;
-        let mut last_error: Option<ApiError>;
-
-        loop {
-            attempts += 1;
-            match self.send_raw_request(request).await {
-                Ok(response) => match expect_success(response).await {
-                    Ok(response) => return Ok(response),
-                    Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
-                        last_error = Some(error);
-                    }
-                    Err(error) => return Err(error),
-                },
-                Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
-                    last_error = Some(error);
-                }
-                Err(error) => return Err(error),
-            }
-
-            if attempts > self.max_retries {
-                break;
-            }
-
-            tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
-        }
-
-        Err(ApiError::RetriesExhausted {
-            attempts,
-            last_error: Box::new(last_error.expect("retry loop must capture an error")),
-        })
-    }
-
-    async fn send_raw_request(
-        &self,
-        request: &MessageRequest,
-    ) -> Result<reqwest::Response, ApiError> {
-        let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
-        let request_builder = self
-            .http
-            .post(&request_url)
-            .header("anthropic-version", ANTHROPIC_VERSION)
-            .header("content-type", "application/json");
-        let mut request_builder = self.auth.apply(request_builder);
-
-        request_builder = request_builder.json(request);
-        request_builder.send().await.map_err(ApiError::from)
-    }
-
-    fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
-        let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
-            return Err(ApiError::BackoffOverflow {
-                attempt,
-                base_delay: self.initial_backoff,
-            });
-        };
-        Ok(self
-            .initial_backoff
-            .checked_mul(multiplier)
-            .map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
-    }
-}
-
-impl AuthSource {
-    pub fn from_env_or_saved() -> Result<Self, ApiError> {
-        if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
-            return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
-                Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
-                    api_key,
-                    bearer_token,
-                }),
-                None => Ok(Self::ApiKey(api_key)),
-            };
-        }
-        if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
-            return Ok(Self::BearerToken(bearer_token));
-        }
-        match load_saved_oauth_token() {
-            Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
-                if token_set.refresh_token.is_some() {
-                    Err(ApiError::Auth(
-                        "saved OAuth token is expired; load runtime OAuth config to refresh it"
-                            .to_string(),
-                    ))
-                } else {
-                    Err(ApiError::ExpiredOAuthToken)
-                }
-            }
-            Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
-            Ok(None) => Err(ApiError::MissingApiKey),
-            Err(error) => Err(error),
+        match self {
+            Self::Anthropic(client) => stream_via_provider(client, request)
+                .await
+                .map(MessageStream::Anthropic),
+            Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request)
+                .await
+                .map(MessageStream::OpenAiCompat),
         }
     }
 }
 
-#[must_use]
-pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
-    token_set
-        .expires_at
-        .is_some_and(|expires_at| expires_at <= now_unix_timestamp())
-}
-
-pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
-    let Some(token_set) = load_saved_oauth_token()? else {
-        return Ok(None);
-    };
-    resolve_saved_oauth_token_set(config, token_set).map(Some)
-}
-
-pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
-where
-    F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
-{
-    if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
-        return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
-            Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
-                api_key,
-                bearer_token,
-            }),
-            None => Ok(AuthSource::ApiKey(api_key)),
-        };
-    }
-    if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
-        return Ok(AuthSource::BearerToken(bearer_token));
-    }
-
-    let Some(token_set) = load_saved_oauth_token()? else {
-        return Err(ApiError::MissingApiKey);
-    };
-    if !oauth_token_is_expired(&token_set) {
-        return Ok(AuthSource::BearerToken(token_set.access_token));
-    }
-    if token_set.refresh_token.is_none() {
-        return Err(ApiError::ExpiredOAuthToken);
-    }
-
-    let Some(config) = load_oauth_config()? else {
-        return Err(ApiError::Auth(
-            "saved OAuth token is expired; runtime OAuth config is missing".to_string(),
-        ));
-    };
-    Ok(AuthSource::from(resolve_saved_oauth_token_set(
-        &config, token_set,
-    )?))
-}
-
-fn resolve_saved_oauth_token_set(
-    config: &OAuthConfig,
-    token_set: OAuthTokenSet,
-) -> Result<OAuthTokenSet, ApiError> {
-    if !oauth_token_is_expired(&token_set) {
-        return Ok(token_set);
-    }
-    let Some(refresh_token) = token_set.refresh_token.clone() else {
-        return Err(ApiError::ExpiredOAuthToken);
-    };
-    let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url());
-    let refreshed = client_runtime_block_on(async {
-        client
-            .refresh_oauth_token(
-                config,
-                &OAuthRefreshRequest::from_config(
-                    config,
-                    refresh_token,
-                    Some(token_set.scopes.clone()),
-                ),
-            )
-            .await
-    })?;
-    let resolved = OAuthTokenSet {
-        access_token: refreshed.access_token,
-        refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
-        expires_at: refreshed.expires_at,
-        scopes: refreshed.scopes,
-    };
-    save_oauth_credentials(&runtime::OAuthTokenSet {
-        access_token: resolved.access_token.clone(),
-        refresh_token: resolved.refresh_token.clone(),
-        expires_at: resolved.expires_at,
-        scopes: resolved.scopes.clone(),
-    })
-    .map_err(ApiError::from)?;
-    Ok(resolved)
-}
-
-fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
-where
-    F: std::future::Future<Output = Result<T, ApiError>>,
-{
-    tokio::runtime::Runtime::new()
-        .map_err(ApiError::from)?
-        .block_on(future)
-}
-
-fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
-    let token_set = load_oauth_credentials().map_err(ApiError::from)?;
-    Ok(token_set.map(|token_set| OAuthTokenSet {
-        access_token: token_set.access_token,
-        refresh_token: token_set.refresh_token,
-        expires_at: token_set.expires_at,
-        scopes: token_set.scopes,
-    }))
-}
-
-fn now_unix_timestamp() -> u64 {
-    SystemTime::now()
-        .duration_since(UNIX_EPOCH)
-        .map_or(0, |duration| duration.as_secs())
-}
-
-fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
-    match std::env::var(key) {
-        Ok(value) if !value.is_empty() => Ok(Some(value)),
-        Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
-        Err(error) => Err(ApiError::from(error)),
-    }
-}
-
-#[cfg(test)]
-fn read_api_key() -> Result<String, ApiError> {
-    let auth = AuthSource::from_env_or_saved()?;
-    auth.api_key()
-        .or_else(|| auth.bearer_token())
-        .map(ToOwned::to_owned)
-        .ok_or(ApiError::MissingApiKey)
-}
-
-#[cfg(test)]
-fn read_auth_token() -> Option<String> {
-    read_env_non_empty("ANTHROPIC_AUTH_TOKEN")
-        .ok()
-        .and_then(std::convert::identity)
-}
-
-#[must_use]
-pub fn read_base_url() -> String {
-    std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string())
-}
-
-fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
-    headers
-        .get(REQUEST_ID_HEADER)
-        .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
-        .and_then(|value| value.to_str().ok())
-        .map(ToOwned::to_owned)
-}
-
 #[derive(Debug)]
-pub struct MessageStream {
-    request_id: Option<String>,
-    response: reqwest::Response,
-    parser: SseParser,
-    pending: VecDeque<StreamEvent>,
-    done: bool,
+pub enum MessageStream {
+    Anthropic(anthropic::MessageStream),
+    OpenAiCompat(openai_compat::MessageStream),
 }
 
 impl MessageStream {
     #[must_use]
     pub fn request_id(&self) -> Option<&str> {
-        self.request_id.as_deref()
+        match self {
+            Self::Anthropic(stream) => stream.request_id(),
+            Self::OpenAiCompat(stream) => stream.request_id(),
+        }
     }
 
     pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
-        loop {
-            if let Some(event) = self.pending.pop_front() {
-                return Ok(Some(event));
-            }
-
-            if self.done {
-                let remaining = self.parser.finish()?;
-                self.pending.extend(remaining);
-                if let Some(event) = self.pending.pop_front() {
-                    return Ok(Some(event));
-                }
-                return Ok(None);
-            }
-
-            match self.response.chunk().await? {
-                Some(chunk) => {
-                    self.pending.extend(self.parser.push(&chunk)?);
-                }
-                None => {
-                    self.done = true;
-                }
-            }
+        match self {
+            Self::Anthropic(stream) => stream.next_event().await,
+            Self::OpenAiCompat(stream) => stream.next_event().await,
         }
     }
 }
 
-async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
-    let status = response.status();
-    if status.is_success() {
-        return Ok(response);
-    }
-
-    let body = response.text().await.unwrap_or_else(|_| String::new());
-    let parsed_error = serde_json::from_str::<AnthropicErrorEnvelope>(&body).ok();
-    let retryable = is_retryable_status(status);
-
-    Err(ApiError::Api {
-        status,
-        error_type: parsed_error
-            .as_ref()
-            .map(|error| error.error.error_type.clone()),
-        message: parsed_error
-            .as_ref()
-            .map(|error| error.error.message.clone()),
-        body,
-        retryable,
-    })
-}
-
-const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
-    matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
-}
-
-#[derive(Debug, Deserialize)]
-struct AnthropicErrorEnvelope {
-    error: AnthropicErrorBody,
+pub use anthropic::{
+    oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, OAuthTokenSet,
+};
+#[must_use]
+pub fn read_base_url() -> String {
+    anthropic::read_base_url()
 }
 
-#[derive(Debug, Deserialize)]
-struct AnthropicErrorBody {
-    #[serde(rename = "type")]
-    error_type: String,
-    message: String,
+#[must_use]
+pub fn read_xai_base_url() -> String {
+    openai_compat::read_base_url(OpenAiCompatConfig::xai())
 }
 
 #[cfg(test)]
 mod tests {
-    use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
-    use std::io::{Read, Write};
-    use std::net::TcpListener;
-    use std::sync::{Mutex, OnceLock};
-    use std::thread;
-    use std::time::{Duration, SystemTime, UNIX_EPOCH};
-
-    use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig};
-
-    use crate::client::{
-        now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
-        resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
-    };
-    use crate::types::{ContentBlockDelta, MessageRequest};
-
-    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
-        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
-        LOCK.get_or_init(|| Mutex::new(()))
-            .lock()
-            .expect("env lock")
-    }
-
-    fn temp_config_home() -> std::path::PathBuf {
-        std::env::temp_dir().join(format!(
-            "api-oauth-test-{}-{}",
-            std::process::id(),
-            SystemTime::now()
-                .duration_since(UNIX_EPOCH)
-                .expect("time")
-                .as_nanos()
-        ))
-    }
-
-    fn sample_oauth_config(token_url: String) -> OAuthConfig {
-        OAuthConfig {
-            client_id: "runtime-client".to_string(),
-            authorize_url: "https://console.test/oauth/authorize".to_string(),
-            token_url,
-            callback_port: Some(4545),
-            manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
-            scopes: vec!["org:read".to_string(), "user:write".to_string()],
-        }
-    }
-
-    fn spawn_token_server(response_body: &'static str) -> String {
-        let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
-        let address = listener.local_addr().expect("local addr");
-        thread::spawn(move || {
-            let (mut stream, _) = listener.accept().expect("accept connection");
-            let mut buffer = [0_u8; 4096];
-            let _ = stream.read(&mut buffer).expect("read request");
-            let response = format!(
-                "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
-                response_body.len(),
-                response_body
-            );
-            stream
-                .write_all(response.as_bytes())
-                .expect("write response");
-        });
-        format!("http://{address}/oauth/token")
-    }
-
-    #[test]
-    fn read_api_key_requires_presence() {
-        let _guard = env_lock();
-        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
-        std::env::remove_var("ANTHROPIC_API_KEY");
-        std::env::remove_var("CLAUDE_CONFIG_HOME");
-        let error = super::read_api_key().expect_err("missing key should error");
-        assert!(matches!(error, crate::error::ApiError::MissingApiKey));
-    }
-
-    #[test]
-    fn read_api_key_requires_non_empty_value() {
-        let _guard = env_lock();
-        std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
-        std::env::remove_var("ANTHROPIC_API_KEY");
-        let error = super::read_api_key().expect_err("empty key should error");
-        assert!(matches!(error, crate::error::ApiError::MissingApiKey));
-        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
-    }
-
-    #[test]
-    fn read_api_key_prefers_api_key_env() {
-        let _guard = env_lock();
-        std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
-        std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
-        assert_eq!(
-            super::read_api_key().expect("api key should load"),
-            "legacy-key"
-        );
-        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
-        std::env::remove_var("ANTHROPIC_API_KEY");
-    }
-
-    #[test]
-    fn read_auth_token_reads_auth_token_env() {
-        let _guard = env_lock();
-        std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
-        assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
-        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
-    }
-
-    #[test]
-    fn oauth_token_maps_to_bearer_auth_source() {
-        let auth = AuthSource::from(OAuthTokenSet {
-            access_token: "access-token".to_string(),
-            refresh_token: Some("refresh".to_string()),
-            expires_at: Some(123),
-            scopes: vec!["scope:a".to_string()],
-        });
-        assert_eq!(auth.bearer_token(), Some("access-token"));
-        assert_eq!(auth.api_key(), None);
-    }
-
-    #[test]
-    fn auth_source_from_env_combines_api_key_and_bearer_token() {
-        let _guard = env_lock();
-        std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
-        std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
-        let auth = AuthSource::from_env().expect("env auth");
-        assert_eq!(auth.api_key(), Some("legacy-key"));
-        assert_eq!(auth.bearer_token(), Some("auth-token"));
-        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
-        std::env::remove_var("ANTHROPIC_API_KEY");
-    }
-
-    #[test]
-    fn auth_source_from_saved_oauth_when_env_absent() {
-        let _guard = env_lock();
-        let config_home = temp_config_home();
-        std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
-        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
-        std::env::remove_var("ANTHROPIC_API_KEY");
-        save_oauth_credentials(&runtime::OAuthTokenSet {
-            access_token: "saved-access-token".to_string(),
-            refresh_token: Some("refresh".to_string()),
-            expires_at: Some(now_unix_timestamp() + 300),
-            scopes: vec!["scope:a".to_string()],
-        })
-        .expect("save oauth credentials");
-
-        let auth = AuthSource::from_env_or_saved().expect("saved auth");
-        assert_eq!(auth.bearer_token(), Some("saved-access-token"));
-
-        clear_oauth_credentials().expect("clear credentials");
-        std::env::remove_var("CLAUDE_CONFIG_HOME");
-        std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
-    }
+    use crate::providers::{detect_provider_kind, resolve_model_alias, ProviderKind};
 
     #[test]
-    fn oauth_token_expiry_uses_expires_at_timestamp() {
-        assert!(oauth_token_is_expired(&OAuthTokenSet {
-            access_token: "access-token".to_string(),
-            refresh_token: None,
-            expires_at: Some(1),
-            scopes: Vec::new(),
-        }));
-        assert!(!oauth_token_is_expired(&OAuthTokenSet {
-            access_token: "access-token".to_string(),
-            refresh_token: None,
-            expires_at: Some(now_unix_timestamp() + 60),
-            scopes: Vec::new(),
-        }));
+    fn resolves_existing_and_grok_aliases() {
+        assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6");
+        assert_eq!(resolve_model_alias("grok"), "grok-3");
+        assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini");
     }
 
     #[test]
-    fn resolve_saved_oauth_token_refreshes_expired_credentials() {
-        let _guard = env_lock();
-        let config_home = temp_config_home();
-        std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
-        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
-        std::env::remove_var("ANTHROPIC_API_KEY");
-        save_oauth_credentials(&runtime::OAuthTokenSet {
-            access_token: "expired-access-token".to_string(),
-            refresh_token: Some("refresh-token".to_string()),
-            expires_at: Some(1),
-            scopes: vec!["scope:a".to_string()],
-        })
-        .expect("save expired oauth credentials");
-
-        let token_url = spawn_token_server(
-            "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
-        );
-        let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
-            .expect("resolve refreshed token")
-            .expect("token set present");
-        assert_eq!(resolved.access_token, "refreshed-token");
-        let stored = runtime::load_oauth_credentials()
-            .expect("load stored credentials")
-            .expect("stored token set");
-        assert_eq!(stored.access_token, "refreshed-token");
-
-        clear_oauth_credentials().expect("clear credentials");
-        std::env::remove_var("CLAUDE_CONFIG_HOME");
-        std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
-    }
-
-    #[test]
-    fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
-        let _guard = env_lock();
-        let config_home = temp_config_home();
-        std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
-        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
-        std::env::remove_var("ANTHROPIC_API_KEY");
-        save_oauth_credentials(&runtime::OAuthTokenSet {
-            access_token: "saved-access-token".to_string(),
-            refresh_token: Some("refresh".to_string()),
-            expires_at: Some(now_unix_timestamp() + 300),
-            scopes: vec!["scope:a".to_string()],
-        })
-        .expect("save oauth credentials");
-
-        let auth = resolve_startup_auth_source(|| panic!("config should not be loaded"))
-            .expect("startup auth");
-        assert_eq!(auth.bearer_token(), Some("saved-access-token"));
-
-        clear_oauth_credentials().expect("clear credentials");
-        std::env::remove_var("CLAUDE_CONFIG_HOME");
-        std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
-    }
-
-    #[test]
-    fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
-        let _guard = env_lock();
-        let config_home = temp_config_home();
-        std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
-        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
-        std::env::remove_var("ANTHROPIC_API_KEY");
-        save_oauth_credentials(&runtime::OAuthTokenSet {
-            access_token: "expired-access-token".to_string(),
-            refresh_token: Some("refresh-token".to_string()),
-            expires_at: Some(1),
-            scopes: vec!["scope:a".to_string()],
-        })
-        .expect("save expired oauth credentials");
-
-        let error =
-            resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error");
-        assert!(
-            matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing"))
-        );
-
-        let stored = runtime::load_oauth_credentials()
-            .expect("load stored credentials")
-            .expect("stored token set");
-        assert_eq!(stored.access_token, "expired-access-token");
-        assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
-
-        clear_oauth_credentials().expect("clear credentials");
-        std::env::remove_var("CLAUDE_CONFIG_HOME");
-        std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
-    }
-
-    #[test]
-    fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
-        let _guard = env_lock();
-        let config_home = temp_config_home();
-        std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
-        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
-        std::env::remove_var("ANTHROPIC_API_KEY");
-        save_oauth_credentials(&runtime::OAuthTokenSet {
-            access_token: "expired-access-token".to_string(),
-            refresh_token: Some("refresh-token".to_string()),
-            expires_at: Some(1),
-            scopes: vec!["scope:a".to_string()],
-        })
-        .expect("save expired oauth credentials");
-
-        let token_url = spawn_token_server(
-            "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
-        );
-        let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
-            .expect("resolve refreshed token")
-            .expect("token set present");
-        assert_eq!(resolved.access_token, "refreshed-token");
-        assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token"));
-        let stored = runtime::load_oauth_credentials()
-            .expect("load stored credentials")
-            .expect("stored token set");
-        assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
-
-        clear_oauth_credentials().expect("clear credentials");
-        std::env::remove_var("CLAUDE_CONFIG_HOME");
-        std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
-    }
-
-    #[test]
-    fn message_request_stream_helper_sets_stream_true() {
-        let request = MessageRequest {
-            model: "claude-opus-4-6".to_string(),
-            max_tokens: 64,
-            messages: vec![],
-            system: None,
-            tools: None,
-            tool_choice: None,
-            stream: false,
-        };
-
-        assert!(request.with_streaming().stream);
-    }
-
-    #[test]
-    fn backoff_doubles_until_maximum() {
-        let client = AnthropicClient::new("test-key").with_retry_policy(
-            3,
-            Duration::from_millis(10),
-            Duration::from_millis(25),
-        );
-        assert_eq!(
-            client.backoff_for_attempt(1).expect("attempt 1"),
-            Duration::from_millis(10)
-        );
-        assert_eq!(
-            client.backoff_for_attempt(2).expect("attempt 2"),
-            Duration::from_millis(20)
-        );
-        assert_eq!(
-            client.backoff_for_attempt(3).expect("attempt 3"),
-            Duration::from_millis(25)
-        );
-    }
-
-    #[test]
-    fn retryable_statuses_are_detected() {
-        assert!(super::is_retryable_status(
-            reqwest::StatusCode::TOO_MANY_REQUESTS
-        ));
-        assert!(super::is_retryable_status(
-            reqwest::StatusCode::INTERNAL_SERVER_ERROR
-        ));
-        assert!(!super::is_retryable_status(
-            reqwest::StatusCode::UNAUTHORIZED
-        ));
-    }
-
-    #[test]
-    fn tool_delta_variant_round_trips() {
-        let delta = ContentBlockDelta::InputJsonDelta {
-            partial_json: "{\"city\":\"Paris\"}".to_string(),
-        };
-        let encoded = serde_json::to_string(&delta).expect("delta should serialize");
-        let decoded: ContentBlockDelta =
-            serde_json::from_str(&encoded).expect("delta should deserialize");
-        assert_eq!(decoded, delta);
-    }
-
-    #[test]
-    fn request_id_uses_primary_or_fallback_header() {
-        let mut headers = reqwest::header::HeaderMap::new();
-        headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header"));
-        assert_eq!(
-            super::request_id_from_headers(&headers).as_deref(),
-            Some("req_primary")
-        );
-
-        headers.clear();
-        headers.insert(
-            ALT_REQUEST_ID_HEADER,
-            "req_fallback".parse().expect("header"),
-        );
-        assert_eq!(
-            super::request_id_from_headers(&headers).as_deref(),
-            Some("req_fallback")
-        );
-    }
-
-    #[test]
-    fn auth_source_applies_headers() {
-        let auth = AuthSource::ApiKeyAndBearer {
-            api_key: "test-key".to_string(),
-            bearer_token: "proxy-token".to_string(),
-        };
-        let request = auth
-            .apply(reqwest::Client::new().post("https://example.test"))
-            .build()
-            .expect("request build");
-        let headers = request.headers();
-        assert_eq!(
-            headers.get("x-api-key").and_then(|v| v.to_str().ok()),
-            Some("test-key")
-        );
+    fn provider_detection_prefers_model_family() {
+        assert_eq!(detect_provider_kind("grok-3"), ProviderKind::Xai);
         assert_eq!(
-            headers.get("authorization").and_then(|v| v.to_str().ok()),
-            Some("Bearer proxy-token")
+            detect_provider_kind("claude-sonnet-4-6"),
+            ProviderKind::Anthropic
         );
     }
 }

+ 8 - 2
rust/crates/api/src/lib.rs

@@ -1,13 +1,19 @@
 mod client;
 mod error;
+mod providers;
 mod sse;
 mod types;
 
 pub use client::{
-    oauth_token_is_expired, read_base_url, resolve_saved_oauth_token, resolve_startup_auth_source,
-    AnthropicClient, AuthSource, MessageStream, OAuthTokenSet,
+    oauth_token_is_expired, read_base_url, read_xai_base_url, resolve_saved_oauth_token,
+    resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient,
 };
 pub use error::ApiError;
+pub use providers::anthropic::{AnthropicClient, AuthSource};
+pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig};
+pub use providers::{
+    detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind,
+};
 pub use sse::{parse_frame, SseParser};
 pub use types::{
     ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,

+ 51 - 7
rust/crates/api/src/providers/anthropic.rs

@@ -8,10 +8,12 @@ use runtime::{
 use serde::Deserialize;
 
 use crate::error::ApiError;
+
+use super::{Provider, ProviderFuture};
 use crate::sse::SseParser;
 use crate::types::{MessageRequest, MessageResponse, StreamEvent};
 
-const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
+pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
 const ANTHROPIC_VERSION: &str = "2023-06-01";
 const REQUEST_ID_HEADER: &str = "request-id";
 const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
@@ -41,7 +43,10 @@ impl AuthSource {
             }),
             (Some(api_key), None) => Ok(Self::ApiKey(api_key)),
             (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
-            (None, None) => Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])),
+            (None, None) => Err(ApiError::missing_credentials(
+                "Anthropic",
+                &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
+            )),
         }
     }
 
@@ -362,7 +367,10 @@ impl AuthSource {
                 }
             }
             Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
-            Ok(None) => Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])),
+            Ok(None) => Err(ApiError::missing_credentials(
+                "Anthropic",
+                &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
+            )),
             Err(error) => Err(error),
         }
     }
@@ -382,6 +390,12 @@ pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTok
     resolve_saved_oauth_token_set(config, token_set).map(Some)
 }
 
+pub fn has_auth_from_env_or_saved() -> Result<bool, ApiError> {
+    Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some()
+        || read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some()
+        || load_saved_oauth_token()?.is_some())
+}
+
 pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
 where
     F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
@@ -400,7 +414,10 @@ where
     }
 
     let Some(token_set) = load_saved_oauth_token()? else {
-        return Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"]));
+        return Err(ApiError::missing_credentials(
+            "Anthropic",
+            &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
+        ));
     };
     if !oauth_token_is_expired(&token_set) {
         return Ok(AuthSource::BearerToken(token_set.access_token));
@@ -497,7 +514,10 @@ fn read_api_key() -> Result<String, ApiError> {
     auth.api_key()
         .or_else(|| auth.bearer_token())
         .map(ToOwned::to_owned)
-        .ok_or(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"]))
+        .ok_or(ApiError::missing_credentials(
+            "Anthropic",
+            &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
+        ))
 }
 
 #[cfg(test)]
@@ -520,6 +540,24 @@ fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<Strin
         .map(ToOwned::to_owned)
 }
 
+impl Provider for AnthropicClient {
+    type Stream = MessageStream;
+
+    fn send_message<'a>(
+        &'a self,
+        request: &'a MessageRequest,
+    ) -> ProviderFuture<'a, MessageResponse> {
+        Box::pin(async move { self.send_message(request).await })
+    }
+
+    fn stream_message<'a>(
+        &'a self,
+        request: &'a MessageRequest,
+    ) -> ProviderFuture<'a, Self::Stream> {
+        Box::pin(async move { self.stream_message(request).await })
+    }
+}
+
 #[derive(Debug)]
 pub struct MessageStream {
     request_id: Option<String>,
@@ -673,7 +711,10 @@ mod tests {
         std::env::remove_var("ANTHROPIC_API_KEY");
         std::env::remove_var("CLAUDE_CONFIG_HOME");
         let error = super::read_api_key().expect_err("missing key should error");
-        assert!(matches!(error, crate::error::ApiError::MissingCredentials { .. }));
+        assert!(matches!(
+            error,
+            crate::error::ApiError::MissingCredentials { .. }
+        ));
     }
 
     #[test]
@@ -682,7 +723,10 @@ mod tests {
         std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
         std::env::remove_var("ANTHROPIC_API_KEY");
         let error = super::read_api_key().expect_err("empty key should error");
-        assert!(matches!(error, crate::error::ApiError::MissingCredentials { .. }));
+        assert!(matches!(
+            error,
+            crate::error::ApiError::MissingCredentials { .. }
+        ));
         std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
     }
 

+ 30 - 16
rust/crates/api/src/providers/mod.rs

@@ -12,9 +12,15 @@ pub type ProviderFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, ApiError>
 pub trait Provider {
     type Stream;
 
-    fn send_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, MessageResponse>;
-
-    fn stream_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, Self::Stream>;
+    fn send_message<'a>(
+        &'a self,
+        request: &'a MessageRequest,
+    ) -> ProviderFuture<'a, MessageResponse>;
+
+    fn stream_message<'a>(
+        &'a self,
+        request: &'a MessageRequest,
+    ) -> ProviderFuture<'a, Self::Stream>;
 }
 
 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -27,7 +33,6 @@ pub enum ProviderKind {
 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
 pub struct ProviderMetadata {
     pub provider: ProviderKind,
-    pub canonical_model: &'static str,
     pub auth_env: &'static str,
     pub base_url_env: &'static str,
     pub default_base_url: &'static str,
@@ -38,7 +43,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
         "opus",
         ProviderMetadata {
             provider: ProviderKind::Anthropic,
-            canonical_model: "claude-opus-4-6",
             auth_env: "ANTHROPIC_API_KEY",
             base_url_env: "ANTHROPIC_BASE_URL",
             default_base_url: anthropic::DEFAULT_BASE_URL,
@@ -48,7 +52,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
         "sonnet",
         ProviderMetadata {
             provider: ProviderKind::Anthropic,
-            canonical_model: "claude-sonnet-4-6",
             auth_env: "ANTHROPIC_API_KEY",
             base_url_env: "ANTHROPIC_BASE_URL",
             default_base_url: anthropic::DEFAULT_BASE_URL,
@@ -58,7 +61,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
         "haiku",
         ProviderMetadata {
             provider: ProviderKind::Anthropic,
-            canonical_model: "claude-haiku-4-5-20251213",
             auth_env: "ANTHROPIC_API_KEY",
             base_url_env: "ANTHROPIC_BASE_URL",
             default_base_url: anthropic::DEFAULT_BASE_URL,
@@ -68,7 +70,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
         "grok",
         ProviderMetadata {
             provider: ProviderKind::Xai,
-            canonical_model: "grok-3",
             auth_env: "XAI_API_KEY",
             base_url_env: "XAI_BASE_URL",
             default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
@@ -78,7 +79,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
         "grok-3",
         ProviderMetadata {
             provider: ProviderKind::Xai,
-            canonical_model: "grok-3",
             auth_env: "XAI_API_KEY",
             base_url_env: "XAI_BASE_URL",
             default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
@@ -88,7 +88,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
         "grok-mini",
         ProviderMetadata {
             provider: ProviderKind::Xai,
-            canonical_model: "grok-3-mini",
             auth_env: "XAI_API_KEY",
             base_url_env: "XAI_BASE_URL",
             default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
@@ -98,7 +97,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
         "grok-3-mini",
         ProviderMetadata {
             provider: ProviderKind::Xai,
-            canonical_model: "grok-3-mini",
             auth_env: "XAI_API_KEY",
             base_url_env: "XAI_BASE_URL",
             default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
@@ -108,7 +106,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
         "grok-2",
         ProviderMetadata {
             provider: ProviderKind::Xai,
-            canonical_model: "grok-2",
             auth_env: "XAI_API_KEY",
             base_url_env: "XAI_BASE_URL",
             default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
@@ -122,7 +119,23 @@ pub fn resolve_model_alias(model: &str) -> String {
     let lower = trimmed.to_ascii_lowercase();
     MODEL_REGISTRY
         .iter()
-        .find_map(|(alias, metadata)| (*alias == lower).then_some(metadata.canonical_model))
+        .find_map(|(alias, metadata)| {
+            (*alias == lower).then_some(match metadata.provider {
+                ProviderKind::Anthropic => match *alias {
+                    "opus" => "claude-opus-4-6",
+                    "sonnet" => "claude-sonnet-4-6",
+                    "haiku" => "claude-haiku-4-5-20251213",
+                    _ => trimmed,
+                },
+                ProviderKind::Xai => match *alias {
+                    "grok" | "grok-3" => "grok-3",
+                    "grok-mini" | "grok-3-mini" => "grok-3-mini",
+                    "grok-2" => "grok-2",
+                    _ => trimmed,
+                },
+                ProviderKind::OpenAi => trimmed,
+            })
+        })
         .map_or_else(|| trimmed.to_string(), ToOwned::to_owned)
 }
 
@@ -132,7 +145,6 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
     if canonical.starts_with("claude") {
         return Some(ProviderMetadata {
             provider: ProviderKind::Anthropic,
-            canonical_model: Box::leak(canonical.into_boxed_str()),
             auth_env: "ANTHROPIC_API_KEY",
             base_url_env: "ANTHROPIC_BASE_URL",
             default_base_url: anthropic::DEFAULT_BASE_URL,
@@ -141,7 +153,6 @@ pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
     if canonical.starts_with("grok") {
         return Some(ProviderMetadata {
             provider: ProviderKind::Xai,
-            canonical_model: Box::leak(canonical.into_boxed_str()),
             auth_env: "XAI_API_KEY",
             base_url_env: "XAI_BASE_URL",
             default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
@@ -191,7 +202,10 @@ mod tests {
     #[test]
     fn detects_provider_from_model_name_first() {
         assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai);
-        assert_eq!(detect_provider_kind("claude-sonnet-4-6"), ProviderKind::Anthropic);
+        assert_eq!(
+            detect_provider_kind("claude-sonnet-4-6"),
+            ProviderKind::Anthropic
+        );
     }
 
     #[test]

+ 1025 - 0
rust/crates/api/src/providers/openai_compat.rs

@@ -0,0 +1,1025 @@
+use std::collections::{BTreeMap, VecDeque};
+use std::time::Duration;
+
+use serde::Deserialize;
+use serde_json::{json, Value};
+
+use crate::error::ApiError;
+use crate::types::{
+    ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
+    InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
+    MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
+    ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
+};
+
+use super::{Provider, ProviderFuture};
+
+pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
+pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
+const REQUEST_ID_HEADER: &str = "request-id";
+const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
+const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
+const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
+const DEFAULT_MAX_RETRIES: u32 = 2;
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub struct OpenAiCompatConfig {
+    pub provider_name: &'static str,
+    pub api_key_env: &'static str,
+    pub base_url_env: &'static str,
+    pub default_base_url: &'static str,
+}
+
+const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
+const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
+
+impl OpenAiCompatConfig {
+    #[must_use]
+    pub const fn xai() -> Self {
+        Self {
+            provider_name: "xAI",
+            api_key_env: "XAI_API_KEY",
+            base_url_env: "XAI_BASE_URL",
+            default_base_url: DEFAULT_XAI_BASE_URL,
+        }
+    }
+
+    #[must_use]
+    pub const fn openai() -> Self {
+        Self {
+            provider_name: "OpenAI",
+            api_key_env: "OPENAI_API_KEY",
+            base_url_env: "OPENAI_BASE_URL",
+            default_base_url: DEFAULT_OPENAI_BASE_URL,
+        }
+    }
+    #[must_use]
+    pub fn credential_env_vars(self) -> &'static [&'static str] {
+        match self.provider_name {
+            "xAI" => XAI_ENV_VARS,
+            "OpenAI" => OPENAI_ENV_VARS,
+            _ => &[],
+        }
+    }
+}
+
+#[derive(Debug, Clone)]
+pub struct OpenAiCompatClient {
+    http: reqwest::Client,
+    api_key: String,
+    base_url: String,
+    max_retries: u32,
+    initial_backoff: Duration,
+    max_backoff: Duration,
+}
+
+impl OpenAiCompatClient {
+    #[must_use]
+    pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
+        Self {
+            http: reqwest::Client::new(),
+            api_key: api_key.into(),
+            base_url: read_base_url(config),
+            max_retries: DEFAULT_MAX_RETRIES,
+            initial_backoff: DEFAULT_INITIAL_BACKOFF,
+            max_backoff: DEFAULT_MAX_BACKOFF,
+        }
+    }
+
+    pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
+        let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
+            return Err(ApiError::missing_credentials(
+                config.provider_name,
+                config.credential_env_vars(),
+            ));
+        };
+        Ok(Self::new(api_key, config))
+    }
+
+    #[must_use]
+    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
+        self.base_url = base_url.into();
+        self
+    }
+
+    #[must_use]
+    pub fn with_retry_policy(
+        mut self,
+        max_retries: u32,
+        initial_backoff: Duration,
+        max_backoff: Duration,
+    ) -> Self {
+        self.max_retries = max_retries;
+        self.initial_backoff = initial_backoff;
+        self.max_backoff = max_backoff;
+        self
+    }
+
+    pub async fn send_message(
+        &self,
+        request: &MessageRequest,
+    ) -> Result<MessageResponse, ApiError> {
+        let request = MessageRequest {
+            stream: false,
+            ..request.clone()
+        };
+        let response = self.send_with_retry(&request).await?;
+        let request_id = request_id_from_headers(response.headers());
+        let payload = response.json::<ChatCompletionResponse>().await?;
+        let mut normalized = normalize_response(&request.model, payload)?;
+        if normalized.request_id.is_none() {
+            normalized.request_id = request_id;
+        }
+        Ok(normalized)
+    }
+
+    pub async fn stream_message(
+        &self,
+        request: &MessageRequest,
+    ) -> Result<MessageStream, ApiError> {
+        let response = self
+            .send_with_retry(&request.clone().with_streaming())
+            .await?;
+        Ok(MessageStream {
+            request_id: request_id_from_headers(response.headers()),
+            response,
+            parser: OpenAiSseParser::new(),
+            pending: VecDeque::new(),
+            done: false,
+            state: StreamState::new(request.model.clone()),
+        })
+    }
+
+    async fn send_with_retry(
+        &self,
+        request: &MessageRequest,
+    ) -> Result<reqwest::Response, ApiError> {
+        let mut attempts = 0;
+
+        let last_error = loop {
+            attempts += 1;
+            let retryable_error = match self.send_raw_request(request).await {
+                Ok(response) => match expect_success(response).await {
+                    Ok(response) => return Ok(response),
+                    Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
+                    Err(error) => return Err(error),
+                },
+                Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
+                Err(error) => return Err(error),
+            };
+
+            if attempts > self.max_retries {
+                break retryable_error;
+            }
+
+            tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
+        };
+
+        Err(ApiError::RetriesExhausted {
+            attempts,
+            last_error: Box::new(last_error),
+        })
+    }
+
+    async fn send_raw_request(
+        &self,
+        request: &MessageRequest,
+    ) -> Result<reqwest::Response, ApiError> {
+        let request_url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
+        self.http
+            .post(&request_url)
+            .header("content-type", "application/json")
+            .bearer_auth(&self.api_key)
+            .json(&build_chat_completion_request(request))
+            .send()
+            .await
+            .map_err(ApiError::from)
+    }
+
+    fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
+        let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
+            return Err(ApiError::BackoffOverflow {
+                attempt,
+                base_delay: self.initial_backoff,
+            });
+        };
+        Ok(self
+            .initial_backoff
+            .checked_mul(multiplier)
+            .map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
+    }
+}
+
+impl Provider for OpenAiCompatClient {
+    type Stream = MessageStream;
+
+    fn send_message<'a>(
+        &'a self,
+        request: &'a MessageRequest,
+    ) -> ProviderFuture<'a, MessageResponse> {
+        Box::pin(async move { self.send_message(request).await })
+    }
+
+    fn stream_message<'a>(
+        &'a self,
+        request: &'a MessageRequest,
+    ) -> ProviderFuture<'a, Self::Stream> {
+        Box::pin(async move { self.stream_message(request).await })
+    }
+}
+
+#[derive(Debug)]
+pub struct MessageStream {
+    request_id: Option<String>,
+    response: reqwest::Response,
+    parser: OpenAiSseParser,
+    pending: VecDeque<StreamEvent>,
+    done: bool,
+    state: StreamState,
+}
+
+impl MessageStream {
+    #[must_use]
+    pub fn request_id(&self) -> Option<&str> {
+        self.request_id.as_deref()
+    }
+
+    pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
+        loop {
+            if let Some(event) = self.pending.pop_front() {
+                return Ok(Some(event));
+            }
+
+            if self.done {
+                self.pending.extend(self.state.finish()?);
+                if let Some(event) = self.pending.pop_front() {
+                    return Ok(Some(event));
+                }
+                return Ok(None);
+            }
+
+            match self.response.chunk().await? {
+                Some(chunk) => {
+                    for parsed in self.parser.push(&chunk)? {
+                        self.pending.extend(self.state.ingest_chunk(parsed)?);
+                    }
+                }
+                None => {
+                    self.done = true;
+                }
+            }
+        }
+    }
+}
+
+#[derive(Debug, Default)]
+struct OpenAiSseParser {
+    buffer: Vec<u8>,
+}
+
+impl OpenAiSseParser {
+    fn new() -> Self {
+        Self::default()
+    }
+
+    fn push(&mut self, chunk: &[u8]) -> Result<Vec<ChatCompletionChunk>, ApiError> {
+        self.buffer.extend_from_slice(chunk);
+        let mut events = Vec::new();
+
+        while let Some(frame) = next_sse_frame(&mut self.buffer) {
+            if let Some(event) = parse_sse_frame(&frame)? {
+                events.push(event);
+            }
+        }
+
+        Ok(events)
+    }
+}
+
+#[derive(Debug)]
+struct StreamState {
+    model: String,
+    message_started: bool,
+    text_started: bool,
+    text_finished: bool,
+    finished: bool,
+    stop_reason: Option<String>,
+    usage: Option<Usage>,
+    tool_calls: BTreeMap<u32, ToolCallState>,
+}
+
+impl StreamState {
+    fn new(model: String) -> Self {
+        Self {
+            model,
+            message_started: false,
+            text_started: false,
+            text_finished: false,
+            finished: false,
+            stop_reason: None,
+            usage: None,
+            tool_calls: BTreeMap::new(),
+        }
+    }
+
+    fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result<Vec<StreamEvent>, ApiError> {
+        let mut events = Vec::new();
+        if !self.message_started {
+            self.message_started = true;
+            events.push(StreamEvent::MessageStart(MessageStartEvent {
+                message: MessageResponse {
+                    id: chunk.id.clone(),
+                    kind: "message".to_string(),
+                    role: "assistant".to_string(),
+                    content: Vec::new(),
+                    model: chunk.model.clone().unwrap_or_else(|| self.model.clone()),
+                    stop_reason: None,
+                    stop_sequence: None,
+                    usage: Usage {
+                        input_tokens: 0,
+                        cache_creation_input_tokens: 0,
+                        cache_read_input_tokens: 0,
+                        output_tokens: 0,
+                    },
+                    request_id: None,
+                },
+            }));
+        }
+
+        if let Some(usage) = chunk.usage {
+            self.usage = Some(Usage {
+                input_tokens: usage.prompt_tokens,
+                cache_creation_input_tokens: 0,
+                cache_read_input_tokens: 0,
+                output_tokens: usage.completion_tokens,
+            });
+        }
+
+        for choice in chunk.choices {
+            if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) {
+                if !self.text_started {
+                    self.text_started = true;
+                    events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent {
+                        index: 0,
+                        content_block: OutputContentBlock::Text {
+                            text: String::new(),
+                        },
+                    }));
+                }
+                events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
+                    index: 0,
+                    delta: ContentBlockDelta::TextDelta { text: content },
+                }));
+            }
+
+            for tool_call in choice.delta.tool_calls {
+                let state = self.tool_calls.entry(tool_call.index).or_default();
+                state.apply(tool_call);
+                let block_index = state.block_index();
+                if !state.started {
+                    if let Some(start_event) = state.start_event()? {
+                        state.started = true;
+                        events.push(StreamEvent::ContentBlockStart(start_event));
+                    } else {
+                        continue;
+                    }
+                }
+                if let Some(delta_event) = state.delta_event() {
+                    events.push(StreamEvent::ContentBlockDelta(delta_event));
+                }
+                if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped {
+                    state.stopped = true;
+                    events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
+                        index: block_index,
+                    }));
+                }
+            }
+
+            if let Some(finish_reason) = choice.finish_reason {
+                self.stop_reason = Some(normalize_finish_reason(&finish_reason));
+                if finish_reason == "tool_calls" {
+                    for state in self.tool_calls.values_mut() {
+                        if state.started && !state.stopped {
+                            state.stopped = true;
+                            events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
+                                index: state.block_index(),
+                            }));
+                        }
+                    }
+                }
+            }
+        }
+
+        Ok(events)
+    }
+
+    fn finish(&mut self) -> Result<Vec<StreamEvent>, ApiError> {
+        if self.finished {
+            return Ok(Vec::new());
+        }
+        self.finished = true;
+
+        let mut events = Vec::new();
+        if self.text_started && !self.text_finished {
+            self.text_finished = true;
+            events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
+                index: 0,
+            }));
+        }
+
+        for state in self.tool_calls.values_mut() {
+            if !state.started {
+                if let Some(start_event) = state.start_event()? {
+                    state.started = true;
+                    events.push(StreamEvent::ContentBlockStart(start_event));
+                    if let Some(delta_event) = state.delta_event() {
+                        events.push(StreamEvent::ContentBlockDelta(delta_event));
+                    }
+                }
+            }
+            if state.started && !state.stopped {
+                state.stopped = true;
+                events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
+                    index: state.block_index(),
+                }));
+            }
+        }
+
+        if self.message_started {
+            events.push(StreamEvent::MessageDelta(MessageDeltaEvent {
+                delta: MessageDelta {
+                    stop_reason: Some(
+                        self.stop_reason
+                            .clone()
+                            .unwrap_or_else(|| "end_turn".to_string()),
+                    ),
+                    stop_sequence: None,
+                },
+                usage: self.usage.clone().unwrap_or(Usage {
+                    input_tokens: 0,
+                    cache_creation_input_tokens: 0,
+                    cache_read_input_tokens: 0,
+                    output_tokens: 0,
+                }),
+            }));
+            events.push(StreamEvent::MessageStop(MessageStopEvent {}));
+        }
+        Ok(events)
+    }
+}
+
+#[derive(Debug, Default)]
+struct ToolCallState {
+    openai_index: u32,
+    id: Option<String>,
+    name: Option<String>,
+    arguments: String,
+    emitted_len: usize,
+    started: bool,
+    stopped: bool,
+}
+
+impl ToolCallState {
+    fn apply(&mut self, tool_call: DeltaToolCall) {
+        self.openai_index = tool_call.index;
+        if let Some(id) = tool_call.id {
+            self.id = Some(id);
+        }
+        if let Some(name) = tool_call.function.name {
+            self.name = Some(name);
+        }
+        if let Some(arguments) = tool_call.function.arguments {
+            self.arguments.push_str(&arguments);
+        }
+    }
+
+    const fn block_index(&self) -> u32 {
+        self.openai_index + 1
+    }
+
+    fn start_event(&self) -> Result<Option<ContentBlockStartEvent>, ApiError> {
+        let Some(name) = self.name.clone() else {
+            return Ok(None);
+        };
+        let id = self
+            .id
+            .clone()
+            .unwrap_or_else(|| format!("tool_call_{}", self.openai_index));
+        Ok(Some(ContentBlockStartEvent {
+            index: self.block_index(),
+            content_block: OutputContentBlock::ToolUse {
+                id,
+                name,
+                input: json!({}),
+            },
+        }))
+    }
+
+    fn delta_event(&mut self) -> Option<ContentBlockDeltaEvent> {
+        if self.emitted_len >= self.arguments.len() {
+            return None;
+        }
+        let delta = self.arguments[self.emitted_len..].to_string();
+        self.emitted_len = self.arguments.len();
+        Some(ContentBlockDeltaEvent {
+            index: self.block_index(),
+            delta: ContentBlockDelta::InputJsonDelta {
+                partial_json: delta,
+            },
+        })
+    }
+}
+
+#[derive(Debug, Deserialize)]
+struct ChatCompletionResponse {
+    id: String,
+    model: String,
+    choices: Vec<ChatChoice>,
+    #[serde(default)]
+    usage: Option<OpenAiUsage>,
+}
+
+#[derive(Debug, Deserialize)]
+struct ChatChoice {
+    message: ChatMessage,
+    #[serde(default)]
+    finish_reason: Option<String>,
+}
+
+#[derive(Debug, Deserialize)]
+struct ChatMessage {
+    role: String,
+    #[serde(default)]
+    content: Option<String>,
+    #[serde(default)]
+    tool_calls: Vec<ResponseToolCall>,
+}
+
+#[derive(Debug, Deserialize)]
+struct ResponseToolCall {
+    id: String,
+    function: ResponseToolFunction,
+}
+
+#[derive(Debug, Deserialize)]
+struct ResponseToolFunction {
+    name: String,
+    arguments: String,
+}
+
+#[derive(Debug, Deserialize)]
+struct OpenAiUsage {
+    #[serde(default)]
+    prompt_tokens: u32,
+    #[serde(default)]
+    completion_tokens: u32,
+}
+
+#[derive(Debug, Deserialize)]
+struct ChatCompletionChunk {
+    id: String,
+    #[serde(default)]
+    model: Option<String>,
+    #[serde(default)]
+    choices: Vec<ChunkChoice>,
+    #[serde(default)]
+    usage: Option<OpenAiUsage>,
+}
+
+#[derive(Debug, Deserialize)]
+struct ChunkChoice {
+    delta: ChunkDelta,
+    #[serde(default)]
+    finish_reason: Option<String>,
+}
+
+#[derive(Debug, Default, Deserialize)]
+struct ChunkDelta {
+    #[serde(default)]
+    content: Option<String>,
+    #[serde(default)]
+    tool_calls: Vec<DeltaToolCall>,
+}
+
+#[derive(Debug, Deserialize)]
+struct DeltaToolCall {
+    #[serde(default)]
+    index: u32,
+    #[serde(default)]
+    id: Option<String>,
+    #[serde(default)]
+    function: DeltaFunction,
+}
+
+#[derive(Debug, Default, Deserialize)]
+struct DeltaFunction {
+    #[serde(default)]
+    name: Option<String>,
+    #[serde(default)]
+    arguments: Option<String>,
+}
+
+#[derive(Debug, Deserialize)]
+struct ErrorEnvelope {
+    error: ErrorBody,
+}
+
+#[derive(Debug, Deserialize)]
+struct ErrorBody {
+    #[serde(rename = "type")]
+    error_type: Option<String>,
+    message: Option<String>,
+}
+
+fn build_chat_completion_request(request: &MessageRequest) -> Value {
+    let mut messages = Vec::new();
+    if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
+        messages.push(json!({
+            "role": "system",
+            "content": system,
+        }));
+    }
+    for message in &request.messages {
+        messages.extend(translate_message(message));
+    }
+
+    let mut payload = json!({
+        "model": request.model,
+        "max_tokens": request.max_tokens,
+        "messages": messages,
+        "stream": request.stream,
+    });
+
+    if let Some(tools) = &request.tools {
+        payload["tools"] =
+            Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
+    }
+    if let Some(tool_choice) = &request.tool_choice {
+        payload["tool_choice"] = openai_tool_choice(tool_choice);
+    }
+
+    payload
+}
+
+fn translate_message(message: &InputMessage) -> Vec<Value> {
+    match message.role.as_str() {
+        "assistant" => {
+            let mut text = String::new();
+            let mut tool_calls = Vec::new();
+            for block in &message.content {
+                match block {
+                    InputContentBlock::Text { text: value } => text.push_str(value),
+                    InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({
+                        "id": id,
+                        "type": "function",
+                        "function": {
+                            "name": name,
+                            "arguments": input.to_string(),
+                        }
+                    })),
+                    InputContentBlock::ToolResult { .. } => {}
+                }
+            }
+            if text.is_empty() && tool_calls.is_empty() {
+                Vec::new()
+            } else {
+                vec![json!({
+                    "role": "assistant",
+                    "content": (!text.is_empty()).then_some(text),
+                    "tool_calls": tool_calls,
+                })]
+            }
+        }
+        _ => message
+            .content
+            .iter()
+            .filter_map(|block| match block {
+                InputContentBlock::Text { text } => Some(json!({
+                    "role": "user",
+                    "content": text,
+                })),
+                InputContentBlock::ToolResult {
+                    tool_use_id,
+                    content,
+                    is_error,
+                } => Some(json!({
+                    "role": "tool",
+                    "tool_call_id": tool_use_id,
+                    "content": flatten_tool_result_content(content),
+                    "is_error": is_error,
+                })),
+                InputContentBlock::ToolUse { .. } => None,
+            })
+            .collect(),
+    }
+}
+
+fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String {
+    content
+        .iter()
+        .map(|block| match block {
+            ToolResultContentBlock::Text { text } => text.clone(),
+            ToolResultContentBlock::Json { value } => value.to_string(),
+        })
+        .collect::<Vec<_>>()
+        .join("\n")
+}
+
+fn openai_tool_definition(tool: &ToolDefinition) -> Value {
+    json!({
+        "type": "function",
+        "function": {
+            "name": tool.name,
+            "description": tool.description,
+            "parameters": tool.input_schema,
+        }
+    })
+}
+
+fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
+    match tool_choice {
+        ToolChoice::Auto => Value::String("auto".to_string()),
+        ToolChoice::Any => Value::String("required".to_string()),
+        ToolChoice::Tool { name } => json!({
+            "type": "function",
+            "function": { "name": name },
+        }),
+    }
+}
+
+fn normalize_response(
+    model: &str,
+    response: ChatCompletionResponse,
+) -> Result<MessageResponse, ApiError> {
+    let choice = response
+        .choices
+        .into_iter()
+        .next()
+        .ok_or(ApiError::InvalidSseFrame(
+            "chat completion response missing choices",
+        ))?;
+    let mut content = Vec::new();
+    if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) {
+        content.push(OutputContentBlock::Text { text });
+    }
+    for tool_call in choice.message.tool_calls {
+        content.push(OutputContentBlock::ToolUse {
+            id: tool_call.id,
+            name: tool_call.function.name,
+            input: parse_tool_arguments(&tool_call.function.arguments),
+        });
+    }
+
+    Ok(MessageResponse {
+        id: response.id,
+        kind: "message".to_string(),
+        role: choice.message.role,
+        content,
+        model: response.model.if_empty_then(model.to_string()),
+        stop_reason: choice
+            .finish_reason
+            .map(|value| normalize_finish_reason(&value)),
+        stop_sequence: None,
+        usage: Usage {
+            input_tokens: response
+                .usage
+                .as_ref()
+                .map_or(0, |usage| usage.prompt_tokens),
+            cache_creation_input_tokens: 0,
+            cache_read_input_tokens: 0,
+            output_tokens: response
+                .usage
+                .as_ref()
+                .map_or(0, |usage| usage.completion_tokens),
+        },
+        request_id: None,
+    })
+}
+
+fn parse_tool_arguments(arguments: &str) -> Value {
+    serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments }))
+}
+
+fn next_sse_frame(buffer: &mut Vec<u8>) -> Option<String> {
+    let separator = buffer
+        .windows(2)
+        .position(|window| window == b"\n\n")
+        .map(|position| (position, 2))
+        .or_else(|| {
+            buffer
+                .windows(4)
+                .position(|window| window == b"\r\n\r\n")
+                .map(|position| (position, 4))
+        })?;
+
+    let (position, separator_len) = separator;
+    let frame = buffer.drain(..position + separator_len).collect::<Vec<_>>();
+    let frame_len = frame.len().saturating_sub(separator_len);
+    Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
+}
+
+fn parse_sse_frame(frame: &str) -> Result<Option<ChatCompletionChunk>, ApiError> {
+    let trimmed = frame.trim();
+    if trimmed.is_empty() {
+        return Ok(None);
+    }
+
+    let mut data_lines = Vec::new();
+    for line in trimmed.lines() {
+        if line.starts_with(':') {
+            continue;
+        }
+        if let Some(data) = line.strip_prefix("data:") {
+            data_lines.push(data.trim_start());
+        }
+    }
+    if data_lines.is_empty() {
+        return Ok(None);
+    }
+    let payload = data_lines.join("\n");
+    if payload == "[DONE]" {
+        return Ok(None);
+    }
+    serde_json::from_str(&payload)
+        .map(Some)
+        .map_err(ApiError::from)
+}
+
+fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
+    match std::env::var(key) {
+        Ok(value) if !value.is_empty() => Ok(Some(value)),
+        Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
+        Err(error) => Err(ApiError::from(error)),
+    }
+}
+
+#[must_use]
+pub fn has_api_key(key: &str) -> bool {
+    read_env_non_empty(key)
+        .ok()
+        .and_then(std::convert::identity)
+        .is_some()
+}
+
+#[must_use]
+pub fn read_base_url(config: OpenAiCompatConfig) -> String {
+    std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
+}
+
+fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
+    headers
+        .get(REQUEST_ID_HEADER)
+        .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
+        .and_then(|value| value.to_str().ok())
+        .map(ToOwned::to_owned)
+}
+
+async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
+    let status = response.status();
+    if status.is_success() {
+        return Ok(response);
+    }
+
+    let body = response.text().await.unwrap_or_default();
+    let parsed_error = serde_json::from_str::<ErrorEnvelope>(&body).ok();
+    let retryable = is_retryable_status(status);
+
+    Err(ApiError::Api {
+        status,
+        error_type: parsed_error
+            .as_ref()
+            .and_then(|error| error.error.error_type.clone()),
+        message: parsed_error
+            .as_ref()
+            .and_then(|error| error.error.message.clone()),
+        body,
+        retryable,
+    })
+}
+
+const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
+    matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
+}
+
+fn normalize_finish_reason(value: &str) -> String {
+    match value {
+        "stop" => "end_turn",
+        "tool_calls" => "tool_use",
+        other => other,
+    }
+    .to_string()
+}
+
+trait StringExt {
+    fn if_empty_then(self, fallback: String) -> String;
+}
+
+impl StringExt for String {
+    fn if_empty_then(self, fallback: String) -> String {
+        if self.is_empty() {
+            fallback
+        } else {
+            self
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::{
+        build_chat_completion_request, normalize_finish_reason, openai_tool_choice,
+        parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig,
+    };
+    use crate::error::ApiError;
+    use crate::types::{
+        InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition,
+        ToolResultContentBlock,
+    };
+    use serde_json::json;
+    use std::sync::{Mutex, OnceLock};
+
+    #[test]
+    fn request_translation_uses_openai_compatible_shape() {
+        let payload = build_chat_completion_request(&MessageRequest {
+            model: "grok-3".to_string(),
+            max_tokens: 64,
+            messages: vec![InputMessage {
+                role: "user".to_string(),
+                content: vec![
+                    InputContentBlock::Text {
+                        text: "hello".to_string(),
+                    },
+                    InputContentBlock::ToolResult {
+                        tool_use_id: "tool_1".to_string(),
+                        content: vec![ToolResultContentBlock::Json {
+                            value: json!({"ok": true}),
+                        }],
+                        is_error: false,
+                    },
+                ],
+            }],
+            system: Some("be helpful".to_string()),
+            tools: Some(vec![ToolDefinition {
+                name: "weather".to_string(),
+                description: Some("Get weather".to_string()),
+                input_schema: json!({"type": "object"}),
+            }]),
+            tool_choice: Some(ToolChoice::Auto),
+            stream: false,
+        });
+
+        assert_eq!(payload["messages"][0]["role"], json!("system"));
+        assert_eq!(payload["messages"][1]["role"], json!("user"));
+        assert_eq!(payload["messages"][2]["role"], json!("tool"));
+        assert_eq!(payload["tools"][0]["type"], json!("function"));
+        assert_eq!(payload["tool_choice"], json!("auto"));
+    }
+
+    #[test]
+    fn tool_choice_translation_supports_required_function() {
+        assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required"));
+        assert_eq!(
+            openai_tool_choice(&ToolChoice::Tool {
+                name: "weather".to_string(),
+            }),
+            json!({"type": "function", "function": {"name": "weather"}})
+        );
+    }
+
+    #[test]
+    fn parses_tool_arguments_fallback() {
+        assert_eq!(
+            parse_tool_arguments("{\"city\":\"Paris\"}"),
+            json!({"city": "Paris"})
+        );
+        assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"}));
+    }
+
+    #[test]
+    fn missing_xai_api_key_is_provider_specific() {
+        let _lock = env_lock();
+        std::env::remove_var("XAI_API_KEY");
+        let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai())
+            .expect_err("missing key should error");
+        assert!(matches!(
+            error,
+            ApiError::MissingCredentials {
+                provider: "xAI",
+                ..
+            }
+        ));
+    }
+
+    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
+        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
+        LOCK.get_or_init(|| Mutex::new(()))
+            .lock()
+            .expect("env lock")
+    }
+
+    #[test]
+    fn normalizes_stop_reasons() {
+        assert_eq!(normalize_finish_reason("stop"), "end_turn");
+        assert_eq!(normalize_finish_reason("tool_calls"), "tool_use");
+    }
+}

+ 312 - 0
rust/crates/api/tests/openai_compat_integration.rs

@@ -0,0 +1,312 @@
+use std::collections::HashMap;
+use std::sync::Arc;
+
+use api::{
+    ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
+    InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig,
+    OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition,
+};
+use serde_json::json;
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
+use tokio::net::TcpListener;
+use tokio::sync::Mutex;
+
+#[tokio::test]
+async fn send_message_uses_openai_compatible_endpoint_and_auth() {
+    let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
+    let body = concat!(
+        "{",
+        "\"id\":\"chatcmpl_test\",",
+        "\"model\":\"grok-3\",",
+        "\"choices\":[{",
+        "\"message\":{\"role\":\"assistant\",\"content\":\"Hello from Grok\",\"tool_calls\":[]},",
+        "\"finish_reason\":\"stop\"",
+        "}],",
+        "\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":5}",
+        "}"
+    );
+    let server = spawn_server(
+        state.clone(),
+        vec![http_response("200 OK", "application/json", body)],
+    )
+    .await;
+
+    let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
+        .with_base_url(server.base_url());
+    let response = client
+        .send_message(&sample_request(false))
+        .await
+        .expect("request should succeed");
+
+    assert_eq!(response.model, "grok-3");
+    assert_eq!(response.total_tokens(), 16);
+    assert_eq!(
+        response.content,
+        vec![OutputContentBlock::Text {
+            text: "Hello from Grok".to_string(),
+        }]
+    );
+
+    let captured = state.lock().await;
+    let request = captured.first().expect("server should capture request");
+    assert_eq!(request.path, "/chat/completions");
+    assert_eq!(
+        request.headers.get("authorization").map(String::as_str),
+        Some("Bearer xai-test-key")
+    );
+    let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
+    assert_eq!(body["model"], json!("grok-3"));
+    assert_eq!(body["messages"][0]["role"], json!("system"));
+    assert_eq!(body["tools"][0]["type"], json!("function"));
+}
+
+#[tokio::test]
+async fn stream_message_normalizes_text_and_multiple_tool_calls() {
+    let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
+    let sse = concat!(
+        "data: {\"id\":\"chatcmpl_stream\",\"model\":\"grok-3\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n",
+        "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}},{\"index\":1,\"id\":\"call_2\",\"function\":{\"name\":\"clock\",\"arguments\":\"{\\\"zone\\\":\\\"UTC\\\"}\"}}]}}]}\n\n",
+        "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n",
+        "data: [DONE]\n\n"
+    );
+    let server = spawn_server(
+        state.clone(),
+        vec![http_response_with_headers(
+            "200 OK",
+            "text/event-stream",
+            sse,
+            &[("x-request-id", "req_grok_stream")],
+        )],
+    )
+    .await;
+
+    let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
+        .with_base_url(server.base_url());
+    let mut stream = client
+        .stream_message(&sample_request(false))
+        .await
+        .expect("stream should start");
+
+    assert_eq!(stream.request_id(), Some("req_grok_stream"));
+
+    let mut events = Vec::new();
+    while let Some(event) = stream.next_event().await.expect("event should parse") {
+        events.push(event);
+    }
+
+    assert!(matches!(events[0], StreamEvent::MessageStart(_)));
+    assert!(matches!(
+        events[1],
+        StreamEvent::ContentBlockStart(ContentBlockStartEvent {
+            content_block: OutputContentBlock::Text { .. },
+            ..
+        })
+    ));
+    assert!(matches!(
+        events[2],
+        StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
+            delta: ContentBlockDelta::TextDelta { .. },
+            ..
+        })
+    ));
+    assert!(matches!(
+        events[3],
+        StreamEvent::ContentBlockStart(ContentBlockStartEvent {
+            index: 1,
+            content_block: OutputContentBlock::ToolUse { .. },
+        })
+    ));
+    assert!(matches!(
+        events[4],
+        StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
+            index: 1,
+            delta: ContentBlockDelta::InputJsonDelta { .. },
+        })
+    ));
+    assert!(matches!(
+        events[5],
+        StreamEvent::ContentBlockStart(ContentBlockStartEvent {
+            index: 2,
+            content_block: OutputContentBlock::ToolUse { .. },
+        })
+    ));
+    assert!(matches!(
+        events[6],
+        StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
+            index: 2,
+            delta: ContentBlockDelta::InputJsonDelta { .. },
+        })
+    ));
+    assert!(matches!(
+        events[7],
+        StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 1 })
+    ));
+    assert!(matches!(
+        events[8],
+        StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 2 })
+    ));
+    assert!(matches!(
+        events[9],
+        StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 })
+    ));
+    assert!(matches!(events[10], StreamEvent::MessageDelta(_)));
+    assert!(matches!(events[11], StreamEvent::MessageStop(_)));
+
+    let captured = state.lock().await;
+    let request = captured.first().expect("captured request");
+    assert_eq!(request.path, "/chat/completions");
+    assert!(request.body.contains("\"stream\":true"));
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+struct CapturedRequest {
+    path: String,
+    headers: HashMap<String, String>,
+    body: String,
+}
+
+struct TestServer {
+    base_url: String,
+    join_handle: tokio::task::JoinHandle<()>,
+}
+
+impl TestServer {
+    fn base_url(&self) -> String {
+        self.base_url.clone()
+    }
+}
+
+impl Drop for TestServer {
+    fn drop(&mut self) {
+        self.join_handle.abort();
+    }
+}
+
+async fn spawn_server(
+    state: Arc<Mutex<Vec<CapturedRequest>>>,
+    responses: Vec<String>,
+) -> TestServer {
+    let listener = TcpListener::bind("127.0.0.1:0")
+        .await
+        .expect("listener should bind");
+    let address = listener.local_addr().expect("listener addr");
+    let join_handle = tokio::spawn(async move {
+        for response in responses {
+            let (mut socket, _) = listener.accept().await.expect("accept");
+            let mut buffer = Vec::new();
+            let mut header_end = None;
+            loop {
+                let mut chunk = [0_u8; 1024];
+                let read = socket.read(&mut chunk).await.expect("read request");
+                if read == 0 {
+                    break;
+                }
+                buffer.extend_from_slice(&chunk[..read]);
+                if let Some(position) = find_header_end(&buffer) {
+                    header_end = Some(position);
+                    break;
+                }
+            }
+
+            let header_end = header_end.expect("headers should exist");
+            let (header_bytes, remaining) = buffer.split_at(header_end);
+            let header_text = String::from_utf8(header_bytes.to_vec()).expect("utf8 headers");
+            let mut lines = header_text.split("\r\n");
+            let request_line = lines.next().expect("request line");
+            let path = request_line
+                .split_whitespace()
+                .nth(1)
+                .expect("path")
+                .to_string();
+            let mut headers = HashMap::new();
+            let mut content_length = 0_usize;
+            for line in lines {
+                if line.is_empty() {
+                    continue;
+                }
+                let (name, value) = line.split_once(':').expect("header");
+                let value = value.trim().to_string();
+                if name.eq_ignore_ascii_case("content-length") {
+                    content_length = value.parse().expect("content length");
+                }
+                headers.insert(name.to_ascii_lowercase(), value);
+            }
+
+            let mut body = remaining[4..].to_vec();
+            while body.len() < content_length {
+                let mut chunk = vec![0_u8; content_length - body.len()];
+                let read = socket.read(&mut chunk).await.expect("read body");
+                if read == 0 {
+                    break;
+                }
+                body.extend_from_slice(&chunk[..read]);
+            }
+
+            state.lock().await.push(CapturedRequest {
+                path,
+                headers,
+                body: String::from_utf8(body).expect("utf8 body"),
+            });
+
+            socket
+                .write_all(response.as_bytes())
+                .await
+                .expect("write response");
+        }
+    });
+
+    TestServer {
+        base_url: format!("http://{address}"),
+        join_handle,
+    }
+}
+
+fn find_header_end(bytes: &[u8]) -> Option<usize> {
+    bytes.windows(4).position(|window| window == b"\r\n\r\n")
+}
+
+fn http_response(status: &str, content_type: &str, body: &str) -> String {
+    http_response_with_headers(status, content_type, body, &[])
+}
+
+fn http_response_with_headers(
+    status: &str,
+    content_type: &str,
+    body: &str,
+    headers: &[(&str, &str)],
+) -> String {
+    let mut extra_headers = String::new();
+    for (name, value) in headers {
+        use std::fmt::Write as _;
+        write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write");
+    }
+    format!(
+        "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
+        body.len()
+    )
+}
+
+fn sample_request(stream: bool) -> MessageRequest {
+    MessageRequest {
+        model: "grok-3".to_string(),
+        max_tokens: 64,
+        messages: vec![InputMessage {
+            role: "user".to_string(),
+            content: vec![InputContentBlock::Text {
+                text: "Say hello".to_string(),
+            }],
+        }],
+        system: Some("Use tools when needed".to_string()),
+        tools: Some(vec![ToolDefinition {
+            name: "weather".to_string(),
+            description: Some("Fetches weather".to_string()),
+            input_schema: json!({
+                "type": "object",
+                "properties": {"city": {"type": "string"}},
+                "required": ["city"]
+            }),
+        }]),
+        tool_choice: Some(ToolChoice::Auto),
+        stream,
+    }
+}

+ 22 - 27
rust/crates/rusty-claude-cli/src/main.rs

@@ -12,8 +12,9 @@ use std::process::Command;
 use std::time::{SystemTime, UNIX_EPOCH};
 
 use api::{
-    resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock,
-    InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
+    detect_provider_kind, max_tokens_for_model, resolve_model_alias, resolve_startup_auth_source,
+    AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, InputMessage,
+    MessageRequest, MessageResponse, OutputContentBlock, ProviderClient, ProviderKind,
     StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
 };
 
@@ -35,13 +36,6 @@ use serde_json::json;
 use tools::{execute_tool, mvp_tool_specs, ToolSpec};
 
 const DEFAULT_MODEL: &str = "claude-opus-4-6";
-fn max_tokens_for_model(model: &str) -> u32 {
-    if model.contains("opus") {
-        32_000
-    } else {
-        64_000
-    }
-}
 const DEFAULT_DATE: &str = "2026-03-31";
 const DEFAULT_OAUTH_CALLBACK_PORT: u16 = 4545;
 const VERSION: &str = env!("CARGO_PKG_VERSION");
@@ -288,15 +282,6 @@ fn parse_args(args: &[String]) -> Result<CliAction, String> {
     }
 }
 
-fn resolve_model_alias(model: &str) -> &str {
-    match model {
-        "opus" => "claude-opus-4-6",
-        "sonnet" => "claude-sonnet-4-6",
-        "haiku" => "claude-haiku-4-5-20251213",
-        _ => model,
-    }
-}
-
 fn normalize_allowed_tools(values: &[String]) -> Result<Option<AllowedToolSet>, String> {
     if values.is_empty() {
         return Ok(None);
@@ -980,7 +965,7 @@ struct LiveCli {
     allowed_tools: Option<AllowedToolSet>,
     permission_mode: PermissionMode,
     system_prompt: Vec<String>,
-    runtime: ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>,
+    runtime: ConversationRuntime<ProviderRuntimeClient, CliToolExecutor>,
     session: SessionHandle,
 }
 
@@ -1920,11 +1905,11 @@ fn build_runtime(
     emit_output: bool,
     allowed_tools: Option<AllowedToolSet>,
     permission_mode: PermissionMode,
-) -> Result<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
+) -> Result<ConversationRuntime<ProviderRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
 {
     Ok(ConversationRuntime::new_with_features(
         session,
-        AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?,
+        ProviderRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?,
         CliToolExecutor::new(allowed_tools, emit_output),
         permission_policy(permission_mode),
         system_prompt,
@@ -1978,26 +1963,33 @@ impl runtime::PermissionPrompter for CliPermissionPrompter {
     }
 }
 
-struct AnthropicRuntimeClient {
+struct ProviderRuntimeClient {
     runtime: tokio::runtime::Runtime,
-    client: AnthropicClient,
+    client: ProviderClient,
     model: String,
     enable_tools: bool,
     emit_output: bool,
     allowed_tools: Option<AllowedToolSet>,
 }
 
-impl AnthropicRuntimeClient {
+impl ProviderRuntimeClient {
     fn new(
         model: String,
         enable_tools: bool,
         emit_output: bool,
         allowed_tools: Option<AllowedToolSet>,
     ) -> Result<Self, Box<dyn std::error::Error>> {
+        let model = resolve_model_alias(&model).to_string();
+        let client = match detect_provider_kind(&model) {
+            ProviderKind::Anthropic => ProviderClient::from_model_with_anthropic_auth(
+                &model,
+                Some(resolve_cli_auth_source()?),
+            )?,
+            ProviderKind::Xai | ProviderKind::OpenAi => ProviderClient::from_model(&model)?,
+        };
         Ok(Self {
             runtime: tokio::runtime::Runtime::new()?,
-            client: AnthropicClient::from_auth(resolve_cli_auth_source()?)
-                .with_base_url(api::read_base_url()),
+            client,
             model,
             enable_tools,
             emit_output,
@@ -2016,7 +2008,7 @@ fn resolve_cli_auth_source() -> Result<AuthSource, Box<dyn std::error::Error>> {
     })?)
 }
 
-impl ApiClient for AnthropicRuntimeClient {
+impl ApiClient for ProviderRuntimeClient {
     #[allow(clippy::too_many_lines)]
     fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
         let message_request = MessageRequest {
@@ -2911,6 +2903,9 @@ mod tests {
         assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6");
         assert_eq!(resolve_model_alias("sonnet"), "claude-sonnet-4-6");
         assert_eq!(resolve_model_alias("haiku"), "claude-haiku-4-5-20251213");
+        assert_eq!(resolve_model_alias("grok"), "grok-3");
+        assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini");
+        assert_eq!(resolve_model_alias("grok-2"), "grok-2");
         assert_eq!(resolve_model_alias("claude-opus"), "claude-opus");
     }
 

+ 17 - 13
rust/crates/tools/src/lib.rs

@@ -4,9 +4,10 @@ use std::process::Command;
 use std::time::{Duration, Instant};
 
 use api::{
-    read_base_url, AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage,
-    MessageRequest, MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice,
-    ToolDefinition, ToolResultContentBlock,
+    detect_provider_kind, max_tokens_for_model, resolve_model_alias, ContentBlockDelta,
+    InputContentBlock, InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
+    ProviderClient, ProviderKind, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition,
+    ToolResultContentBlock,
 };
 use reqwest::blocking::Client;
 use runtime::{
@@ -1459,14 +1460,14 @@ fn run_agent_job(job: &AgentJob) -> Result<(), String> {
 
 fn build_agent_runtime(
     job: &AgentJob,
-) -> Result<ConversationRuntime<AnthropicRuntimeClient, SubagentToolExecutor>, String> {
+) -> Result<ConversationRuntime<ProviderRuntimeClient, SubagentToolExecutor>, String> {
     let model = job
         .manifest
         .model
         .clone()
         .unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string());
     let allowed_tools = job.allowed_tools.clone();
-    let api_client = AnthropicRuntimeClient::new(model, allowed_tools.clone())?;
+    let api_client = ProviderRuntimeClient::new(model, allowed_tools.clone())?;
     let tool_executor = SubagentToolExecutor::new(allowed_tools);
     Ok(ConversationRuntime::new(
         Session::new(),
@@ -1635,18 +1636,21 @@ fn format_agent_terminal_output(status: &str, result: Option<&str>, error: Optio
     sections.join("")
 }
 
-struct AnthropicRuntimeClient {
+struct ProviderRuntimeClient {
     runtime: tokio::runtime::Runtime,
-    client: AnthropicClient,
+    client: ProviderClient,
     model: String,
     allowed_tools: BTreeSet<String>,
 }
 
-impl AnthropicRuntimeClient {
+impl ProviderRuntimeClient {
     fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> {
-        let client = AnthropicClient::from_env()
-            .map_err(|error| error.to_string())?
-            .with_base_url(read_base_url());
+        let model = resolve_model_alias(&model).to_string();
+        let client = match detect_provider_kind(&model) {
+            ProviderKind::Anthropic | ProviderKind::Xai | ProviderKind::OpenAi => {
+                ProviderClient::from_model(&model).map_err(|error| error.to_string())?
+            }
+        };
         Ok(Self {
             runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?,
             client,
@@ -1656,7 +1660,7 @@ impl AnthropicRuntimeClient {
     }
 }
 
-impl ApiClient for AnthropicRuntimeClient {
+impl ApiClient for ProviderRuntimeClient {
     fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> {
         let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools))
             .into_iter()
@@ -1668,7 +1672,7 @@ impl ApiClient for AnthropicRuntimeClient {
             .collect::<Vec<_>>();
         let message_request = MessageRequest {
             model: self.model.clone(),
-            max_tokens: 32_000,
+            max_tokens: max_tokens_for_model(&self.model),
             messages: convert_messages(&request.messages),
             system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")),
             tools: (!tools.is_empty()).then_some(tools),