Преглед изворни кода

feat: grok provider tests + cargo fmt

Yeachan-Heo пре 2 месеци
родитељ
комит
178934a9a0
2 измењених фајлова са 172 додато и 39 уклоњено
  1. 93 17
      rust/crates/rusty-claude-cli/src/main.rs
  2. 79 22
      rust/crates/tools/src/lib.rs

+ 93 - 17
rust/crates/rusty-claude-cli/src/main.rs

@@ -2046,7 +2046,7 @@ impl ApiClient for ProviderRuntimeClient {
             let renderer = TerminalRenderer::new();
             let mut markdown_stream = MarkdownStreamState::default();
             let mut events = Vec::new();
-            let mut pending_tool: Option<(String, String, String)> = None;
+            let mut pending_tools: BTreeMap<u32, (String, String, String)> = BTreeMap::new();
             let mut saw_stop = false;
 
             while let Some(event) = stream
@@ -2057,15 +2057,23 @@ impl ApiClient for ProviderRuntimeClient {
                 match event {
                     ApiStreamEvent::MessageStart(start) => {
                         for block in start.message.content {
-                            push_output_block(block, out, &mut events, &mut pending_tool, true)?;
+                            push_output_block(
+                                block,
+                                0,
+                                out,
+                                &mut events,
+                                &mut pending_tools,
+                                true,
+                            )?;
                         }
                     }
                     ApiStreamEvent::ContentBlockStart(start) => {
                         push_output_block(
                             start.content_block,
+                            start.index,
                             out,
                             &mut events,
-                            &mut pending_tool,
+                            &mut pending_tools,
                             true,
                         )?;
                     }
@@ -2081,18 +2089,18 @@ impl ApiClient for ProviderRuntimeClient {
                             }
                         }
                         ContentBlockDelta::InputJsonDelta { partial_json } => {
-                            if let Some((_, _, input)) = &mut pending_tool {
+                            if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) {
                                 input.push_str(&partial_json);
                             }
                         }
                     },
-                    ApiStreamEvent::ContentBlockStop(_) => {
+                    ApiStreamEvent::ContentBlockStop(stop) => {
                         if let Some(rendered) = markdown_stream.flush(&renderer) {
                             write!(out, "{rendered}")
                                 .and_then(|()| out.flush())
                                 .map_err(|error| RuntimeError::new(error.to_string()))?;
                         }
-                        if let Some((id, name, input)) = pending_tool.take() {
+                        if let Some((id, name, input)) = pending_tools.remove(&stop.index) {
                             // Display tool call now that input is fully accumulated
                             writeln!(out, "\n{}", format_tool_call_start(&name, &input))
                                 .and_then(|()| out.flush())
@@ -2556,9 +2564,10 @@ fn truncate_for_summary(value: &str, limit: usize) -> String {
 
 fn push_output_block(
     block: OutputContentBlock,
+    block_index: u32,
     out: &mut (impl Write + ?Sized),
     events: &mut Vec<AssistantEvent>,
-    pending_tool: &mut Option<(String, String, String)>,
+    pending_tools: &mut BTreeMap<u32, (String, String, String)>,
     streaming_tool_input: bool,
 ) -> Result<(), RuntimeError> {
     match block {
@@ -2583,7 +2592,7 @@ fn push_output_block(
             } else {
                 input.to_string()
             };
-            *pending_tool = Some((id, name, initial_input));
+            pending_tools.insert(block_index, (id, name, initial_input));
         }
     }
     Ok(())
@@ -2594,11 +2603,13 @@ fn response_to_events(
     out: &mut (impl Write + ?Sized),
 ) -> Result<Vec<AssistantEvent>, RuntimeError> {
     let mut events = Vec::new();
-    let mut pending_tool = None;
+    let mut pending_tools = BTreeMap::new();
 
-    for block in response.content {
-        push_output_block(block, out, &mut events, &mut pending_tool, false)?;
-        if let Some((id, name, input)) = pending_tool.take() {
+    for (index, block) in response.content.into_iter().enumerate() {
+        let index =
+            u32::try_from(index).map_err(|_| RuntimeError::new("response block index overflow"))?;
+        push_output_block(block, index, out, &mut events, &mut pending_tools, false)?;
+        if let Some((id, name, input)) = pending_tools.remove(&index) {
             events.push(AssistantEvent::ToolUse { id, name, input });
         }
     }
@@ -2824,6 +2835,7 @@ mod tests {
     use api::{MessageResponse, OutputContentBlock, Usage};
     use runtime::{AssistantEvent, ContentBlock, ConversationMessage, MessageRole, PermissionMode};
     use serde_json::json;
+    use std::collections::BTreeMap;
     use std::path::PathBuf;
 
     #[test]
@@ -3373,15 +3385,16 @@ mod tests {
     fn push_output_block_renders_markdown_text() {
         let mut out = Vec::new();
         let mut events = Vec::new();
-        let mut pending_tool = None;
+        let mut pending_tools = BTreeMap::new();
 
         push_output_block(
             OutputContentBlock::Text {
                 text: "# Heading".to_string(),
             },
+            0,
             &mut out,
             &mut events,
-            &mut pending_tool,
+            &mut pending_tools,
             false,
         )
         .expect("text block should render");
@@ -3395,7 +3408,7 @@ mod tests {
     fn push_output_block_skips_empty_object_prefix_for_tool_streams() {
         let mut out = Vec::new();
         let mut events = Vec::new();
-        let mut pending_tool = None;
+        let mut pending_tools = BTreeMap::new();
 
         push_output_block(
             OutputContentBlock::ToolUse {
@@ -3403,20 +3416,83 @@ mod tests {
                 name: "read_file".to_string(),
                 input: json!({}),
             },
+            1,
             &mut out,
             &mut events,
-            &mut pending_tool,
+            &mut pending_tools,
             true,
         )
         .expect("tool block should accumulate");
 
         assert!(events.is_empty());
         assert_eq!(
-            pending_tool,
+            pending_tools.remove(&1),
             Some(("tool-1".to_string(), "read_file".to_string(), String::new(),))
         );
     }
 
+    #[test]
+    fn pending_tools_preserve_multiple_streaming_tool_calls_by_index() {
+        let mut out = Vec::new();
+        let mut events = Vec::new();
+        let mut pending_tools = BTreeMap::new();
+
+        push_output_block(
+            OutputContentBlock::ToolUse {
+                id: "tool-1".to_string(),
+                name: "read_file".to_string(),
+                input: json!({}),
+            },
+            1,
+            &mut out,
+            &mut events,
+            &mut pending_tools,
+            true,
+        )
+        .expect("first tool should accumulate");
+        push_output_block(
+            OutputContentBlock::ToolUse {
+                id: "tool-2".to_string(),
+                name: "grep_search".to_string(),
+                input: json!({}),
+            },
+            2,
+            &mut out,
+            &mut events,
+            &mut pending_tools,
+            true,
+        )
+        .expect("second tool should accumulate");
+
+        pending_tools
+            .get_mut(&1)
+            .expect("first tool pending")
+            .2
+            .push_str("{\"path\":\"src/main.rs\"}");
+        pending_tools
+            .get_mut(&2)
+            .expect("second tool pending")
+            .2
+            .push_str("{\"pattern\":\"TODO\"}");
+
+        assert_eq!(
+            pending_tools.remove(&1),
+            Some((
+                "tool-1".to_string(),
+                "read_file".to_string(),
+                "{\"path\":\"src/main.rs\"}".to_string(),
+            ))
+        );
+        assert_eq!(
+            pending_tools.remove(&2),
+            Some((
+                "tool-2".to_string(),
+                "grep_search".to_string(),
+                "{\"pattern\":\"TODO\"}".to_string(),
+            ))
+        );
+    }
+
     #[test]
     fn response_to_events_preserves_empty_object_json_input_outside_streaming() {
         let mut out = Vec::new();

+ 79 - 22
rust/crates/tools/src/lib.rs

@@ -4,10 +4,9 @@ use std::process::Command;
 use std::time::{Duration, Instant};
 
 use api::{
-    detect_provider_kind, max_tokens_for_model, resolve_model_alias, ContentBlockDelta,
-    InputContentBlock, InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
-    ProviderClient, ProviderKind, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition,
-    ToolResultContentBlock,
+    max_tokens_for_model, resolve_model_alias, ContentBlockDelta, InputContentBlock, InputMessage,
+    MessageRequest, MessageResponse, OutputContentBlock, ProviderClient,
+    StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
 };
 use reqwest::blocking::Client;
 use runtime::{
@@ -1646,11 +1645,7 @@ struct ProviderRuntimeClient {
 impl ProviderRuntimeClient {
     fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> {
         let model = resolve_model_alias(&model).to_string();
-        let client = match detect_provider_kind(&model) {
-            ProviderKind::Anthropic | ProviderKind::Xai | ProviderKind::OpenAi => {
-                ProviderClient::from_model(&model).map_err(|error| error.to_string())?
-            }
-        };
+        let client = ProviderClient::from_model(&model).map_err(|error| error.to_string())?;
         Ok(Self {
             runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?,
             client,
@@ -1687,7 +1682,7 @@ impl ApiClient for ProviderRuntimeClient {
                 .await
                 .map_err(|error| RuntimeError::new(error.to_string()))?;
             let mut events = Vec::new();
-            let mut pending_tool: Option<(String, String, String)> = None;
+            let mut pending_tools: BTreeMap<u32, (String, String, String)> = BTreeMap::new();
             let mut saw_stop = false;
 
             while let Some(event) = stream
@@ -1698,14 +1693,15 @@ impl ApiClient for ProviderRuntimeClient {
                 match event {
                     ApiStreamEvent::MessageStart(start) => {
                         for block in start.message.content {
-                            push_output_block(block, &mut events, &mut pending_tool, true);
+                            push_output_block(block, 0, &mut events, &mut pending_tools, true);
                         }
                     }
                     ApiStreamEvent::ContentBlockStart(start) => {
                         push_output_block(
                             start.content_block,
+                            start.index,
                             &mut events,
-                            &mut pending_tool,
+                            &mut pending_tools,
                             true,
                         );
                     }
@@ -1716,13 +1712,13 @@ impl ApiClient for ProviderRuntimeClient {
                             }
                         }
                         ContentBlockDelta::InputJsonDelta { partial_json } => {
-                            if let Some((_, _, input)) = &mut pending_tool {
+                            if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) {
                                 input.push_str(&partial_json);
                             }
                         }
                     },
-                    ApiStreamEvent::ContentBlockStop(_) => {
-                        if let Some((id, name, input)) = pending_tool.take() {
+                    ApiStreamEvent::ContentBlockStop(stop) => {
+                        if let Some((id, name, input)) = pending_tools.remove(&stop.index) {
                             events.push(AssistantEvent::ToolUse { id, name, input });
                         }
                     }
@@ -1843,8 +1839,9 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> {
 
 fn push_output_block(
     block: OutputContentBlock,
+    block_index: u32,
     events: &mut Vec<AssistantEvent>,
-    pending_tool: &mut Option<(String, String, String)>,
+    pending_tools: &mut BTreeMap<u32, (String, String, String)>,
     streaming_tool_input: bool,
 ) {
     match block {
@@ -1862,18 +1859,19 @@ fn push_output_block(
             } else {
                 input.to_string()
             };
-            *pending_tool = Some((id, name, initial_input));
+            pending_tools.insert(block_index, (id, name, initial_input));
         }
     }
 }
 
 fn response_to_events(response: MessageResponse) -> Vec<AssistantEvent> {
     let mut events = Vec::new();
-    let mut pending_tool = None;
+    let mut pending_tools = BTreeMap::new();
 
-    for block in response.content {
-        push_output_block(block, &mut events, &mut pending_tool, false);
-        if let Some((id, name, input)) = pending_tool.take() {
+    for (index, block) in response.content.into_iter().enumerate() {
+        let index = u32::try_from(index).expect("response block index overflow");
+        push_output_block(block, index, &mut events, &mut pending_tools, false);
+        if let Some((id, name, input)) = pending_tools.remove(&index) {
             events.push(AssistantEvent::ToolUse { id, name, input });
         }
     }
@@ -2897,6 +2895,7 @@ fn parse_skill_description(contents: &str) -> Option<String> {
 
 #[cfg(test)]
 mod tests {
+    use std::collections::BTreeMap;
     use std::collections::BTreeSet;
     use std::fs;
     use std::io::{Read, Write};
@@ -2909,8 +2908,9 @@ mod tests {
     use super::{
         agent_permission_policy, allowed_tools_for_subagent, execute_agent_with_spawn,
         execute_tool, final_assistant_text, mvp_tool_specs, persist_agent_terminal_state,
-        AgentInput, AgentJob, SubagentToolExecutor,
+        push_output_block, AgentInput, AgentJob, SubagentToolExecutor,
     };
+    use api::OutputContentBlock;
     use runtime::{ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, Session};
     use serde_json::json;
 
@@ -3125,6 +3125,63 @@ mod tests {
         assert!(error.contains("relative URL without a base") || error.contains("empty host"));
     }
 
+    #[test]
+    fn pending_tools_preserve_multiple_streaming_tool_calls_by_index() {
+        let mut events = Vec::new();
+        let mut pending_tools = BTreeMap::new();
+
+        push_output_block(
+            OutputContentBlock::ToolUse {
+                id: "tool-1".to_string(),
+                name: "read_file".to_string(),
+                input: json!({}),
+            },
+            1,
+            &mut events,
+            &mut pending_tools,
+            true,
+        );
+        push_output_block(
+            OutputContentBlock::ToolUse {
+                id: "tool-2".to_string(),
+                name: "grep_search".to_string(),
+                input: json!({}),
+            },
+            2,
+            &mut events,
+            &mut pending_tools,
+            true,
+        );
+
+        pending_tools
+            .get_mut(&1)
+            .expect("first tool pending")
+            .2
+            .push_str("{\"path\":\"src/main.rs\"}");
+        pending_tools
+            .get_mut(&2)
+            .expect("second tool pending")
+            .2
+            .push_str("{\"pattern\":\"TODO\"}");
+
+        assert_eq!(
+            pending_tools.remove(&1),
+            Some((
+                "tool-1".to_string(),
+                "read_file".to_string(),
+                "{\"path\":\"src/main.rs\"}".to_string(),
+            ))
+        );
+        assert_eq!(
+            pending_tools.remove(&2),
+            Some((
+                "tool-2".to_string(),
+                "grep_search".to_string(),
+                "{\"pattern\":\"TODO\"}".to_string(),
+            ))
+        );
+    }
+
     #[test]
     fn todo_write_persists_and_returns_previous_state() {
         let _guard = env_lock()