openai_compat.rs 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104
  1. use std::collections::{BTreeMap, VecDeque};
  2. use std::time::Duration;
  3. use serde::Deserialize;
  4. use serde_json::{json, Value};
  5. use crate::error::ApiError;
  6. use crate::types::{
  7. ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
  8. InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
  9. MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
  10. ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
  11. };
  12. use super::{Provider, ProviderFuture};
  13. pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
  14. pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
  15. const REQUEST_ID_HEADER: &str = "request-id";
  16. const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
  17. const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
  18. const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
  19. const DEFAULT_MAX_RETRIES: u32 = 2;
  20. #[derive(Debug, Clone, Copy, PartialEq, Eq)]
  21. pub struct OpenAiCompatConfig {
  22. pub provider_name: &'static str,
  23. pub api_key_env: &'static str,
  24. pub base_url_env: &'static str,
  25. pub default_base_url: &'static str,
  26. }
  27. const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
  28. const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
  29. impl OpenAiCompatConfig {
  30. #[must_use]
  31. pub const fn xai() -> Self {
  32. Self {
  33. provider_name: "xAI",
  34. api_key_env: "XAI_API_KEY",
  35. base_url_env: "XAI_BASE_URL",
  36. default_base_url: DEFAULT_XAI_BASE_URL,
  37. }
  38. }
  39. #[must_use]
  40. pub const fn openai() -> Self {
  41. Self {
  42. provider_name: "OpenAI",
  43. api_key_env: "OPENAI_API_KEY",
  44. base_url_env: "OPENAI_BASE_URL",
  45. default_base_url: DEFAULT_OPENAI_BASE_URL,
  46. }
  47. }
  48. #[must_use]
  49. pub fn credential_env_vars(self) -> &'static [&'static str] {
  50. match self.provider_name {
  51. "xAI" => XAI_ENV_VARS,
  52. "OpenAI" => OPENAI_ENV_VARS,
  53. _ => &[],
  54. }
  55. }
  56. }
  57. #[derive(Debug, Clone)]
  58. pub struct OpenAiCompatClient {
  59. http: reqwest::Client,
  60. api_key: String,
  61. config: OpenAiCompatConfig,
  62. base_url: String,
  63. max_retries: u32,
  64. initial_backoff: Duration,
  65. max_backoff: Duration,
  66. }
  67. impl OpenAiCompatClient {
  68. const fn config(&self) -> OpenAiCompatConfig {
  69. self.config
  70. }
  71. #[must_use]
  72. pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
  73. Self {
  74. http: reqwest::Client::new(),
  75. api_key: api_key.into(),
  76. config,
  77. base_url: read_base_url(config),
  78. max_retries: DEFAULT_MAX_RETRIES,
  79. initial_backoff: DEFAULT_INITIAL_BACKOFF,
  80. max_backoff: DEFAULT_MAX_BACKOFF,
  81. }
  82. }
  83. pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
  84. let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
  85. return Err(ApiError::missing_credentials(
  86. config.provider_name,
  87. config.credential_env_vars(),
  88. ));
  89. };
  90. Ok(Self::new(api_key, config))
  91. }
  92. #[must_use]
  93. pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
  94. self.base_url = base_url.into();
  95. self
  96. }
  97. #[must_use]
  98. pub fn with_retry_policy(
  99. mut self,
  100. max_retries: u32,
  101. initial_backoff: Duration,
  102. max_backoff: Duration,
  103. ) -> Self {
  104. self.max_retries = max_retries;
  105. self.initial_backoff = initial_backoff;
  106. self.max_backoff = max_backoff;
  107. self
  108. }
  109. pub async fn send_message(
  110. &self,
  111. request: &MessageRequest,
  112. ) -> Result<MessageResponse, ApiError> {
  113. let request = MessageRequest {
  114. stream: false,
  115. ..request.clone()
  116. };
  117. let response = self.send_with_retry(&request).await?;
  118. let request_id = request_id_from_headers(response.headers());
  119. let payload = response.json::<ChatCompletionResponse>().await?;
  120. let mut normalized = normalize_response(&request.model, payload)?;
  121. if normalized.request_id.is_none() {
  122. normalized.request_id = request_id;
  123. }
  124. Ok(normalized)
  125. }
  126. pub async fn stream_message(
  127. &self,
  128. request: &MessageRequest,
  129. ) -> Result<MessageStream, ApiError> {
  130. let response = self
  131. .send_with_retry(&request.clone().with_streaming())
  132. .await?;
  133. Ok(MessageStream {
  134. request_id: request_id_from_headers(response.headers()),
  135. response,
  136. parser: OpenAiSseParser::new(),
  137. pending: VecDeque::new(),
  138. done: false,
  139. state: StreamState::new(request.model.clone()),
  140. })
  141. }
  142. async fn send_with_retry(
  143. &self,
  144. request: &MessageRequest,
  145. ) -> Result<reqwest::Response, ApiError> {
  146. let mut attempts = 0;
  147. let last_error = loop {
  148. attempts += 1;
  149. let retryable_error = match self.send_raw_request(request).await {
  150. Ok(response) => match expect_success(response).await {
  151. Ok(response) => return Ok(response),
  152. Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
  153. Err(error) => return Err(error),
  154. },
  155. Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
  156. Err(error) => return Err(error),
  157. };
  158. if attempts > self.max_retries {
  159. break retryable_error;
  160. }
  161. tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
  162. };
  163. Err(ApiError::RetriesExhausted {
  164. attempts,
  165. last_error: Box::new(last_error),
  166. })
  167. }
  168. async fn send_raw_request(
  169. &self,
  170. request: &MessageRequest,
  171. ) -> Result<reqwest::Response, ApiError> {
  172. let request_url = chat_completions_endpoint(&self.base_url);
  173. self.http
  174. .post(&request_url)
  175. .header("content-type", "application/json")
  176. .bearer_auth(&self.api_key)
  177. .json(&build_chat_completion_request(request, self.config()))
  178. .send()
  179. .await
  180. .map_err(ApiError::from)
  181. }
  182. fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
  183. let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
  184. return Err(ApiError::BackoffOverflow {
  185. attempt,
  186. base_delay: self.initial_backoff,
  187. });
  188. };
  189. Ok(self
  190. .initial_backoff
  191. .checked_mul(multiplier)
  192. .map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
  193. }
  194. }
  195. impl Provider for OpenAiCompatClient {
  196. type Stream = MessageStream;
  197. fn send_message<'a>(
  198. &'a self,
  199. request: &'a MessageRequest,
  200. ) -> ProviderFuture<'a, MessageResponse> {
  201. Box::pin(async move { self.send_message(request).await })
  202. }
  203. fn stream_message<'a>(
  204. &'a self,
  205. request: &'a MessageRequest,
  206. ) -> ProviderFuture<'a, Self::Stream> {
  207. Box::pin(async move { self.stream_message(request).await })
  208. }
  209. }
  210. #[derive(Debug)]
  211. pub struct MessageStream {
  212. request_id: Option<String>,
  213. response: reqwest::Response,
  214. parser: OpenAiSseParser,
  215. pending: VecDeque<StreamEvent>,
  216. done: bool,
  217. state: StreamState,
  218. }
  219. impl MessageStream {
  220. #[must_use]
  221. pub fn request_id(&self) -> Option<&str> {
  222. self.request_id.as_deref()
  223. }
  224. pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
  225. loop {
  226. if let Some(event) = self.pending.pop_front() {
  227. return Ok(Some(event));
  228. }
  229. if self.done {
  230. self.pending.extend(self.state.finish()?);
  231. if let Some(event) = self.pending.pop_front() {
  232. return Ok(Some(event));
  233. }
  234. return Ok(None);
  235. }
  236. match self.response.chunk().await? {
  237. Some(chunk) => {
  238. for parsed in self.parser.push(&chunk)? {
  239. self.pending.extend(self.state.ingest_chunk(parsed)?);
  240. }
  241. }
  242. None => {
  243. self.done = true;
  244. }
  245. }
  246. }
  247. }
  248. }
  249. #[derive(Debug, Default)]
  250. struct OpenAiSseParser {
  251. buffer: Vec<u8>,
  252. }
  253. impl OpenAiSseParser {
  254. fn new() -> Self {
  255. Self::default()
  256. }
  257. fn push(&mut self, chunk: &[u8]) -> Result<Vec<ChatCompletionChunk>, ApiError> {
  258. self.buffer.extend_from_slice(chunk);
  259. let mut events = Vec::new();
  260. while let Some(frame) = next_sse_frame(&mut self.buffer) {
  261. if let Some(event) = parse_sse_frame(&frame)? {
  262. events.push(event);
  263. }
  264. }
  265. Ok(events)
  266. }
  267. }
  268. #[allow(clippy::struct_excessive_bools)]
  269. #[derive(Debug)]
  270. struct StreamState {
  271. model: String,
  272. message_started: bool,
  273. text_started: bool,
  274. text_finished: bool,
  275. finished: bool,
  276. stop_reason: Option<String>,
  277. usage: Option<Usage>,
  278. tool_calls: BTreeMap<u32, ToolCallState>,
  279. }
  280. impl StreamState {
  281. fn new(model: String) -> Self {
  282. Self {
  283. model,
  284. message_started: false,
  285. text_started: false,
  286. text_finished: false,
  287. finished: false,
  288. stop_reason: None,
  289. usage: None,
  290. tool_calls: BTreeMap::new(),
  291. }
  292. }
  293. fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result<Vec<StreamEvent>, ApiError> {
  294. let mut events = Vec::new();
  295. if !self.message_started {
  296. self.message_started = true;
  297. events.push(StreamEvent::MessageStart(MessageStartEvent {
  298. message: MessageResponse {
  299. id: chunk.id.clone(),
  300. kind: "message".to_string(),
  301. role: "assistant".to_string(),
  302. content: Vec::new(),
  303. model: chunk.model.clone().unwrap_or_else(|| self.model.clone()),
  304. stop_reason: None,
  305. stop_sequence: None,
  306. usage: Usage {
  307. input_tokens: 0,
  308. cache_creation_input_tokens: 0,
  309. cache_read_input_tokens: 0,
  310. output_tokens: 0,
  311. },
  312. request_id: None,
  313. },
  314. }));
  315. }
  316. if let Some(usage) = chunk.usage {
  317. self.usage = Some(Usage {
  318. input_tokens: usage.prompt_tokens,
  319. cache_creation_input_tokens: 0,
  320. cache_read_input_tokens: 0,
  321. output_tokens: usage.completion_tokens,
  322. });
  323. }
  324. for choice in chunk.choices {
  325. if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) {
  326. if !self.text_started {
  327. self.text_started = true;
  328. events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent {
  329. index: 0,
  330. content_block: OutputContentBlock::Text {
  331. text: String::new(),
  332. },
  333. }));
  334. }
  335. events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
  336. index: 0,
  337. delta: ContentBlockDelta::TextDelta { text: content },
  338. }));
  339. }
  340. for tool_call in choice.delta.tool_calls {
  341. let state = self.tool_calls.entry(tool_call.index).or_default();
  342. state.apply(tool_call);
  343. let block_index = state.block_index();
  344. if !state.started {
  345. if let Some(start_event) = state.start_event()? {
  346. state.started = true;
  347. events.push(StreamEvent::ContentBlockStart(start_event));
  348. } else {
  349. continue;
  350. }
  351. }
  352. if let Some(delta_event) = state.delta_event() {
  353. events.push(StreamEvent::ContentBlockDelta(delta_event));
  354. }
  355. if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped {
  356. state.stopped = true;
  357. events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
  358. index: block_index,
  359. }));
  360. }
  361. }
  362. if let Some(finish_reason) = choice.finish_reason {
  363. self.stop_reason = Some(normalize_finish_reason(&finish_reason));
  364. if finish_reason == "tool_calls" {
  365. for state in self.tool_calls.values_mut() {
  366. if state.started && !state.stopped {
  367. state.stopped = true;
  368. events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
  369. index: state.block_index(),
  370. }));
  371. }
  372. }
  373. }
  374. }
  375. }
  376. Ok(events)
  377. }
  378. fn finish(&mut self) -> Result<Vec<StreamEvent>, ApiError> {
  379. if self.finished {
  380. return Ok(Vec::new());
  381. }
  382. self.finished = true;
  383. let mut events = Vec::new();
  384. if self.text_started && !self.text_finished {
  385. self.text_finished = true;
  386. events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
  387. index: 0,
  388. }));
  389. }
  390. for state in self.tool_calls.values_mut() {
  391. if !state.started {
  392. if let Some(start_event) = state.start_event()? {
  393. state.started = true;
  394. events.push(StreamEvent::ContentBlockStart(start_event));
  395. if let Some(delta_event) = state.delta_event() {
  396. events.push(StreamEvent::ContentBlockDelta(delta_event));
  397. }
  398. }
  399. }
  400. if state.started && !state.stopped {
  401. state.stopped = true;
  402. events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
  403. index: state.block_index(),
  404. }));
  405. }
  406. }
  407. if self.message_started {
  408. events.push(StreamEvent::MessageDelta(MessageDeltaEvent {
  409. delta: MessageDelta {
  410. stop_reason: Some(
  411. self.stop_reason
  412. .clone()
  413. .unwrap_or_else(|| "end_turn".to_string()),
  414. ),
  415. stop_sequence: None,
  416. },
  417. usage: self.usage.clone().unwrap_or(Usage {
  418. input_tokens: 0,
  419. cache_creation_input_tokens: 0,
  420. cache_read_input_tokens: 0,
  421. output_tokens: 0,
  422. }),
  423. }));
  424. events.push(StreamEvent::MessageStop(MessageStopEvent {}));
  425. }
  426. Ok(events)
  427. }
  428. }
  429. #[derive(Debug, Default)]
  430. struct ToolCallState {
  431. openai_index: u32,
  432. id: Option<String>,
  433. name: Option<String>,
  434. arguments: String,
  435. emitted_len: usize,
  436. started: bool,
  437. stopped: bool,
  438. }
  439. impl ToolCallState {
  440. fn apply(&mut self, tool_call: DeltaToolCall) {
  441. self.openai_index = tool_call.index;
  442. if let Some(id) = tool_call.id {
  443. self.id = Some(id);
  444. }
  445. if let Some(name) = tool_call.function.name {
  446. self.name = Some(name);
  447. }
  448. if let Some(arguments) = tool_call.function.arguments {
  449. self.arguments.push_str(&arguments);
  450. }
  451. }
  452. const fn block_index(&self) -> u32 {
  453. self.openai_index + 1
  454. }
  455. #[allow(clippy::unnecessary_wraps)]
  456. fn start_event(&self) -> Result<Option<ContentBlockStartEvent>, ApiError> {
  457. let Some(name) = self.name.clone() else {
  458. return Ok(None);
  459. };
  460. let id = self
  461. .id
  462. .clone()
  463. .unwrap_or_else(|| format!("tool_call_{}", self.openai_index));
  464. Ok(Some(ContentBlockStartEvent {
  465. index: self.block_index(),
  466. content_block: OutputContentBlock::ToolUse {
  467. id,
  468. name,
  469. input: json!({}),
  470. },
  471. }))
  472. }
  473. fn delta_event(&mut self) -> Option<ContentBlockDeltaEvent> {
  474. if self.emitted_len >= self.arguments.len() {
  475. return None;
  476. }
  477. let delta = self.arguments[self.emitted_len..].to_string();
  478. self.emitted_len = self.arguments.len();
  479. Some(ContentBlockDeltaEvent {
  480. index: self.block_index(),
  481. delta: ContentBlockDelta::InputJsonDelta {
  482. partial_json: delta,
  483. },
  484. })
  485. }
  486. }
  487. #[derive(Debug, Deserialize)]
  488. struct ChatCompletionResponse {
  489. id: String,
  490. model: String,
  491. choices: Vec<ChatChoice>,
  492. #[serde(default)]
  493. usage: Option<OpenAiUsage>,
  494. }
  495. #[derive(Debug, Deserialize)]
  496. struct ChatChoice {
  497. message: ChatMessage,
  498. #[serde(default)]
  499. finish_reason: Option<String>,
  500. }
  501. #[derive(Debug, Deserialize)]
  502. struct ChatMessage {
  503. role: String,
  504. #[serde(default)]
  505. content: Option<String>,
  506. #[serde(default)]
  507. tool_calls: Vec<ResponseToolCall>,
  508. }
  509. #[derive(Debug, Deserialize)]
  510. struct ResponseToolCall {
  511. id: String,
  512. function: ResponseToolFunction,
  513. }
  514. #[derive(Debug, Deserialize)]
  515. struct ResponseToolFunction {
  516. name: String,
  517. arguments: String,
  518. }
  519. #[derive(Debug, Deserialize)]
  520. struct OpenAiUsage {
  521. #[serde(default)]
  522. prompt_tokens: u32,
  523. #[serde(default)]
  524. completion_tokens: u32,
  525. }
  526. #[derive(Debug, Deserialize)]
  527. struct ChatCompletionChunk {
  528. id: String,
  529. #[serde(default)]
  530. model: Option<String>,
  531. #[serde(default)]
  532. choices: Vec<ChunkChoice>,
  533. #[serde(default)]
  534. usage: Option<OpenAiUsage>,
  535. }
  536. #[derive(Debug, Deserialize)]
  537. struct ChunkChoice {
  538. delta: ChunkDelta,
  539. #[serde(default)]
  540. finish_reason: Option<String>,
  541. }
  542. #[derive(Debug, Default, Deserialize)]
  543. struct ChunkDelta {
  544. #[serde(default)]
  545. content: Option<String>,
  546. #[serde(default)]
  547. tool_calls: Vec<DeltaToolCall>,
  548. }
  549. #[derive(Debug, Deserialize)]
  550. struct DeltaToolCall {
  551. #[serde(default)]
  552. index: u32,
  553. #[serde(default)]
  554. id: Option<String>,
  555. #[serde(default)]
  556. function: DeltaFunction,
  557. }
  558. #[derive(Debug, Default, Deserialize)]
  559. struct DeltaFunction {
  560. #[serde(default)]
  561. name: Option<String>,
  562. #[serde(default)]
  563. arguments: Option<String>,
  564. }
  565. #[derive(Debug, Deserialize)]
  566. struct ErrorEnvelope {
  567. error: ErrorBody,
  568. }
  569. #[derive(Debug, Deserialize)]
  570. struct ErrorBody {
  571. #[serde(rename = "type")]
  572. error_type: Option<String>,
  573. message: Option<String>,
  574. }
  575. fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatConfig) -> Value {
  576. let mut messages = Vec::new();
  577. if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
  578. messages.push(json!({
  579. "role": "system",
  580. "content": system,
  581. }));
  582. }
  583. for message in &request.messages {
  584. messages.extend(translate_message(message));
  585. }
  586. let mut payload = json!({
  587. "model": request.model,
  588. "max_tokens": request.max_tokens,
  589. "messages": messages,
  590. "stream": request.stream,
  591. });
  592. if request.stream && should_request_stream_usage(config) {
  593. payload["stream_options"] = json!({ "include_usage": true });
  594. }
  595. if let Some(tools) = &request.tools {
  596. payload["tools"] =
  597. Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
  598. }
  599. if let Some(tool_choice) = &request.tool_choice {
  600. payload["tool_choice"] = openai_tool_choice(tool_choice);
  601. }
  602. payload
  603. }
  604. fn translate_message(message: &InputMessage) -> Vec<Value> {
  605. match message.role.as_str() {
  606. "assistant" => {
  607. let mut text = String::new();
  608. let mut tool_calls = Vec::new();
  609. for block in &message.content {
  610. match block {
  611. InputContentBlock::Text { text: value } => text.push_str(value),
  612. InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({
  613. "id": id,
  614. "type": "function",
  615. "function": {
  616. "name": name,
  617. "arguments": input.to_string(),
  618. }
  619. })),
  620. InputContentBlock::ToolResult { .. } => {}
  621. }
  622. }
  623. if text.is_empty() && tool_calls.is_empty() {
  624. Vec::new()
  625. } else {
  626. vec![json!({
  627. "role": "assistant",
  628. "content": (!text.is_empty()).then_some(text),
  629. "tool_calls": tool_calls,
  630. })]
  631. }
  632. }
  633. _ => message
  634. .content
  635. .iter()
  636. .filter_map(|block| match block {
  637. InputContentBlock::Text { text } => Some(json!({
  638. "role": "user",
  639. "content": text,
  640. })),
  641. InputContentBlock::ToolResult {
  642. tool_use_id,
  643. content,
  644. is_error,
  645. } => Some(json!({
  646. "role": "tool",
  647. "tool_call_id": tool_use_id,
  648. "content": flatten_tool_result_content(content),
  649. "is_error": is_error,
  650. })),
  651. InputContentBlock::ToolUse { .. } => None,
  652. })
  653. .collect(),
  654. }
  655. }
  656. fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String {
  657. content
  658. .iter()
  659. .map(|block| match block {
  660. ToolResultContentBlock::Text { text } => text.clone(),
  661. ToolResultContentBlock::Json { value } => value.to_string(),
  662. })
  663. .collect::<Vec<_>>()
  664. .join("\n")
  665. }
  666. fn openai_tool_definition(tool: &ToolDefinition) -> Value {
  667. json!({
  668. "type": "function",
  669. "function": {
  670. "name": tool.name,
  671. "description": tool.description,
  672. "parameters": tool.input_schema,
  673. }
  674. })
  675. }
  676. fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
  677. match tool_choice {
  678. ToolChoice::Auto => Value::String("auto".to_string()),
  679. ToolChoice::Any => Value::String("required".to_string()),
  680. ToolChoice::Tool { name } => json!({
  681. "type": "function",
  682. "function": { "name": name },
  683. }),
  684. }
  685. }
  686. fn should_request_stream_usage(config: OpenAiCompatConfig) -> bool {
  687. matches!(config.provider_name, "OpenAI")
  688. }
  689. fn normalize_response(
  690. model: &str,
  691. response: ChatCompletionResponse,
  692. ) -> Result<MessageResponse, ApiError> {
  693. let choice = response
  694. .choices
  695. .into_iter()
  696. .next()
  697. .ok_or(ApiError::InvalidSseFrame(
  698. "chat completion response missing choices",
  699. ))?;
  700. let mut content = Vec::new();
  701. if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) {
  702. content.push(OutputContentBlock::Text { text });
  703. }
  704. for tool_call in choice.message.tool_calls {
  705. content.push(OutputContentBlock::ToolUse {
  706. id: tool_call.id,
  707. name: tool_call.function.name,
  708. input: parse_tool_arguments(&tool_call.function.arguments),
  709. });
  710. }
  711. Ok(MessageResponse {
  712. id: response.id,
  713. kind: "message".to_string(),
  714. role: choice.message.role,
  715. content,
  716. model: response.model.if_empty_then(model.to_string()),
  717. stop_reason: choice
  718. .finish_reason
  719. .map(|value| normalize_finish_reason(&value)),
  720. stop_sequence: None,
  721. usage: Usage {
  722. input_tokens: response
  723. .usage
  724. .as_ref()
  725. .map_or(0, |usage| usage.prompt_tokens),
  726. cache_creation_input_tokens: 0,
  727. cache_read_input_tokens: 0,
  728. output_tokens: response
  729. .usage
  730. .as_ref()
  731. .map_or(0, |usage| usage.completion_tokens),
  732. },
  733. request_id: None,
  734. })
  735. }
  736. fn parse_tool_arguments(arguments: &str) -> Value {
  737. serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments }))
  738. }
  739. fn next_sse_frame(buffer: &mut Vec<u8>) -> Option<String> {
  740. let separator = buffer
  741. .windows(2)
  742. .position(|window| window == b"\n\n")
  743. .map(|position| (position, 2))
  744. .or_else(|| {
  745. buffer
  746. .windows(4)
  747. .position(|window| window == b"\r\n\r\n")
  748. .map(|position| (position, 4))
  749. })?;
  750. let (position, separator_len) = separator;
  751. let frame = buffer.drain(..position + separator_len).collect::<Vec<_>>();
  752. let frame_len = frame.len().saturating_sub(separator_len);
  753. Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
  754. }
  755. fn parse_sse_frame(frame: &str) -> Result<Option<ChatCompletionChunk>, ApiError> {
  756. let trimmed = frame.trim();
  757. if trimmed.is_empty() {
  758. return Ok(None);
  759. }
  760. let mut data_lines = Vec::new();
  761. for line in trimmed.lines() {
  762. if line.starts_with(':') {
  763. continue;
  764. }
  765. if let Some(data) = line.strip_prefix("data:") {
  766. data_lines.push(data.trim_start());
  767. }
  768. }
  769. if data_lines.is_empty() {
  770. return Ok(None);
  771. }
  772. let payload = data_lines.join("\n");
  773. if payload == "[DONE]" {
  774. return Ok(None);
  775. }
  776. serde_json::from_str(&payload)
  777. .map(Some)
  778. .map_err(ApiError::from)
  779. }
  780. fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
  781. match std::env::var(key) {
  782. Ok(value) if !value.is_empty() => Ok(Some(value)),
  783. Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
  784. Err(error) => Err(ApiError::from(error)),
  785. }
  786. }
  787. #[must_use]
  788. pub fn has_api_key(key: &str) -> bool {
  789. read_env_non_empty(key)
  790. .ok()
  791. .and_then(std::convert::identity)
  792. .is_some()
  793. }
  794. #[must_use]
  795. pub fn read_base_url(config: OpenAiCompatConfig) -> String {
  796. std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
  797. }
  798. fn chat_completions_endpoint(base_url: &str) -> String {
  799. let trimmed = base_url.trim_end_matches('/');
  800. if trimmed.ends_with("/chat/completions") {
  801. trimmed.to_string()
  802. } else {
  803. format!("{trimmed}/chat/completions")
  804. }
  805. }
  806. fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
  807. headers
  808. .get(REQUEST_ID_HEADER)
  809. .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
  810. .and_then(|value| value.to_str().ok())
  811. .map(ToOwned::to_owned)
  812. }
  813. async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
  814. let status = response.status();
  815. if status.is_success() {
  816. return Ok(response);
  817. }
  818. let body = response.text().await.unwrap_or_default();
  819. let parsed_error = serde_json::from_str::<ErrorEnvelope>(&body).ok();
  820. let retryable = is_retryable_status(status);
  821. Err(ApiError::Api {
  822. status,
  823. error_type: parsed_error
  824. .as_ref()
  825. .and_then(|error| error.error.error_type.clone()),
  826. message: parsed_error
  827. .as_ref()
  828. .and_then(|error| error.error.message.clone()),
  829. body,
  830. retryable,
  831. })
  832. }
  833. const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
  834. matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
  835. }
  836. fn normalize_finish_reason(value: &str) -> String {
  837. match value {
  838. "stop" => "end_turn",
  839. "tool_calls" => "tool_use",
  840. other => other,
  841. }
  842. .to_string()
  843. }
  844. trait StringExt {
  845. fn if_empty_then(self, fallback: String) -> String;
  846. }
  847. impl StringExt for String {
  848. fn if_empty_then(self, fallback: String) -> String {
  849. if self.is_empty() {
  850. fallback
  851. } else {
  852. self
  853. }
  854. }
  855. }
  856. #[cfg(test)]
  857. mod tests {
  858. use super::{
  859. build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason,
  860. openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig,
  861. };
  862. use crate::error::ApiError;
  863. use crate::types::{
  864. InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition,
  865. ToolResultContentBlock,
  866. };
  867. use serde_json::json;
  868. use std::sync::{Mutex, OnceLock};
  869. #[test]
  870. fn request_translation_uses_openai_compatible_shape() {
  871. let payload = build_chat_completion_request(
  872. &MessageRequest {
  873. model: "grok-3".to_string(),
  874. max_tokens: 64,
  875. messages: vec![InputMessage {
  876. role: "user".to_string(),
  877. content: vec![
  878. InputContentBlock::Text {
  879. text: "hello".to_string(),
  880. },
  881. InputContentBlock::ToolResult {
  882. tool_use_id: "tool_1".to_string(),
  883. content: vec![ToolResultContentBlock::Json {
  884. value: json!({"ok": true}),
  885. }],
  886. is_error: false,
  887. },
  888. ],
  889. }],
  890. system: Some("be helpful".to_string()),
  891. tools: Some(vec![ToolDefinition {
  892. name: "weather".to_string(),
  893. description: Some("Get weather".to_string()),
  894. input_schema: json!({"type": "object"}),
  895. }]),
  896. tool_choice: Some(ToolChoice::Auto),
  897. stream: false,
  898. },
  899. OpenAiCompatConfig::xai(),
  900. );
  901. assert_eq!(payload["messages"][0]["role"], json!("system"));
  902. assert_eq!(payload["messages"][1]["role"], json!("user"));
  903. assert_eq!(payload["messages"][2]["role"], json!("tool"));
  904. assert_eq!(payload["tools"][0]["type"], json!("function"));
  905. assert_eq!(payload["tool_choice"], json!("auto"));
  906. }
  907. #[test]
  908. fn openai_streaming_requests_include_usage_opt_in() {
  909. let payload = build_chat_completion_request(
  910. &MessageRequest {
  911. model: "gpt-5".to_string(),
  912. max_tokens: 64,
  913. messages: vec![InputMessage::user_text("hello")],
  914. system: None,
  915. tools: None,
  916. tool_choice: None,
  917. stream: true,
  918. },
  919. OpenAiCompatConfig::openai(),
  920. );
  921. assert_eq!(payload["stream_options"], json!({"include_usage": true}));
  922. }
  923. #[test]
  924. fn xai_streaming_requests_skip_openai_specific_usage_opt_in() {
  925. let payload = build_chat_completion_request(
  926. &MessageRequest {
  927. model: "grok-3".to_string(),
  928. max_tokens: 64,
  929. messages: vec![InputMessage::user_text("hello")],
  930. system: None,
  931. tools: None,
  932. tool_choice: None,
  933. stream: true,
  934. },
  935. OpenAiCompatConfig::xai(),
  936. );
  937. assert!(payload.get("stream_options").is_none());
  938. }
  939. #[test]
  940. fn tool_choice_translation_supports_required_function() {
  941. assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required"));
  942. assert_eq!(
  943. openai_tool_choice(&ToolChoice::Tool {
  944. name: "weather".to_string(),
  945. }),
  946. json!({"type": "function", "function": {"name": "weather"}})
  947. );
  948. }
  949. #[test]
  950. fn parses_tool_arguments_fallback() {
  951. assert_eq!(
  952. parse_tool_arguments("{\"city\":\"Paris\"}"),
  953. json!({"city": "Paris"})
  954. );
  955. assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"}));
  956. }
  957. #[test]
  958. fn missing_xai_api_key_is_provider_specific() {
  959. let _lock = env_lock();
  960. std::env::remove_var("XAI_API_KEY");
  961. let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai())
  962. .expect_err("missing key should error");
  963. assert!(matches!(
  964. error,
  965. ApiError::MissingCredentials {
  966. provider: "xAI",
  967. ..
  968. }
  969. ));
  970. }
  971. #[test]
  972. fn endpoint_builder_accepts_base_urls_and_full_endpoints() {
  973. assert_eq!(
  974. chat_completions_endpoint("https://api.x.ai/v1"),
  975. "https://api.x.ai/v1/chat/completions"
  976. );
  977. assert_eq!(
  978. chat_completions_endpoint("https://api.x.ai/v1/"),
  979. "https://api.x.ai/v1/chat/completions"
  980. );
  981. assert_eq!(
  982. chat_completions_endpoint("https://api.x.ai/v1/chat/completions"),
  983. "https://api.x.ai/v1/chat/completions"
  984. );
  985. }
  986. fn env_lock() -> std::sync::MutexGuard<'static, ()> {
  987. static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
  988. LOCK.get_or_init(|| Mutex::new(()))
  989. .lock()
  990. .expect("env lock")
  991. }
  992. #[test]
  993. fn normalizes_stop_reasons() {
  994. assert_eq!(normalize_finish_reason("stop"), "end_turn");
  995. assert_eq!(normalize_finish_reason("tool_calls"), "tool_use");
  996. }
  997. }