| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216 |
- use std::future::Future;
- use std::pin::Pin;
- use crate::error::ApiError;
- use crate::types::{MessageRequest, MessageResponse};
- pub mod anthropic;
- pub mod openai_compat;
- pub type ProviderFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, ApiError>> + Send + 'a>>;
- pub trait Provider {
- type Stream;
- fn send_message<'a>(
- &'a self,
- request: &'a MessageRequest,
- ) -> ProviderFuture<'a, MessageResponse>;
- fn stream_message<'a>(
- &'a self,
- request: &'a MessageRequest,
- ) -> ProviderFuture<'a, Self::Stream>;
- }
- #[derive(Debug, Clone, Copy, PartialEq, Eq)]
- pub enum ProviderKind {
- Anthropic,
- Xai,
- OpenAi,
- }
- #[derive(Debug, Clone, Copy, PartialEq, Eq)]
- pub struct ProviderMetadata {
- pub provider: ProviderKind,
- pub auth_env: &'static str,
- pub base_url_env: &'static str,
- pub default_base_url: &'static str,
- }
- const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
- (
- "opus",
- ProviderMetadata {
- provider: ProviderKind::Anthropic,
- auth_env: "ANTHROPIC_API_KEY",
- base_url_env: "ANTHROPIC_BASE_URL",
- default_base_url: anthropic::DEFAULT_BASE_URL,
- },
- ),
- (
- "sonnet",
- ProviderMetadata {
- provider: ProviderKind::Anthropic,
- auth_env: "ANTHROPIC_API_KEY",
- base_url_env: "ANTHROPIC_BASE_URL",
- default_base_url: anthropic::DEFAULT_BASE_URL,
- },
- ),
- (
- "haiku",
- ProviderMetadata {
- provider: ProviderKind::Anthropic,
- auth_env: "ANTHROPIC_API_KEY",
- base_url_env: "ANTHROPIC_BASE_URL",
- default_base_url: anthropic::DEFAULT_BASE_URL,
- },
- ),
- (
- "grok",
- ProviderMetadata {
- provider: ProviderKind::Xai,
- auth_env: "XAI_API_KEY",
- base_url_env: "XAI_BASE_URL",
- default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
- },
- ),
- (
- "grok-3",
- ProviderMetadata {
- provider: ProviderKind::Xai,
- auth_env: "XAI_API_KEY",
- base_url_env: "XAI_BASE_URL",
- default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
- },
- ),
- (
- "grok-mini",
- ProviderMetadata {
- provider: ProviderKind::Xai,
- auth_env: "XAI_API_KEY",
- base_url_env: "XAI_BASE_URL",
- default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
- },
- ),
- (
- "grok-3-mini",
- ProviderMetadata {
- provider: ProviderKind::Xai,
- auth_env: "XAI_API_KEY",
- base_url_env: "XAI_BASE_URL",
- default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
- },
- ),
- (
- "grok-2",
- ProviderMetadata {
- provider: ProviderKind::Xai,
- auth_env: "XAI_API_KEY",
- base_url_env: "XAI_BASE_URL",
- default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
- },
- ),
- ];
- #[must_use]
- pub fn resolve_model_alias(model: &str) -> String {
- let trimmed = model.trim();
- let lower = trimmed.to_ascii_lowercase();
- MODEL_REGISTRY
- .iter()
- .find_map(|(alias, metadata)| {
- (*alias == lower).then_some(match metadata.provider {
- ProviderKind::Anthropic => match *alias {
- "opus" => "claude-opus-4-6",
- "sonnet" => "claude-sonnet-4-6",
- "haiku" => "claude-haiku-4-5-20251213",
- _ => trimmed,
- },
- ProviderKind::Xai => match *alias {
- "grok" | "grok-3" => "grok-3",
- "grok-mini" | "grok-3-mini" => "grok-3-mini",
- "grok-2" => "grok-2",
- _ => trimmed,
- },
- ProviderKind::OpenAi => trimmed,
- })
- })
- .map_or_else(|| trimmed.to_string(), ToOwned::to_owned)
- }
- #[must_use]
- pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
- let canonical = resolve_model_alias(model);
- if canonical.starts_with("claude") {
- return Some(ProviderMetadata {
- provider: ProviderKind::Anthropic,
- auth_env: "ANTHROPIC_API_KEY",
- base_url_env: "ANTHROPIC_BASE_URL",
- default_base_url: anthropic::DEFAULT_BASE_URL,
- });
- }
- if canonical.starts_with("grok") {
- return Some(ProviderMetadata {
- provider: ProviderKind::Xai,
- auth_env: "XAI_API_KEY",
- base_url_env: "XAI_BASE_URL",
- default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
- });
- }
- None
- }
- #[must_use]
- pub fn detect_provider_kind(model: &str) -> ProviderKind {
- if let Some(metadata) = metadata_for_model(model) {
- return metadata.provider;
- }
- if anthropic::has_auth_from_env_or_saved().unwrap_or(false) {
- return ProviderKind::Anthropic;
- }
- if openai_compat::has_api_key("OPENAI_API_KEY") {
- return ProviderKind::OpenAi;
- }
- if openai_compat::has_api_key("XAI_API_KEY") {
- return ProviderKind::Xai;
- }
- ProviderKind::Anthropic
- }
- #[must_use]
- pub fn max_tokens_for_model(model: &str) -> u32 {
- let canonical = resolve_model_alias(model);
- if canonical.contains("opus") {
- 32_000
- } else {
- 64_000
- }
- }
- #[cfg(test)]
- mod tests {
- use super::{detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind};
- #[test]
- fn resolves_grok_aliases() {
- assert_eq!(resolve_model_alias("grok"), "grok-3");
- assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini");
- assert_eq!(resolve_model_alias("grok-2"), "grok-2");
- }
- #[test]
- fn detects_provider_from_model_name_first() {
- assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai);
- assert_eq!(
- detect_provider_kind("claude-sonnet-4-6"),
- ProviderKind::Anthropic
- );
- }
- #[test]
- fn keeps_existing_max_token_heuristic() {
- assert_eq!(max_tokens_for_model("opus"), 32_000);
- assert_eq!(max_tokens_for_model("grok-3"), 64_000);
- }
- }
|