فهرست منبع

Preserve usage accounting on OpenAI SSE streams

OpenAI chat-completions streams can emit a final usage chunk when the\nclient opts in, but the Rust transport was not requesting it. This\nkeeps provider config on the client and adds stream_options.include_usage\nonly for OpenAI streams so normalized message_delta usage reflects the\ntransport without changing xAI request bodies.\n\nConstraint: Keep xAI request bodies unchanged because provider-specific streaming knobs may differ\nRejected: Enable stream_options for every OpenAI-compatible provider | risks sending unsupported params to xAI-style endpoints\nConfidence: high\nScope-risk: narrow\nDirective: Keep provider-specific streaming flags tied to OpenAiCompatConfig instead of inferring provider behavior from URLs\nTested: cargo clippy -p api --tests -- -D warnings\nTested: cargo test -p api openai_streaming_requests -- --nocapture\nTested: cargo test -p api xai_streaming_requests_skip_openai_specific_usage_opt_in -- --nocapture\nTested: cargo test -p api request_translation_uses_openai_compatible_shape -- --nocapture\nTested: cargo test -p api stream_message_normalizes_text_and_multiple_tool_calls -- --exact --nocapture\nNot-tested: Live OpenAI or xAI network calls
Yeachan-Heo 2 ماه پیش
والد
کامیت
5f1eddf03a
2فایلهای تغییر یافته به همراه160 افزوده شده و 31 حذف شده
  1. 81 29
      rust/crates/api/src/providers/openai_compat.rs
  2. 79 2
      rust/crates/api/tests/openai_compat_integration.rs

+ 81 - 29
rust/crates/api/src/providers/openai_compat.rs

