Эх сурвалжийг харах

feat(tools): error propagation, REPL timeout, edge-case validation

- Replace NotebookEdit expect() with Result-based error propagation
- Add 5-minute guard to Sleep duration
- Reject empty StructuredOutput payloads
- Enforce timeout_ms in REPL via spawn+try_wait+kill
- Add edge-case tests: excessive/zero sleep, empty output, REPL timeout
- Verified: cargo test -p tools 35 passed, clippy clean
YeonGyu-Kim 2 сар өмнө
parent
commit
73187de6ea
1 өөрчлөгдсөн 122 нэмэгдсэн , 19 устгасан
  1. 122 19
      rust/crates/tools/src/lib.rs

+ 122 - 19
rust/crates/tools/src/lib.rs

@@ -169,7 +169,6 @@ impl GlobalToolRegistry {
         builtin.chain(plugin).collect()
     }
 
-    #[must_use]
     pub fn permission_specs(
         &self,
         allowed_tools: Option<&BTreeSet<String>>,
@@ -648,7 +647,7 @@ fn run_notebook_edit(input: NotebookEditInput) -> Result<String, String> {
 }
 
 fn run_sleep(input: SleepInput) -> Result<String, String> {
-    to_pretty_json(execute_sleep(input))
+    to_pretty_json(execute_sleep(input)?)
 }
 
 fn run_brief(input: BriefInput) -> Result<String, String> {
@@ -660,7 +659,7 @@ fn run_config(input: ConfigInput) -> Result<String, String> {
 }
 
 fn run_structured_output(input: StructuredOutputInput) -> Result<String, String> {
-    to_pretty_json(execute_structured_output(input))
+    to_pretty_json(execute_structured_output(input)?)
 }
 
 fn run_repl(input: ReplInput) -> Result<String, String> {
@@ -2347,7 +2346,8 @@ fn execute_notebook_edit(input: NotebookEditInput) -> Result<NotebookEditOutput,
 
     let cell_id = match edit_mode {
         NotebookEditMode::Insert => {
-            let resolved_cell_type = resolved_cell_type.expect("insert cell type");
+            let resolved_cell_type = resolved_cell_type
+                .ok_or_else(|| String::from("insert mode requires a cell type"))?;
             let new_id = make_cell_id(cells.len());
             let new_cell = build_notebook_cell(&new_id, resolved_cell_type, &new_source);
             let insert_at = target_index.map_or(cells.len(), |index| index + 1);
@@ -2359,16 +2359,21 @@ fn execute_notebook_edit(input: NotebookEditInput) -> Result<NotebookEditOutput,
                 .map(ToString::to_string)
         }
         NotebookEditMode::Delete => {
-            let removed = cells.remove(target_index.expect("delete target index"));
+            let idx = target_index
+                .ok_or_else(|| String::from("delete mode requires a target cell index"))?;
+            let removed = cells.remove(idx);
             removed
                 .get("id")
                 .and_then(serde_json::Value::as_str)
                 .map(ToString::to_string)
         }
         NotebookEditMode::Replace => {
-            let resolved_cell_type = resolved_cell_type.expect("replace cell type");
+            let resolved_cell_type = resolved_cell_type
+                .ok_or_else(|| String::from("replace mode requires a cell type"))?;
+            let idx = target_index
+                .ok_or_else(|| String::from("replace mode requires a target cell index"))?;
             let cell = cells
-                .get_mut(target_index.expect("replace target index"))
+                .get_mut(idx)
                 .ok_or_else(|| String::from("Cell index out of range"))?;
             cell["source"] = serde_json::Value::Array(source_lines(&new_source));
             cell["cell_type"] = serde_json::Value::String(match resolved_cell_type {
@@ -2459,13 +2464,21 @@ fn cell_kind(cell: &serde_json::Value) -> Option<NotebookCellType> {
         })
 }
 
+const MAX_SLEEP_DURATION_MS: u64 = 300_000;
+
 #[allow(clippy::needless_pass_by_value)]
-fn execute_sleep(input: SleepInput) -> SleepOutput {
+fn execute_sleep(input: SleepInput) -> Result<SleepOutput, String> {
+    if input.duration_ms > MAX_SLEEP_DURATION_MS {
+        return Err(format!(
+            "duration_ms {} exceeds maximum allowed sleep of {MAX_SLEEP_DURATION_MS}ms",
+            input.duration_ms,
+        ));
+    }
     std::thread::sleep(Duration::from_millis(input.duration_ms));
-    SleepOutput {
+    Ok(SleepOutput {
         duration_ms: input.duration_ms,
         message: format!("Slept for {}ms", input.duration_ms),
-    }
+    })
 }
 
 fn execute_brief(input: BriefInput) -> Result<BriefOutput, String> {
@@ -2562,25 +2575,62 @@ fn execute_config(input: ConfigInput) -> Result<ConfigOutput, String> {
     }
 }
 
-fn execute_structured_output(input: StructuredOutputInput) -> StructuredOutputResult {
-    StructuredOutputResult {
+fn execute_structured_output(
+    input: StructuredOutputInput,
+) -> Result<StructuredOutputResult, String> {
+    if input.0.is_empty() {
+        return Err(String::from("structured output payload must not be empty"));
+    }
+    Ok(StructuredOutputResult {
         data: String::from("Structured output provided successfully"),
         structured_output: input.0,
-    }
+    })
 }
 
 fn execute_repl(input: ReplInput) -> Result<ReplOutput, String> {
     if input.code.trim().is_empty() {
         return Err(String::from("code must not be empty"));
     }
-    let _ = input.timeout_ms;
     let runtime = resolve_repl_runtime(&input.language)?;
     let started = Instant::now();
-    let output = Command::new(runtime.program)
+    let mut process = Command::new(runtime.program);
+    process
         .args(runtime.args)
         .arg(&input.code)
-        .output()
-        .map_err(|error| error.to_string())?;
+        .stdin(std::process::Stdio::null())
+        .stdout(std::process::Stdio::piped())
+        .stderr(std::process::Stdio::piped());
+
+    let output = if let Some(timeout_ms) = input.timeout_ms {
+        let mut child = process.spawn().map_err(|error| error.to_string())?;
+        loop {
+            if child
+                .try_wait()
+                .map_err(|error| error.to_string())?
+                .is_some()
+            {
+                break child
+                    .wait_with_output()
+                    .map_err(|error| error.to_string())?;
+            }
+            if started.elapsed() >= Duration::from_millis(timeout_ms) {
+                child.kill().map_err(|error| error.to_string())?;
+                child
+                    .wait_with_output()
+                    .map_err(|error| error.to_string())?;
+                return Err(format!(
+                    "REPL execution exceeded timeout of {timeout_ms} ms"
+                ));
+            }
+            std::thread::sleep(Duration::from_millis(10));
+        }
+    } else {
+        process
+            .spawn()
+            .map_err(|error| error.to_string())?
+            .wait_with_output()
+            .map_err(|error| error.to_string())?
+    };
 
     Ok(ReplOutput {
         language: input.language,
@@ -3157,8 +3207,8 @@ mod tests {
             .expect_err("unknown plugin permission should fail");
         assert!(unknown_permission.contains("unsupported plugin permission: admin"));
 
-        let empty_permission = permission_mode_from_plugin("")
-            .expect_err("empty plugin permission should fail");
+        let empty_permission =
+            permission_mode_from_plugin("").expect_err("empty plugin permission should fail");
         assert!(empty_permission.contains("unsupported plugin permission: "));
     }
 
@@ -4226,6 +4276,21 @@ mod tests {
         assert!(elapsed >= Duration::from_millis(15));
     }
 
+    #[test]
+    fn given_excessive_duration_when_sleep_then_rejects_with_error() {
+        let result = execute_tool("Sleep", &json!({"duration_ms": 999_999_999_u64}));
+        let error = result.expect_err("excessive sleep should fail");
+        assert!(error.contains("exceeds maximum allowed sleep"));
+    }
+
+    #[test]
+    fn given_zero_duration_when_sleep_then_succeeds() {
+        let result =
+            execute_tool("Sleep", &json!({"duration_ms": 0})).expect("0ms sleep should succeed");
+        let output: serde_json::Value = serde_json::from_str(&result).expect("json");
+        assert_eq!(output["duration_ms"], 0);
+    }
+
     #[test]
     fn brief_returns_sent_message_and_attachment_metadata() {
         let attachment = std::env::temp_dir().join(format!(
@@ -4330,6 +4395,13 @@ mod tests {
         assert_eq!(output["structured_output"]["items"][1], 2);
     }
 
+    #[test]
+    fn given_empty_payload_when_structured_output_then_rejects_with_error() {
+        let result = execute_tool("StructuredOutput", &json!({}));
+        let error = result.expect_err("empty payload should fail");
+        assert!(error.contains("must not be empty"));
+    }
+
     #[test]
     fn repl_executes_python_code() {
         let result = execute_tool(
@@ -4343,6 +4415,37 @@ mod tests {
         assert!(output["stdout"].as_str().expect("stdout").contains('2'));
     }
 
+    #[test]
+    fn given_empty_code_when_repl_then_rejects_with_error() {
+        let result = execute_tool("REPL", &json!({"language": "python", "code": "   "}));
+
+        let error = result.expect_err("empty REPL code should fail");
+        assert!(error.contains("code must not be empty"));
+    }
+
+    #[test]
+    fn given_unsupported_language_when_repl_then_rejects_with_error() {
+        let result = execute_tool("REPL", &json!({"language": "ruby", "code": "puts 1"}));
+
+        let error = result.expect_err("unsupported REPL language should fail");
+        assert!(error.contains("unsupported REPL language: ruby"));
+    }
+
+    #[test]
+    fn given_timeout_ms_when_repl_blocks_then_returns_timeout_error() {
+        let result = execute_tool(
+            "REPL",
+            &json!({
+                "language": "python",
+                "code": "import time\ntime.sleep(1)",
+                "timeout_ms": 10
+            }),
+        );
+
+        let error = result.expect_err("timed out REPL execution should fail");
+        assert!(error.contains("REPL execution exceeded timeout of 10 ms"));
+    }
+
     #[test]
     fn powershell_runs_via_stub_shell() {
         let _guard = env_lock()