mod.rs 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. use std::future::Future;
  2. use std::pin::Pin;
  3. use crate::error::ApiError;
  4. use crate::types::{MessageRequest, MessageResponse};
  5. pub mod anthropic;
  6. pub mod openai_compat;
  7. pub type ProviderFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, ApiError>> + Send + 'a>>;
  8. pub trait Provider {
  9. type Stream;
  10. fn send_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, MessageResponse>;
  11. fn stream_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, Self::Stream>;
  12. }
  13. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  14. pub enum ProviderKind {
  15. Anthropic,
  16. Xai,
  17. OpenAi,
  18. }
  19. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  20. pub struct ProviderMetadata {
  21. pub provider: ProviderKind,
  22. pub canonical_model: &'static str,
  23. pub auth_env: &'static str,
  24. pub base_url_env: &'static str,
  25. pub default_base_url: &'static str,
  26. }
  27. const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
  28. (
  29. "opus",
  30. ProviderMetadata {
  31. provider: ProviderKind::Anthropic,
  32. canonical_model: "claude-opus-4-6",
  33. auth_env: "ANTHROPIC_API_KEY",
  34. base_url_env: "ANTHROPIC_BASE_URL",
  35. default_base_url: anthropic::DEFAULT_BASE_URL,
  36. },
  37. ),
  38. (
  39. "sonnet",
  40. ProviderMetadata {
  41. provider: ProviderKind::Anthropic,
  42. canonical_model: "claude-sonnet-4-6",
  43. auth_env: "ANTHROPIC_API_KEY",
  44. base_url_env: "ANTHROPIC_BASE_URL",
  45. default_base_url: anthropic::DEFAULT_BASE_URL,
  46. },
  47. ),
  48. (
  49. "haiku",
  50. ProviderMetadata {
  51. provider: ProviderKind::Anthropic,
  52. canonical_model: "claude-haiku-4-5-20251213",
  53. auth_env: "ANTHROPIC_API_KEY",
  54. base_url_env: "ANTHROPIC_BASE_URL",
  55. default_base_url: anthropic::DEFAULT_BASE_URL,
  56. },
  57. ),
  58. (
  59. "grok",
  60. ProviderMetadata {
  61. provider: ProviderKind::Xai,
  62. canonical_model: "grok-3",
  63. auth_env: "XAI_API_KEY",
  64. base_url_env: "XAI_BASE_URL",
  65. default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
  66. },
  67. ),
  68. (
  69. "grok-3",
  70. ProviderMetadata {
  71. provider: ProviderKind::Xai,
  72. canonical_model: "grok-3",
  73. auth_env: "XAI_API_KEY",
  74. base_url_env: "XAI_BASE_URL",
  75. default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
  76. },
  77. ),
  78. (
  79. "grok-mini",
  80. ProviderMetadata {
  81. provider: ProviderKind::Xai,
  82. canonical_model: "grok-3-mini",
  83. auth_env: "XAI_API_KEY",
  84. base_url_env: "XAI_BASE_URL",
  85. default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
  86. },
  87. ),
  88. (
  89. "grok-3-mini",
  90. ProviderMetadata {
  91. provider: ProviderKind::Xai,
  92. canonical_model: "grok-3-mini",
  93. auth_env: "XAI_API_KEY",
  94. base_url_env: "XAI_BASE_URL",
  95. default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
  96. },
  97. ),
  98. (
  99. "grok-2",
  100. ProviderMetadata {
  101. provider: ProviderKind::Xai,
  102. canonical_model: "grok-2",
  103. auth_env: "XAI_API_KEY",
  104. base_url_env: "XAI_BASE_URL",
  105. default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
  106. },
  107. ),
  108. ];
  109. #[must_use]
  110. pub fn resolve_model_alias(model: &str) -> String {
  111. let trimmed = model.trim();
  112. let lower = trimmed.to_ascii_lowercase();
  113. MODEL_REGISTRY
  114. .iter()
  115. .find_map(|(alias, metadata)| (*alias == lower).then_some(metadata.canonical_model))
  116. .map_or_else(|| trimmed.to_string(), ToOwned::to_owned)
  117. }
  118. #[must_use]
  119. pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
  120. let canonical = resolve_model_alias(model);
  121. if canonical.starts_with("claude") {
  122. return Some(ProviderMetadata {
  123. provider: ProviderKind::Anthropic,
  124. canonical_model: Box::leak(canonical.into_boxed_str()),
  125. auth_env: "ANTHROPIC_API_KEY",
  126. base_url_env: "ANTHROPIC_BASE_URL",
  127. default_base_url: anthropic::DEFAULT_BASE_URL,
  128. });
  129. }
  130. if canonical.starts_with("grok") {
  131. return Some(ProviderMetadata {
  132. provider: ProviderKind::Xai,
  133. canonical_model: Box::leak(canonical.into_boxed_str()),
  134. auth_env: "XAI_API_KEY",
  135. base_url_env: "XAI_BASE_URL",
  136. default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
  137. });
  138. }
  139. None
  140. }
  141. #[must_use]
  142. pub fn detect_provider_kind(model: &str) -> ProviderKind {
  143. if let Some(metadata) = metadata_for_model(model) {
  144. return metadata.provider;
  145. }
  146. if anthropic::has_auth_from_env_or_saved().unwrap_or(false) {
  147. return ProviderKind::Anthropic;
  148. }
  149. if openai_compat::has_api_key("OPENAI_API_KEY") {
  150. return ProviderKind::OpenAi;
  151. }
  152. if openai_compat::has_api_key("XAI_API_KEY") {
  153. return ProviderKind::Xai;
  154. }
  155. ProviderKind::Anthropic
  156. }
  157. #[must_use]
  158. pub fn max_tokens_for_model(model: &str) -> u32 {
  159. let canonical = resolve_model_alias(model);
  160. if canonical.contains("opus") {
  161. 32_000
  162. } else {
  163. 64_000
  164. }
  165. }
  166. #[cfg(test)]
  167. mod tests {
  168. use super::{detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind};
  169. #[test]
  170. fn resolves_grok_aliases() {
  171. assert_eq!(resolve_model_alias("grok"), "grok-3");
  172. assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini");
  173. assert_eq!(resolve_model_alias("grok-2"), "grok-2");
  174. }
  175. #[test]
  176. fn detects_provider_from_model_name_first() {
  177. assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai);
  178. assert_eq!(detect_provider_kind("claude-sonnet-4-6"), ProviderKind::Anthropic);
  179. }
  180. #[test]
  181. fn keeps_existing_max_token_heuristic() {
  182. assert_eq!(max_tokens_for_model("opus"), 32_000);
  183. assert_eq!(max_tokens_for_model("grok-3"), 64_000);
  184. }
  185. }