mod.rs 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. use std::future::Future;
  2. use std::pin::Pin;
  3. use crate::error::ApiError;
  4. use crate::types::{MessageRequest, MessageResponse};
  5. pub mod anthropic;
  6. pub mod openai_compat;
  7. pub type ProviderFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T, ApiError>> + Send + 'a>>;
  8. pub trait Provider {
  9. type Stream;
  10. fn send_message<'a>(
  11. &'a self,
  12. request: &'a MessageRequest,
  13. ) -> ProviderFuture<'a, MessageResponse>;
  14. fn stream_message<'a>(
  15. &'a self,
  16. request: &'a MessageRequest,
  17. ) -> ProviderFuture<'a, Self::Stream>;
  18. }
  19. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  20. pub enum ProviderKind {
  21. Anthropic,
  22. Xai,
  23. OpenAi,
  24. }
  25. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  26. pub struct ProviderMetadata {
  27. pub provider: ProviderKind,
  28. pub auth_env: &'static str,
  29. pub base_url_env: &'static str,
  30. pub default_base_url: &'static str,
  31. }
  32. const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[
  33. (
  34. "opus",
  35. ProviderMetadata {
  36. provider: ProviderKind::Anthropic,
  37. auth_env: "ANTHROPIC_API_KEY",
  38. base_url_env: "ANTHROPIC_BASE_URL",
  39. default_base_url: anthropic::DEFAULT_BASE_URL,
  40. },
  41. ),
  42. (
  43. "sonnet",
  44. ProviderMetadata {
  45. provider: ProviderKind::Anthropic,
  46. auth_env: "ANTHROPIC_API_KEY",
  47. base_url_env: "ANTHROPIC_BASE_URL",
  48. default_base_url: anthropic::DEFAULT_BASE_URL,
  49. },
  50. ),
  51. (
  52. "haiku",
  53. ProviderMetadata {
  54. provider: ProviderKind::Anthropic,
  55. auth_env: "ANTHROPIC_API_KEY",
  56. base_url_env: "ANTHROPIC_BASE_URL",
  57. default_base_url: anthropic::DEFAULT_BASE_URL,
  58. },
  59. ),
  60. (
  61. "grok",
  62. ProviderMetadata {
  63. provider: ProviderKind::Xai,
  64. auth_env: "XAI_API_KEY",
  65. base_url_env: "XAI_BASE_URL",
  66. default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
  67. },
  68. ),
  69. (
  70. "grok-3",
  71. ProviderMetadata {
  72. provider: ProviderKind::Xai,
  73. auth_env: "XAI_API_KEY",
  74. base_url_env: "XAI_BASE_URL",
  75. default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
  76. },
  77. ),
  78. (
  79. "grok-mini",
  80. ProviderMetadata {
  81. provider: ProviderKind::Xai,
  82. auth_env: "XAI_API_KEY",
  83. base_url_env: "XAI_BASE_URL",
  84. default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
  85. },
  86. ),
  87. (
  88. "grok-3-mini",
  89. ProviderMetadata {
  90. provider: ProviderKind::Xai,
  91. auth_env: "XAI_API_KEY",
  92. base_url_env: "XAI_BASE_URL",
  93. default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
  94. },
  95. ),
  96. (
  97. "grok-2",
  98. ProviderMetadata {
  99. provider: ProviderKind::Xai,
  100. auth_env: "XAI_API_KEY",
  101. base_url_env: "XAI_BASE_URL",
  102. default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
  103. },
  104. ),
  105. ];
  106. #[must_use]
  107. pub fn resolve_model_alias(model: &str) -> String {
  108. let trimmed = model.trim();
  109. let lower = trimmed.to_ascii_lowercase();
  110. MODEL_REGISTRY
  111. .iter()
  112. .find_map(|(alias, metadata)| {
  113. (*alias == lower).then_some(match metadata.provider {
  114. ProviderKind::Anthropic => match *alias {
  115. "opus" => "claude-opus-4-6",
  116. "sonnet" => "claude-sonnet-4-6",
  117. "haiku" => "claude-haiku-4-5-20251213",
  118. _ => trimmed,
  119. },
  120. ProviderKind::Xai => match *alias {
  121. "grok" | "grok-3" => "grok-3",
  122. "grok-mini" | "grok-3-mini" => "grok-3-mini",
  123. "grok-2" => "grok-2",
  124. _ => trimmed,
  125. },
  126. ProviderKind::OpenAi => trimmed,
  127. })
  128. })
  129. .map_or_else(|| trimmed.to_string(), ToOwned::to_owned)
  130. }
  131. #[must_use]
  132. pub fn metadata_for_model(model: &str) -> Option<ProviderMetadata> {
  133. let canonical = resolve_model_alias(model);
  134. if canonical.starts_with("claude") {
  135. return Some(ProviderMetadata {
  136. provider: ProviderKind::Anthropic,
  137. auth_env: "ANTHROPIC_API_KEY",
  138. base_url_env: "ANTHROPIC_BASE_URL",
  139. default_base_url: anthropic::DEFAULT_BASE_URL,
  140. });
  141. }
  142. if canonical.starts_with("grok") {
  143. return Some(ProviderMetadata {
  144. provider: ProviderKind::Xai,
  145. auth_env: "XAI_API_KEY",
  146. base_url_env: "XAI_BASE_URL",
  147. default_base_url: openai_compat::DEFAULT_XAI_BASE_URL,
  148. });
  149. }
  150. None
  151. }
  152. #[must_use]
  153. pub fn detect_provider_kind(model: &str) -> ProviderKind {
  154. if let Some(metadata) = metadata_for_model(model) {
  155. return metadata.provider;
  156. }
  157. if anthropic::has_auth_from_env_or_saved().unwrap_or(false) {
  158. return ProviderKind::Anthropic;
  159. }
  160. if openai_compat::has_api_key("OPENAI_API_KEY") {
  161. return ProviderKind::OpenAi;
  162. }
  163. if openai_compat::has_api_key("XAI_API_KEY") {
  164. return ProviderKind::Xai;
  165. }
  166. ProviderKind::Anthropic
  167. }
  168. #[must_use]
  169. pub fn max_tokens_for_model(model: &str) -> u32 {
  170. let canonical = resolve_model_alias(model);
  171. if canonical.contains("opus") {
  172. 32_000
  173. } else {
  174. 64_000
  175. }
  176. }
  177. #[cfg(test)]
  178. mod tests {
  179. use super::{detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind};
  180. #[test]
  181. fn resolves_grok_aliases() {
  182. assert_eq!(resolve_model_alias("grok"), "grok-3");
  183. assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini");
  184. assert_eq!(resolve_model_alias("grok-2"), "grok-2");
  185. }
  186. #[test]
  187. fn detects_provider_from_model_name_first() {
  188. assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai);
  189. assert_eq!(
  190. detect_provider_kind("claude-sonnet-4-6"),
  191. ProviderKind::Anthropic
  192. );
  193. }
  194. #[test]
  195. fn keeps_existing_max_token_heuristic() {
  196. assert_eq!(max_tokens_for_model("opus"), 32_000);
  197. assert_eq!(max_tokens_for_model("grok-3"), 64_000);
  198. }
  199. }