client.rs 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. use crate::error::ApiError;
  2. use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
  3. use crate::providers::anthropic::{self, AnthropicClient, AuthSource};
  4. use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig};
  5. use crate::providers::{self, Provider, ProviderKind};
  6. use crate::types::{MessageRequest, MessageResponse, StreamEvent};
  7. async fn send_via_provider<P: Provider>(
  8. provider: &P,
  9. request: &MessageRequest,
  10. ) -> Result<MessageResponse, ApiError> {
  11. provider.send_message(request).await
  12. }
  13. async fn stream_via_provider<P: Provider>(
  14. provider: &P,
  15. request: &MessageRequest,
  16. ) -> Result<P::Stream, ApiError> {
  17. provider.stream_message(request).await
  18. }
  19. #[allow(clippy::large_enum_variant)]
  20. #[derive(Debug, Clone)]
  21. pub enum ProviderClient {
  22. Anthropic(AnthropicClient),
  23. Xai(OpenAiCompatClient),
  24. OpenAi(OpenAiCompatClient),
  25. }
  26. impl ProviderClient {
  27. pub fn from_model(model: &str) -> Result<Self, ApiError> {
  28. Self::from_model_with_anthropic_auth(model, None)
  29. }
  30. pub fn from_model_with_anthropic_auth(
  31. model: &str,
  32. anthropic_auth: Option<AuthSource>,
  33. ) -> Result<Self, ApiError> {
  34. let resolved_model = providers::resolve_model_alias(model);
  35. match providers::detect_provider_kind(&resolved_model) {
  36. ProviderKind::Anthropic => Ok(Self::Anthropic(match anthropic_auth {
  37. Some(auth) => AnthropicClient::from_auth(auth),
  38. None => AnthropicClient::from_env()?,
  39. })),
  40. ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env(
  41. OpenAiCompatConfig::xai(),
  42. )?)),
  43. ProviderKind::OpenAi => Ok(Self::OpenAi(OpenAiCompatClient::from_env(
  44. OpenAiCompatConfig::openai(),
  45. )?)),
  46. }
  47. }
  48. #[must_use]
  49. pub const fn provider_kind(&self) -> ProviderKind {
  50. match self {
  51. Self::Anthropic(_) => ProviderKind::Anthropic,
  52. Self::Xai(_) => ProviderKind::Xai,
  53. Self::OpenAi(_) => ProviderKind::OpenAi,
  54. }
  55. }
  56. #[must_use]
  57. pub fn with_prompt_cache(self, prompt_cache: PromptCache) -> Self {
  58. match self {
  59. Self::Anthropic(client) => Self::Anthropic(client.with_prompt_cache(prompt_cache)),
  60. other => other,
  61. }
  62. }
  63. #[must_use]
  64. pub fn prompt_cache_stats(&self) -> Option<PromptCacheStats> {
  65. match self {
  66. Self::Anthropic(client) => client.prompt_cache_stats(),
  67. Self::Xai(_) | Self::OpenAi(_) => None,
  68. }
  69. }
  70. #[must_use]
  71. pub fn take_last_prompt_cache_record(&self) -> Option<PromptCacheRecord> {
  72. match self {
  73. Self::Anthropic(client) => client.take_last_prompt_cache_record(),
  74. Self::Xai(_) | Self::OpenAi(_) => None,
  75. }
  76. }
  77. pub async fn send_message(
  78. &self,
  79. request: &MessageRequest,
  80. ) -> Result<MessageResponse, ApiError> {
  81. match self {
  82. Self::Anthropic(client) => send_via_provider(client, request).await,
  83. Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await,
  84. }
  85. }
  86. pub async fn stream_message(
  87. &self,
  88. request: &MessageRequest,
  89. ) -> Result<MessageStream, ApiError> {
  90. match self {
  91. Self::Anthropic(client) => stream_via_provider(client, request)
  92. .await
  93. .map(MessageStream::Anthropic),
  94. Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request)
  95. .await
  96. .map(MessageStream::OpenAiCompat),
  97. }
  98. }
  99. }
  100. #[derive(Debug)]
  101. pub enum MessageStream {
  102. Anthropic(anthropic::MessageStream),
  103. OpenAiCompat(openai_compat::MessageStream),
  104. }
  105. impl MessageStream {
  106. #[must_use]
  107. pub fn request_id(&self) -> Option<&str> {
  108. match self {
  109. Self::Anthropic(stream) => stream.request_id(),
  110. Self::OpenAiCompat(stream) => stream.request_id(),
  111. }
  112. }
  113. pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
  114. match self {
  115. Self::Anthropic(stream) => stream.next_event().await,
  116. Self::OpenAiCompat(stream) => stream.next_event().await,
  117. }
  118. }
  119. }
  120. pub use anthropic::{
  121. oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, OAuthTokenSet,
  122. };
  123. #[must_use]
  124. pub fn read_base_url() -> String {
  125. anthropic::read_base_url()
  126. }
  127. #[must_use]
  128. pub fn read_xai_base_url() -> String {
  129. openai_compat::read_base_url(OpenAiCompatConfig::xai())
  130. }
  131. #[cfg(test)]
  132. mod tests {
  133. use crate::providers::{detect_provider_kind, resolve_model_alias, ProviderKind};
  134. #[test]
  135. fn resolves_existing_and_grok_aliases() {
  136. assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6");
  137. assert_eq!(resolve_model_alias("grok"), "grok-3");
  138. assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini");
  139. }
  140. #[test]
  141. fn provider_detection_prefers_model_family() {
  142. assert_eq!(detect_provider_kind("grok-3"), ProviderKind::Xai);
  143. assert_eq!(
  144. detect_provider_kind("claude-sonnet-4-6"),
  145. ProviderKind::Anthropic
  146. );
  147. }
  148. }