| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104 |
- use std::collections::{BTreeMap, VecDeque};
- use std::time::Duration;
- use serde::Deserialize;
- use serde_json::{json, Value};
- use crate::error::ApiError;
- use crate::types::{
- ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
- InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
- MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
- ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
- };
- use super::{Provider, ProviderFuture};
- pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
- pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
- const REQUEST_ID_HEADER: &str = "request-id";
- const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
- const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
- const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
- const DEFAULT_MAX_RETRIES: u32 = 2;
- #[derive(Debug, Clone, Copy, PartialEq, Eq)]
- pub struct OpenAiCompatConfig {
- pub provider_name: &'static str,
- pub api_key_env: &'static str,
- pub base_url_env: &'static str,
- pub default_base_url: &'static str,
- }
- const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
- const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
- impl OpenAiCompatConfig {
- #[must_use]
- pub const fn xai() -> Self {
- Self {
- provider_name: "xAI",
- api_key_env: "XAI_API_KEY",
- base_url_env: "XAI_BASE_URL",
- default_base_url: DEFAULT_XAI_BASE_URL,
- }
- }
- #[must_use]
- pub const fn openai() -> Self {
- Self {
- provider_name: "OpenAI",
- api_key_env: "OPENAI_API_KEY",
- base_url_env: "OPENAI_BASE_URL",
- default_base_url: DEFAULT_OPENAI_BASE_URL,
- }
- }
- #[must_use]
- pub fn credential_env_vars(self) -> &'static [&'static str] {
- match self.provider_name {
- "xAI" => XAI_ENV_VARS,
- "OpenAI" => OPENAI_ENV_VARS,
- _ => &[],
- }
- }
- }
- #[derive(Debug, Clone)]
- pub struct OpenAiCompatClient {
- http: reqwest::Client,
- api_key: String,
- config: OpenAiCompatConfig,
- base_url: String,
- max_retries: u32,
- initial_backoff: Duration,
- max_backoff: Duration,
- }
- impl OpenAiCompatClient {
- const fn config(&self) -> OpenAiCompatConfig {
- self.config
- }
- #[must_use]
- pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
- Self {
- http: reqwest::Client::new(),
- api_key: api_key.into(),
- config,
- base_url: read_base_url(config),
- max_retries: DEFAULT_MAX_RETRIES,
- initial_backoff: DEFAULT_INITIAL_BACKOFF,
- max_backoff: DEFAULT_MAX_BACKOFF,
- }
- }
- pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
- let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
- return Err(ApiError::missing_credentials(
- config.provider_name,
- config.credential_env_vars(),
- ));
- };
- Ok(Self::new(api_key, config))
- }
- #[must_use]
- pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
- self.base_url = base_url.into();
- self
- }
- #[must_use]
- pub fn with_retry_policy(
- mut self,
- max_retries: u32,
- initial_backoff: Duration,
- max_backoff: Duration,
- ) -> Self {
- self.max_retries = max_retries;
- self.initial_backoff = initial_backoff;
- self.max_backoff = max_backoff;
- self
- }
- pub async fn send_message(
- &self,
- request: &MessageRequest,
- ) -> Result<MessageResponse, ApiError> {
- let request = MessageRequest {
- stream: false,
- ..request.clone()
- };
- let response = self.send_with_retry(&request).await?;
- let request_id = request_id_from_headers(response.headers());
- let payload = response.json::<ChatCompletionResponse>().await?;
- let mut normalized = normalize_response(&request.model, payload)?;
- if normalized.request_id.is_none() {
- normalized.request_id = request_id;
- }
- Ok(normalized)
- }
- pub async fn stream_message(
- &self,
- request: &MessageRequest,
- ) -> Result<MessageStream, ApiError> {
- let response = self
- .send_with_retry(&request.clone().with_streaming())
- .await?;
- Ok(MessageStream {
- request_id: request_id_from_headers(response.headers()),
- response,
- parser: OpenAiSseParser::new(),
- pending: VecDeque::new(),
- done: false,
- state: StreamState::new(request.model.clone()),
- })
- }
- async fn send_with_retry(
- &self,
- request: &MessageRequest,
- ) -> Result<reqwest::Response, ApiError> {
- let mut attempts = 0;
- let last_error = loop {
- attempts += 1;
- let retryable_error = match self.send_raw_request(request).await {
- Ok(response) => match expect_success(response).await {
- Ok(response) => return Ok(response),
- Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
- Err(error) => return Err(error),
- },
- Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
- Err(error) => return Err(error),
- };
- if attempts > self.max_retries {
- break retryable_error;
- }
- tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
- };
- Err(ApiError::RetriesExhausted {
- attempts,
- last_error: Box::new(last_error),
- })
- }
- async fn send_raw_request(
- &self,
- request: &MessageRequest,
- ) -> Result<reqwest::Response, ApiError> {
- let request_url = chat_completions_endpoint(&self.base_url);
- self.http
- .post(&request_url)
- .header("content-type", "application/json")
- .bearer_auth(&self.api_key)
- .json(&build_chat_completion_request(request, self.config()))
- .send()
- .await
- .map_err(ApiError::from)
- }
- fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
- let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
- return Err(ApiError::BackoffOverflow {
- attempt,
- base_delay: self.initial_backoff,
- });
- };
- Ok(self
- .initial_backoff
- .checked_mul(multiplier)
- .map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
- }
- }
- impl Provider for OpenAiCompatClient {
- type Stream = MessageStream;
- fn send_message<'a>(
- &'a self,
- request: &'a MessageRequest,
- ) -> ProviderFuture<'a, MessageResponse> {
- Box::pin(async move { self.send_message(request).await })
- }
- fn stream_message<'a>(
- &'a self,
- request: &'a MessageRequest,
- ) -> ProviderFuture<'a, Self::Stream> {
- Box::pin(async move { self.stream_message(request).await })
- }
- }
- #[derive(Debug)]
- pub struct MessageStream {
- request_id: Option<String>,
- response: reqwest::Response,
- parser: OpenAiSseParser,
- pending: VecDeque<StreamEvent>,
- done: bool,
- state: StreamState,
- }
- impl MessageStream {
- #[must_use]
- pub fn request_id(&self) -> Option<&str> {
- self.request_id.as_deref()
- }
- pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
- loop {
- if let Some(event) = self.pending.pop_front() {
- return Ok(Some(event));
- }
- if self.done {
- self.pending.extend(self.state.finish()?);
- if let Some(event) = self.pending.pop_front() {
- return Ok(Some(event));
- }
- return Ok(None);
- }
- match self.response.chunk().await? {
- Some(chunk) => {
- for parsed in self.parser.push(&chunk)? {
- self.pending.extend(self.state.ingest_chunk(parsed)?);
- }
- }
- None => {
- self.done = true;
- }
- }
- }
- }
- }
- #[derive(Debug, Default)]
- struct OpenAiSseParser {
- buffer: Vec<u8>,
- }
- impl OpenAiSseParser {
- fn new() -> Self {
- Self::default()
- }
- fn push(&mut self, chunk: &[u8]) -> Result<Vec<ChatCompletionChunk>, ApiError> {
- self.buffer.extend_from_slice(chunk);
- let mut events = Vec::new();
- while let Some(frame) = next_sse_frame(&mut self.buffer) {
- if let Some(event) = parse_sse_frame(&frame)? {
- events.push(event);
- }
- }
- Ok(events)
- }
- }
- #[allow(clippy::struct_excessive_bools)]
- #[derive(Debug)]
- struct StreamState {
- model: String,
- message_started: bool,
- text_started: bool,
- text_finished: bool,
- finished: bool,
- stop_reason: Option<String>,
- usage: Option<Usage>,
- tool_calls: BTreeMap<u32, ToolCallState>,
- }
- impl StreamState {
- fn new(model: String) -> Self {
- Self {
- model,
- message_started: false,
- text_started: false,
- text_finished: false,
- finished: false,
- stop_reason: None,
- usage: None,
- tool_calls: BTreeMap::new(),
- }
- }
- fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result<Vec<StreamEvent>, ApiError> {
- let mut events = Vec::new();
- if !self.message_started {
- self.message_started = true;
- events.push(StreamEvent::MessageStart(MessageStartEvent {
- message: MessageResponse {
- id: chunk.id.clone(),
- kind: "message".to_string(),
- role: "assistant".to_string(),
- content: Vec::new(),
- model: chunk.model.clone().unwrap_or_else(|| self.model.clone()),
- stop_reason: None,
- stop_sequence: None,
- usage: Usage {
- input_tokens: 0,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- output_tokens: 0,
- },
- request_id: None,
- },
- }));
- }
- if let Some(usage) = chunk.usage {
- self.usage = Some(Usage {
- input_tokens: usage.prompt_tokens,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- output_tokens: usage.completion_tokens,
- });
- }
- for choice in chunk.choices {
- if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) {
- if !self.text_started {
- self.text_started = true;
- events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent {
- index: 0,
- content_block: OutputContentBlock::Text {
- text: String::new(),
- },
- }));
- }
- events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
- index: 0,
- delta: ContentBlockDelta::TextDelta { text: content },
- }));
- }
- for tool_call in choice.delta.tool_calls {
- let state = self.tool_calls.entry(tool_call.index).or_default();
- state.apply(tool_call);
- let block_index = state.block_index();
- if !state.started {
- if let Some(start_event) = state.start_event()? {
- state.started = true;
- events.push(StreamEvent::ContentBlockStart(start_event));
- } else {
- continue;
- }
- }
- if let Some(delta_event) = state.delta_event() {
- events.push(StreamEvent::ContentBlockDelta(delta_event));
- }
- if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped {
- state.stopped = true;
- events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
- index: block_index,
- }));
- }
- }
- if let Some(finish_reason) = choice.finish_reason {
- self.stop_reason = Some(normalize_finish_reason(&finish_reason));
- if finish_reason == "tool_calls" {
- for state in self.tool_calls.values_mut() {
- if state.started && !state.stopped {
- state.stopped = true;
- events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
- index: state.block_index(),
- }));
- }
- }
- }
- }
- }
- Ok(events)
- }
- fn finish(&mut self) -> Result<Vec<StreamEvent>, ApiError> {
- if self.finished {
- return Ok(Vec::new());
- }
- self.finished = true;
- let mut events = Vec::new();
- if self.text_started && !self.text_finished {
- self.text_finished = true;
- events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
- index: 0,
- }));
- }
- for state in self.tool_calls.values_mut() {
- if !state.started {
- if let Some(start_event) = state.start_event()? {
- state.started = true;
- events.push(StreamEvent::ContentBlockStart(start_event));
- if let Some(delta_event) = state.delta_event() {
- events.push(StreamEvent::ContentBlockDelta(delta_event));
- }
- }
- }
- if state.started && !state.stopped {
- state.stopped = true;
- events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
- index: state.block_index(),
- }));
- }
- }
- if self.message_started {
- events.push(StreamEvent::MessageDelta(MessageDeltaEvent {
- delta: MessageDelta {
- stop_reason: Some(
- self.stop_reason
- .clone()
- .unwrap_or_else(|| "end_turn".to_string()),
- ),
- stop_sequence: None,
- },
- usage: self.usage.clone().unwrap_or(Usage {
- input_tokens: 0,
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- output_tokens: 0,
- }),
- }));
- events.push(StreamEvent::MessageStop(MessageStopEvent {}));
- }
- Ok(events)
- }
- }
- #[derive(Debug, Default)]
- struct ToolCallState {
- openai_index: u32,
- id: Option<String>,
- name: Option<String>,
- arguments: String,
- emitted_len: usize,
- started: bool,
- stopped: bool,
- }
- impl ToolCallState {
- fn apply(&mut self, tool_call: DeltaToolCall) {
- self.openai_index = tool_call.index;
- if let Some(id) = tool_call.id {
- self.id = Some(id);
- }
- if let Some(name) = tool_call.function.name {
- self.name = Some(name);
- }
- if let Some(arguments) = tool_call.function.arguments {
- self.arguments.push_str(&arguments);
- }
- }
- const fn block_index(&self) -> u32 {
- self.openai_index + 1
- }
- #[allow(clippy::unnecessary_wraps)]
- fn start_event(&self) -> Result<Option<ContentBlockStartEvent>, ApiError> {
- let Some(name) = self.name.clone() else {
- return Ok(None);
- };
- let id = self
- .id
- .clone()
- .unwrap_or_else(|| format!("tool_call_{}", self.openai_index));
- Ok(Some(ContentBlockStartEvent {
- index: self.block_index(),
- content_block: OutputContentBlock::ToolUse {
- id,
- name,
- input: json!({}),
- },
- }))
- }
- fn delta_event(&mut self) -> Option<ContentBlockDeltaEvent> {
- if self.emitted_len >= self.arguments.len() {
- return None;
- }
- let delta = self.arguments[self.emitted_len..].to_string();
- self.emitted_len = self.arguments.len();
- Some(ContentBlockDeltaEvent {
- index: self.block_index(),
- delta: ContentBlockDelta::InputJsonDelta {
- partial_json: delta,
- },
- })
- }
- }
- #[derive(Debug, Deserialize)]
- struct ChatCompletionResponse {
- id: String,
- model: String,
- choices: Vec<ChatChoice>,
- #[serde(default)]
- usage: Option<OpenAiUsage>,
- }
- #[derive(Debug, Deserialize)]
- struct ChatChoice {
- message: ChatMessage,
- #[serde(default)]
- finish_reason: Option<String>,
- }
- #[derive(Debug, Deserialize)]
- struct ChatMessage {
- role: String,
- #[serde(default)]
- content: Option<String>,
- #[serde(default)]
- tool_calls: Vec<ResponseToolCall>,
- }
- #[derive(Debug, Deserialize)]
- struct ResponseToolCall {
- id: String,
- function: ResponseToolFunction,
- }
- #[derive(Debug, Deserialize)]
- struct ResponseToolFunction {
- name: String,
- arguments: String,
- }
- #[derive(Debug, Deserialize)]
- struct OpenAiUsage {
- #[serde(default)]
- prompt_tokens: u32,
- #[serde(default)]
- completion_tokens: u32,
- }
- #[derive(Debug, Deserialize)]
- struct ChatCompletionChunk {
- id: String,
- #[serde(default)]
- model: Option<String>,
- #[serde(default)]
- choices: Vec<ChunkChoice>,
- #[serde(default)]
- usage: Option<OpenAiUsage>,
- }
- #[derive(Debug, Deserialize)]
- struct ChunkChoice {
- delta: ChunkDelta,
- #[serde(default)]
- finish_reason: Option<String>,
- }
- #[derive(Debug, Default, Deserialize)]
- struct ChunkDelta {
- #[serde(default)]
- content: Option<String>,
- #[serde(default)]
- tool_calls: Vec<DeltaToolCall>,
- }
- #[derive(Debug, Deserialize)]
- struct DeltaToolCall {
- #[serde(default)]
- index: u32,
- #[serde(default)]
- id: Option<String>,
- #[serde(default)]
- function: DeltaFunction,
- }
- #[derive(Debug, Default, Deserialize)]
- struct DeltaFunction {
- #[serde(default)]
- name: Option<String>,
- #[serde(default)]
- arguments: Option<String>,
- }
- #[derive(Debug, Deserialize)]
- struct ErrorEnvelope {
- error: ErrorBody,
- }
- #[derive(Debug, Deserialize)]
- struct ErrorBody {
- #[serde(rename = "type")]
- error_type: Option<String>,
- message: Option<String>,
- }
- fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatConfig) -> Value {
- let mut messages = Vec::new();
- if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
- messages.push(json!({
- "role": "system",
- "content": system,
- }));
- }
- for message in &request.messages {
- messages.extend(translate_message(message));
- }
- let mut payload = json!({
- "model": request.model,
- "max_tokens": request.max_tokens,
- "messages": messages,
- "stream": request.stream,
- });
- if request.stream && should_request_stream_usage(config) {
- payload["stream_options"] = json!({ "include_usage": true });
- }
- if let Some(tools) = &request.tools {
- payload["tools"] =
- Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
- }
- if let Some(tool_choice) = &request.tool_choice {
- payload["tool_choice"] = openai_tool_choice(tool_choice);
- }
- payload
- }
- fn translate_message(message: &InputMessage) -> Vec<Value> {
- match message.role.as_str() {
- "assistant" => {
- let mut text = String::new();
- let mut tool_calls = Vec::new();
- for block in &message.content {
- match block {
- InputContentBlock::Text { text: value } => text.push_str(value),
- InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({
- "id": id,
- "type": "function",
- "function": {
- "name": name,
- "arguments": input.to_string(),
- }
- })),
- InputContentBlock::ToolResult { .. } => {}
- }
- }
- if text.is_empty() && tool_calls.is_empty() {
- Vec::new()
- } else {
- vec![json!({
- "role": "assistant",
- "content": (!text.is_empty()).then_some(text),
- "tool_calls": tool_calls,
- })]
- }
- }
- _ => message
- .content
- .iter()
- .filter_map(|block| match block {
- InputContentBlock::Text { text } => Some(json!({
- "role": "user",
- "content": text,
- })),
- InputContentBlock::ToolResult {
- tool_use_id,
- content,
- is_error,
- } => Some(json!({
- "role": "tool",
- "tool_call_id": tool_use_id,
- "content": flatten_tool_result_content(content),
- "is_error": is_error,
- })),
- InputContentBlock::ToolUse { .. } => None,
- })
- .collect(),
- }
- }
- fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String {
- content
- .iter()
- .map(|block| match block {
- ToolResultContentBlock::Text { text } => text.clone(),
- ToolResultContentBlock::Json { value } => value.to_string(),
- })
- .collect::<Vec<_>>()
- .join("\n")
- }
- fn openai_tool_definition(tool: &ToolDefinition) -> Value {
- json!({
- "type": "function",
- "function": {
- "name": tool.name,
- "description": tool.description,
- "parameters": tool.input_schema,
- }
- })
- }
- fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
- match tool_choice {
- ToolChoice::Auto => Value::String("auto".to_string()),
- ToolChoice::Any => Value::String("required".to_string()),
- ToolChoice::Tool { name } => json!({
- "type": "function",
- "function": { "name": name },
- }),
- }
- }
- fn should_request_stream_usage(config: OpenAiCompatConfig) -> bool {
- matches!(config.provider_name, "OpenAI")
- }
- fn normalize_response(
- model: &str,
- response: ChatCompletionResponse,
- ) -> Result<MessageResponse, ApiError> {
- let choice = response
- .choices
- .into_iter()
- .next()
- .ok_or(ApiError::InvalidSseFrame(
- "chat completion response missing choices",
- ))?;
- let mut content = Vec::new();
- if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) {
- content.push(OutputContentBlock::Text { text });
- }
- for tool_call in choice.message.tool_calls {
- content.push(OutputContentBlock::ToolUse {
- id: tool_call.id,
- name: tool_call.function.name,
- input: parse_tool_arguments(&tool_call.function.arguments),
- });
- }
- Ok(MessageResponse {
- id: response.id,
- kind: "message".to_string(),
- role: choice.message.role,
- content,
- model: response.model.if_empty_then(model.to_string()),
- stop_reason: choice
- .finish_reason
- .map(|value| normalize_finish_reason(&value)),
- stop_sequence: None,
- usage: Usage {
- input_tokens: response
- .usage
- .as_ref()
- .map_or(0, |usage| usage.prompt_tokens),
- cache_creation_input_tokens: 0,
- cache_read_input_tokens: 0,
- output_tokens: response
- .usage
- .as_ref()
- .map_or(0, |usage| usage.completion_tokens),
- },
- request_id: None,
- })
- }
- fn parse_tool_arguments(arguments: &str) -> Value {
- serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments }))
- }
- fn next_sse_frame(buffer: &mut Vec<u8>) -> Option<String> {
- let separator = buffer
- .windows(2)
- .position(|window| window == b"\n\n")
- .map(|position| (position, 2))
- .or_else(|| {
- buffer
- .windows(4)
- .position(|window| window == b"\r\n\r\n")
- .map(|position| (position, 4))
- })?;
- let (position, separator_len) = separator;
- let frame = buffer.drain(..position + separator_len).collect::<Vec<_>>();
- let frame_len = frame.len().saturating_sub(separator_len);
- Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
- }
- fn parse_sse_frame(frame: &str) -> Result<Option<ChatCompletionChunk>, ApiError> {
- let trimmed = frame.trim();
- if trimmed.is_empty() {
- return Ok(None);
- }
- let mut data_lines = Vec::new();
- for line in trimmed.lines() {
- if line.starts_with(':') {
- continue;
- }
- if let Some(data) = line.strip_prefix("data:") {
- data_lines.push(data.trim_start());
- }
- }
- if data_lines.is_empty() {
- return Ok(None);
- }
- let payload = data_lines.join("\n");
- if payload == "[DONE]" {
- return Ok(None);
- }
- serde_json::from_str(&payload)
- .map(Some)
- .map_err(ApiError::from)
- }
- fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
- match std::env::var(key) {
- Ok(value) if !value.is_empty() => Ok(Some(value)),
- Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
- Err(error) => Err(ApiError::from(error)),
- }
- }
- #[must_use]
- pub fn has_api_key(key: &str) -> bool {
- read_env_non_empty(key)
- .ok()
- .and_then(std::convert::identity)
- .is_some()
- }
- #[must_use]
- pub fn read_base_url(config: OpenAiCompatConfig) -> String {
- std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
- }
- fn chat_completions_endpoint(base_url: &str) -> String {
- let trimmed = base_url.trim_end_matches('/');
- if trimmed.ends_with("/chat/completions") {
- trimmed.to_string()
- } else {
- format!("{trimmed}/chat/completions")
- }
- }
- fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
- headers
- .get(REQUEST_ID_HEADER)
- .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
- .and_then(|value| value.to_str().ok())
- .map(ToOwned::to_owned)
- }
- async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
- let status = response.status();
- if status.is_success() {
- return Ok(response);
- }
- let body = response.text().await.unwrap_or_default();
- let parsed_error = serde_json::from_str::<ErrorEnvelope>(&body).ok();
- let retryable = is_retryable_status(status);
- Err(ApiError::Api {
- status,
- error_type: parsed_error
- .as_ref()
- .and_then(|error| error.error.error_type.clone()),
- message: parsed_error
- .as_ref()
- .and_then(|error| error.error.message.clone()),
- body,
- retryable,
- })
- }
- const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
- matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
- }
- fn normalize_finish_reason(value: &str) -> String {
- match value {
- "stop" => "end_turn",
- "tool_calls" => "tool_use",
- other => other,
- }
- .to_string()
- }
- trait StringExt {
- fn if_empty_then(self, fallback: String) -> String;
- }
- impl StringExt for String {
- fn if_empty_then(self, fallback: String) -> String {
- if self.is_empty() {
- fallback
- } else {
- self
- }
- }
- }
- #[cfg(test)]
- mod tests {
- use super::{
- build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason,
- openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig,
- };
- use crate::error::ApiError;
- use crate::types::{
- InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition,
- ToolResultContentBlock,
- };
- use serde_json::json;
- use std::sync::{Mutex, OnceLock};
- #[test]
- fn request_translation_uses_openai_compatible_shape() {
- let payload = build_chat_completion_request(
- &MessageRequest {
- model: "grok-3".to_string(),
- max_tokens: 64,
- messages: vec![InputMessage {
- role: "user".to_string(),
- content: vec![
- InputContentBlock::Text {
- text: "hello".to_string(),
- },
- InputContentBlock::ToolResult {
- tool_use_id: "tool_1".to_string(),
- content: vec![ToolResultContentBlock::Json {
- value: json!({"ok": true}),
- }],
- is_error: false,
- },
- ],
- }],
- system: Some("be helpful".to_string()),
- tools: Some(vec![ToolDefinition {
- name: "weather".to_string(),
- description: Some("Get weather".to_string()),
- input_schema: json!({"type": "object"}),
- }]),
- tool_choice: Some(ToolChoice::Auto),
- stream: false,
- },
- OpenAiCompatConfig::xai(),
- );
- assert_eq!(payload["messages"][0]["role"], json!("system"));
- assert_eq!(payload["messages"][1]["role"], json!("user"));
- assert_eq!(payload["messages"][2]["role"], json!("tool"));
- assert_eq!(payload["tools"][0]["type"], json!("function"));
- assert_eq!(payload["tool_choice"], json!("auto"));
- }
- #[test]
- fn openai_streaming_requests_include_usage_opt_in() {
- let payload = build_chat_completion_request(
- &MessageRequest {
- model: "gpt-5".to_string(),
- max_tokens: 64,
- messages: vec![InputMessage::user_text("hello")],
- system: None,
- tools: None,
- tool_choice: None,
- stream: true,
- },
- OpenAiCompatConfig::openai(),
- );
- assert_eq!(payload["stream_options"], json!({"include_usage": true}));
- }
- #[test]
- fn xai_streaming_requests_skip_openai_specific_usage_opt_in() {
- let payload = build_chat_completion_request(
- &MessageRequest {
- model: "grok-3".to_string(),
- max_tokens: 64,
- messages: vec![InputMessage::user_text("hello")],
- system: None,
- tools: None,
- tool_choice: None,
- stream: true,
- },
- OpenAiCompatConfig::xai(),
- );
- assert!(payload.get("stream_options").is_none());
- }
- #[test]
- fn tool_choice_translation_supports_required_function() {
- assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required"));
- assert_eq!(
- openai_tool_choice(&ToolChoice::Tool {
- name: "weather".to_string(),
- }),
- json!({"type": "function", "function": {"name": "weather"}})
- );
- }
- #[test]
- fn parses_tool_arguments_fallback() {
- assert_eq!(
- parse_tool_arguments("{\"city\":\"Paris\"}"),
- json!({"city": "Paris"})
- );
- assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"}));
- }
- #[test]
- fn missing_xai_api_key_is_provider_specific() {
- let _lock = env_lock();
- std::env::remove_var("XAI_API_KEY");
- let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai())
- .expect_err("missing key should error");
- assert!(matches!(
- error,
- ApiError::MissingCredentials {
- provider: "xAI",
- ..
- }
- ));
- }
- #[test]
- fn endpoint_builder_accepts_base_urls_and_full_endpoints() {
- assert_eq!(
- chat_completions_endpoint("https://api.x.ai/v1"),
- "https://api.x.ai/v1/chat/completions"
- );
- assert_eq!(
- chat_completions_endpoint("https://api.x.ai/v1/"),
- "https://api.x.ai/v1/chat/completions"
- );
- assert_eq!(
- chat_completions_endpoint("https://api.x.ai/v1/chat/completions"),
- "https://api.x.ai/v1/chat/completions"
- );
- }
- fn env_lock() -> std::sync::MutexGuard<'static, ()> {
- static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
- LOCK.get_or_init(|| Mutex::new(()))
- .lock()
- .expect("env lock")
- }
- #[test]
- fn normalizes_stop_reasons() {
- assert_eq!(normalize_finish_reason("stop"), "end_turn");
- assert_eq!(normalize_finish_reason("tool_calls"), "tool_use");
- }
- }
|