@@ -67,6 +67,7 @@ impl OpenAiCompatConfig {
 pub struct OpenAiCompatClient {
     http: reqwest::Client,
     api_key: String,
+    config: OpenAiCompatConfig,
     base_url: String,
     max_retries: u32,
     initial_backoff: Duration,
@@ -74,11 +75,15 @@ pub struct OpenAiCompatClient {
 }
 
 impl OpenAiCompatClient {
+    const fn config(&self) -> OpenAiCompatConfig {
+        self.config
+    }
     #[must_use]
     pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
         Self {
             http: reqwest::Client::new(),
             api_key: api_key.into(),
+            config,
             base_url: read_base_url(config),
             max_retries: DEFAULT_MAX_RETRIES,
             initial_backoff: DEFAULT_INITIAL_BACKOFF,
@@ -190,7 +195,7 @@ impl OpenAiCompatClient {
             .post(&request_url)
             .header("content-type", "application/json")
             .bearer_auth(&self.api_key)
-            .json(&build_chat_completion_request(request))
+            .json(&build_chat_completion_request(request, self.config()))
             .send()
             .await
             .map_err(ApiError::from)
@@ -633,7 +638,7 @@ struct ErrorBody {
     message: Option<String>,
 }
 
-fn build_chat_completion_request(request: &MessageRequest) -> Value {
+fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatConfig) -> Value {
     let mut messages = Vec::new();
     if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
         messages.push(json!({
@@ -652,6 +657,10 @@ fn build_chat_completion_request(request: &MessageRequest) -> Value {
         "stream": request.stream,
     });
 
+    if request.stream && should_request_stream_usage(config) {
+        payload["stream_options"] = json!({ "include_usage": true });
+    }
+
     if let Some(tools) = &request.tools {
         payload["tools"] =
             Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
@@ -749,6 +758,10 @@ fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
     }
 }
 
+fn should_request_stream_usage(config: OpenAiCompatConfig) -> bool {
+    matches!(config.provider_name, "OpenAI")
+}
+
 fn normalize_response(
     model: &str,
     response: ChatCompletionResponse,
@@ -951,33 +964,36 @@ mod tests {
 
     #[test]
     fn request_translation_uses_openai_compatible_shape() {
-        let payload = build_chat_completion_request(&MessageRequest {
-            model: "grok-3".to_string(),
-            max_tokens: 64,
-            messages: vec![InputMessage {
-                role: "user".to_string(),
-                content: vec![
-                    InputContentBlock::Text {
-                        text: "hello".to_string(),
-                    },
-                    InputContentBlock::ToolResult {
-                        tool_use_id: "tool_1".to_string(),
-                        content: vec![ToolResultContentBlock::Json {
-                            value: json!({"ok": true}),
-                        }],
-                        is_error: false,
-                    },
-                ],
-            }],
-            system: Some("be helpful".to_string()),
-            tools: Some(vec![ToolDefinition {
-                name: "weather".to_string(),
-                description: Some("Get weather".to_string()),
-                input_schema: json!({"type": "object"}),
-            }]),
-            tool_choice: Some(ToolChoice::Auto),
-            stream: false,
-        });
+        let payload = build_chat_completion_request(
+            &MessageRequest {
+                model: "grok-3".to_string(),
+                max_tokens: 64,
+                messages: vec![InputMessage {
+                    role: "user".to_string(),
+                    content: vec![
+                        InputContentBlock::Text {
+                            text: "hello".to_string(),
+                        },
+                        InputContentBlock::ToolResult {
+                            tool_use_id: "tool_1".to_string(),
+                            content: vec![ToolResultContentBlock::Json {
+                                value: json!({"ok": true}),
+                            }],
+                            is_error: false,
+                        },
+                    ],
+                }],
+                system: Some("be helpful".to_string()),
+                tools: Some(vec![ToolDefinition {
+                    name: "weather".to_string(),
+                    description: Some("Get weather".to_string()),
+                    input_schema: json!({"type": "object"}),
+                }]),
+                tool_choice: Some(ToolChoice::Auto),
+                stream: false,
+            },
+            OpenAiCompatConfig::xai(),
+        );
 
         assert_eq!(payload["messages"][0]["role"], json!("system"));
         assert_eq!(payload["messages"][1]["role"], json!("user"));
@@ -986,6 +1002,42 @@ mod tests {
         assert_eq!(payload["tool_choice"], json!("auto"));
     }
 
+    #[test]
+    fn openai_streaming_requests_include_usage_opt_in() {
+        let payload = build_chat_completion_request(
+            &MessageRequest {
+                model: "gpt-5".to_string(),
+                max_tokens: 64,
+                messages: vec![InputMessage::user_text("hello")],
+                system: None,
+                tools: None,
+                tool_choice: None,
+                stream: true,
+            },
+            OpenAiCompatConfig::openai(),
+        );
+
+        assert_eq!(payload["stream_options"], json!({"include_usage": true}));
+    }
+
+    #[test]
+    fn xai_streaming_requests_skip_openai_specific_usage_opt_in() {
+        let payload = build_chat_completion_request(
+            &MessageRequest {
+                model: "grok-3".to_string(),
+                max_tokens: 64,
+                messages: vec![InputMessage::user_text("hello")],
+                system: None,
+                tools: None,
+                tool_choice: None,
+                stream: true,
+            },
+            OpenAiCompatConfig::xai(),
+        );
+
+        assert!(payload.get("stream_options").is_none());
+    }
+
     #[test]
     fn tool_choice_translation_supports_required_function() {
         assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required"));

+ 79 - 2
rust/crates/api/tests/openai_compat_integration.rs

@@ -5,8 +5,9 @@ use std::sync::{Mutex as StdMutex, OnceLock};
 
 use api::{
     ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
-    InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig,
-    OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition,
+    InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OpenAiCompatClient,
+    OpenAiCompatConfig, OutputContentBlock, ProviderClient, StreamEvent, ToolChoice,
+    ToolDefinition,
 };
 use serde_json::json;
 use tokio::io::{AsyncReadExt, AsyncWriteExt};
@@ -195,6 +196,82 @@ async fn stream_message_normalizes_text_and_multiple_tool_calls() {
     assert!(request.body.contains("\"stream\":true"));
 }
 
+#[allow(clippy::await_holding_lock)]
+#[tokio::test]
+async fn openai_streaming_requests_opt_into_usage_chunks() {
+    let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
+    let sse = concat!(
+        "data: {\"id\":\"chatcmpl_openai_stream\",\"model\":\"gpt-5\",\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n",
+        "data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n",
+        "data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}\n\n",
+        "data: [DONE]\n\n"
+    );
+    let server = spawn_server(
+        state.clone(),
+        vec![http_response_with_headers(
+            "200 OK",
+            "text/event-stream",
+            sse,
+            &[("x-request-id", "req_openai_stream")],
+        )],
+    )
+    .await;
+
+    let client = OpenAiCompatClient::new("openai-test-key", OpenAiCompatConfig::openai())
+        .with_base_url(server.base_url());
+    let mut stream = client
+        .stream_message(&sample_request(false))
+        .await
+        .expect("stream should start");
+
+    assert_eq!(stream.request_id(), Some("req_openai_stream"));
+
+    let mut events = Vec::new();
+    while let Some(event) = stream.next_event().await.expect("event should parse") {
+        events.push(event);
+    }
+
+    assert!(matches!(events[0], StreamEvent::MessageStart(_)));
+    assert!(matches!(
+        events[1],
+        StreamEvent::ContentBlockStart(ContentBlockStartEvent {
+            content_block: OutputContentBlock::Text { .. },
+            ..
+        })
+    ));
+    assert!(matches!(
+        events[2],
+        StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
+            delta: ContentBlockDelta::TextDelta { .. },
+            ..
+        })
+    ));
+    assert!(matches!(
+        events[3],
+        StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 })
+    ));
+    assert!(matches!(
+        events[4],
+        StreamEvent::MessageDelta(MessageDeltaEvent { .. })
+    ));
+    assert!(matches!(events[5], StreamEvent::MessageStop(_)));
+
+    match &events[4] {
+        StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => {
+            assert_eq!(usage.input_tokens, 9);
+            assert_eq!(usage.output_tokens, 4);
+        }
+        other => panic!("expected message delta, got {other:?}"),
+    }
+
+    let captured = state.lock().await;
+    let request = captured.first().expect("captured request");
+    assert_eq!(request.path, "/chat/completions");
+    let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
+    assert_eq!(body["stream"], json!(true));
+    assert_eq!(body["stream_options"], json!({"include_usage": true}));
+}
+
 #[allow(clippy::await_holding_lock)]
 #[tokio::test]
 async fn provider_client_dispatches_xai_requests_from_env() {