|
@@ -67,6 +67,7 @@ impl OpenAiCompatConfig {
|
|
|
pub struct OpenAiCompatClient {
|
|
pub struct OpenAiCompatClient {
|
|
|
http: reqwest::Client,
|
|
http: reqwest::Client,
|
|
|
api_key: String,
|
|
api_key: String,
|
|
|
|
|
+ config: OpenAiCompatConfig,
|
|
|
base_url: String,
|
|
base_url: String,
|
|
|
max_retries: u32,
|
|
max_retries: u32,
|
|
|
initial_backoff: Duration,
|
|
initial_backoff: Duration,
|
|
@@ -74,11 +75,15 @@ pub struct OpenAiCompatClient {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
impl OpenAiCompatClient {
|
|
impl OpenAiCompatClient {
|
|
|
|
|
+ const fn config(&self) -> OpenAiCompatConfig {
|
|
|
|
|
+ self.config
|
|
|
|
|
+ }
|
|
|
#[must_use]
|
|
#[must_use]
|
|
|
pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
|
|
pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
|
|
|
Self {
|
|
Self {
|
|
|
http: reqwest::Client::new(),
|
|
http: reqwest::Client::new(),
|
|
|
api_key: api_key.into(),
|
|
api_key: api_key.into(),
|
|
|
|
|
+ config,
|
|
|
base_url: read_base_url(config),
|
|
base_url: read_base_url(config),
|
|
|
max_retries: DEFAULT_MAX_RETRIES,
|
|
max_retries: DEFAULT_MAX_RETRIES,
|
|
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
|
@@ -190,7 +195,7 @@ impl OpenAiCompatClient {
|
|
|
.post(&request_url)
|
|
.post(&request_url)
|
|
|
.header("content-type", "application/json")
|
|
.header("content-type", "application/json")
|
|
|
.bearer_auth(&self.api_key)
|
|
.bearer_auth(&self.api_key)
|
|
|
- .json(&build_chat_completion_request(request))
|
|
|
|
|
|
|
+ .json(&build_chat_completion_request(request, self.config()))
|
|
|
.send()
|
|
.send()
|
|
|
.await
|
|
.await
|
|
|
.map_err(ApiError::from)
|
|
.map_err(ApiError::from)
|
|
@@ -633,7 +638,7 @@ struct ErrorBody {
|
|
|
message: Option<String>,
|
|
message: Option<String>,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-fn build_chat_completion_request(request: &MessageRequest) -> Value {
|
|
|
|
|
|
|
+fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatConfig) -> Value {
|
|
|
let mut messages = Vec::new();
|
|
let mut messages = Vec::new();
|
|
|
if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
|
|
if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
|
|
|
messages.push(json!({
|
|
messages.push(json!({
|
|
@@ -652,6 +657,10 @@ fn build_chat_completion_request(request: &MessageRequest) -> Value {
|
|
|
"stream": request.stream,
|
|
"stream": request.stream,
|
|
|
});
|
|
});
|
|
|
|
|
|
|
|
|
|
+ if request.stream && should_request_stream_usage(config) {
|
|
|
|
|
+ payload["stream_options"] = json!({ "include_usage": true });
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
if let Some(tools) = &request.tools {
|
|
if let Some(tools) = &request.tools {
|
|
|
payload["tools"] =
|
|
payload["tools"] =
|
|
|
Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
|
|
Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
|
|
@@ -749,6 +758,10 @@ fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+fn should_request_stream_usage(config: OpenAiCompatConfig) -> bool {
|
|
|
|
|
+ matches!(config.provider_name, "OpenAI")
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
fn normalize_response(
|
|
fn normalize_response(
|
|
|
model: &str,
|
|
model: &str,
|
|
|
response: ChatCompletionResponse,
|
|
response: ChatCompletionResponse,
|
|
@@ -951,33 +964,36 @@ mod tests {
|
|
|
|
|
|
|
|
#[test]
|
|
#[test]
|
|
|
fn request_translation_uses_openai_compatible_shape() {
|
|
fn request_translation_uses_openai_compatible_shape() {
|
|
|
- let payload = build_chat_completion_request(&MessageRequest {
|
|
|
|
|
- model: "grok-3".to_string(),
|
|
|
|
|
- max_tokens: 64,
|
|
|
|
|
- messages: vec![InputMessage {
|
|
|
|
|
- role: "user".to_string(),
|
|
|
|
|
- content: vec![
|
|
|
|
|
- InputContentBlock::Text {
|
|
|
|
|
- text: "hello".to_string(),
|
|
|
|
|
- },
|
|
|
|
|
- InputContentBlock::ToolResult {
|
|
|
|
|
- tool_use_id: "tool_1".to_string(),
|
|
|
|
|
- content: vec![ToolResultContentBlock::Json {
|
|
|
|
|
- value: json!({"ok": true}),
|
|
|
|
|
- }],
|
|
|
|
|
- is_error: false,
|
|
|
|
|
- },
|
|
|
|
|
- ],
|
|
|
|
|
- }],
|
|
|
|
|
- system: Some("be helpful".to_string()),
|
|
|
|
|
- tools: Some(vec![ToolDefinition {
|
|
|
|
|
- name: "weather".to_string(),
|
|
|
|
|
- description: Some("Get weather".to_string()),
|
|
|
|
|
- input_schema: json!({"type": "object"}),
|
|
|
|
|
- }]),
|
|
|
|
|
- tool_choice: Some(ToolChoice::Auto),
|
|
|
|
|
- stream: false,
|
|
|
|
|
- });
|
|
|
|
|
|
|
+ let payload = build_chat_completion_request(
|
|
|
|
|
+ &MessageRequest {
|
|
|
|
|
+ model: "grok-3".to_string(),
|
|
|
|
|
+ max_tokens: 64,
|
|
|
|
|
+ messages: vec![InputMessage {
|
|
|
|
|
+ role: "user".to_string(),
|
|
|
|
|
+ content: vec![
|
|
|
|
|
+ InputContentBlock::Text {
|
|
|
|
|
+ text: "hello".to_string(),
|
|
|
|
|
+ },
|
|
|
|
|
+ InputContentBlock::ToolResult {
|
|
|
|
|
+ tool_use_id: "tool_1".to_string(),
|
|
|
|
|
+ content: vec![ToolResultContentBlock::Json {
|
|
|
|
|
+ value: json!({"ok": true}),
|
|
|
|
|
+ }],
|
|
|
|
|
+ is_error: false,
|
|
|
|
|
+ },
|
|
|
|
|
+ ],
|
|
|
|
|
+ }],
|
|
|
|
|
+ system: Some("be helpful".to_string()),
|
|
|
|
|
+ tools: Some(vec![ToolDefinition {
|
|
|
|
|
+ name: "weather".to_string(),
|
|
|
|
|
+ description: Some("Get weather".to_string()),
|
|
|
|
|
+ input_schema: json!({"type": "object"}),
|
|
|
|
|
+ }]),
|
|
|
|
|
+ tool_choice: Some(ToolChoice::Auto),
|
|
|
|
|
+ stream: false,
|
|
|
|
|
+ },
|
|
|
|
|
+ OpenAiCompatConfig::xai(),
|
|
|
|
|
+ );
|
|
|
|
|
|
|
|
assert_eq!(payload["messages"][0]["role"], json!("system"));
|
|
assert_eq!(payload["messages"][0]["role"], json!("system"));
|
|
|
assert_eq!(payload["messages"][1]["role"], json!("user"));
|
|
assert_eq!(payload["messages"][1]["role"], json!("user"));
|
|
@@ -986,6 +1002,42 @@ mod tests {
|
|
|
assert_eq!(payload["tool_choice"], json!("auto"));
|
|
assert_eq!(payload["tool_choice"], json!("auto"));
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn openai_streaming_requests_include_usage_opt_in() {
|
|
|
|
|
+ let payload = build_chat_completion_request(
|
|
|
|
|
+ &MessageRequest {
|
|
|
|
|
+ model: "gpt-5".to_string(),
|
|
|
|
|
+ max_tokens: 64,
|
|
|
|
|
+ messages: vec![InputMessage::user_text("hello")],
|
|
|
|
|
+ system: None,
|
|
|
|
|
+ tools: None,
|
|
|
|
|
+ tool_choice: None,
|
|
|
|
|
+ stream: true,
|
|
|
|
|
+ },
|
|
|
|
|
+ OpenAiCompatConfig::openai(),
|
|
|
|
|
+ );
|
|
|
|
|
+
|
|
|
|
|
+ assert_eq!(payload["stream_options"], json!({"include_usage": true}));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn xai_streaming_requests_skip_openai_specific_usage_opt_in() {
|
|
|
|
|
+ let payload = build_chat_completion_request(
|
|
|
|
|
+ &MessageRequest {
|
|
|
|
|
+ model: "grok-3".to_string(),
|
|
|
|
|
+ max_tokens: 64,
|
|
|
|
|
+ messages: vec![InputMessage::user_text("hello")],
|
|
|
|
|
+ system: None,
|
|
|
|
|
+ tools: None,
|
|
|
|
|
+ tool_choice: None,
|
|
|
|
|
+ stream: true,
|
|
|
|
|
+ },
|
|
|
|
|
+ OpenAiCompatConfig::xai(),
|
|
|
|
|
+ );
|
|
|
|
|
+
|
|
|
|
|
+ assert!(payload.get("stream_options").is_none());
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
#[test]
|
|
#[test]
|
|
|
fn tool_choice_translation_supports_required_function() {
|
|
fn tool_choice_translation_supports_required_function() {
|
|
|
assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required"));
|
|
assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required"));
|