Просмотр исходного кода

Merge remote-tracking branch 'origin/rcc/cli' into dev/rust

Yeachan-Heo 2 месяцев назад
Родитель
Сommit
842abcfe85
2 измененных файлов с 227 добавлено и 20 удалено
  1. 3 1
      rust/README.md
  2. 224 19
      rust/crates/rusty-claude-cli/src/main.rs

+ 3 - 1
rust/README.md

@@ -132,7 +132,9 @@ cargo run -p rusty-claude-cli -- --resume session.json /memory /config
 - `bootstrap-plan` — print the current bootstrap skeleton
 - `system-prompt [--cwd PATH] [--date YYYY-MM-DD]` — render the synthesized system prompt
 - `--help` / `-h` — show CLI help
-- `--version` / `-V` — print the CLI version
+- `--version` / `-V` — print the CLI version and build info locally (no API call)
+- `--output-format text|json` — choose non-interactive prompt output rendering
+- `--allowedTools <tool[,tool...]>` — restrict enabled tools for interactive sessions and prompt-mode tool use
 
 ### Interactive slash commands
 

+ 224 - 19
rust/crates/rusty-claude-cli/src/main.rs

@@ -1,6 +1,7 @@
 mod input;
 mod render;
 
+use std::collections::{BTreeMap, BTreeSet};
 use std::env;
 use std::fs;
 use std::io::{self, Write};
@@ -32,6 +33,8 @@ const VERSION: &str = env!("CARGO_PKG_VERSION");
 const BUILD_TARGET: Option<&str> = option_env!("TARGET");
 const GIT_SHA: Option<&str> = option_env!("GIT_SHA");
 
+type AllowedToolSet = BTreeSet<String>;
+
 fn main() {
     if let Err(error) = run() {
         eprintln!(
@@ -49,6 +52,7 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {
         CliAction::DumpManifests => dump_manifests(),
         CliAction::BootstrapPlan => print_bootstrap_plan(),
         CliAction::PrintSystemPrompt { cwd, date } => print_system_prompt(cwd, date),
+        CliAction::Version => print_version(),
         CliAction::ResumeSession {
             session_path,
             commands,
@@ -57,8 +61,13 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {
             prompt,
             model,
             output_format,
-        } => LiveCli::new(model, false)?.run_turn_with_output(&prompt, output_format)?,
-        CliAction::Repl { model } => run_repl(model)?,
+            allowed_tools,
+        } => LiveCli::new(model, false, allowed_tools)?
+            .run_turn_with_output(&prompt, output_format)?,
+        CliAction::Repl {
+            model,
+            allowed_tools,
+        } => run_repl(model, allowed_tools)?,
         CliAction::Help => print_help(),
     }
     Ok(())
@@ -72,6 +81,7 @@ enum CliAction {
         cwd: PathBuf,
         date: String,
     },
+    Version,
     ResumeSession {
         session_path: PathBuf,
         commands: Vec<String>,
@@ -80,9 +90,11 @@ enum CliAction {
         prompt: String,
         model: String,
         output_format: CliOutputFormat,
+        allowed_tools: Option<AllowedToolSet>,
     },
     Repl {
         model: String,
+        allowed_tools: Option<AllowedToolSet>,
     },
     // prompt-mode formatting is only supported for non-interactive runs
     Help,
@@ -109,11 +121,17 @@ impl CliOutputFormat {
 fn parse_args(args: &[String]) -> Result<CliAction, String> {
     let mut model = DEFAULT_MODEL.to_string();
     let mut output_format = CliOutputFormat::Text;
+    let mut wants_version = false;
+    let mut allowed_tool_values = Vec::new();
     let mut rest = Vec::new();
     let mut index = 0;
 
     while index < args.len() {
         match args[index].as_str() {
+            "--version" | "-V" => {
+                wants_version = true;
+                index += 1;
+            }
             "--model" => {
                 let value = args
                     .get(index + 1)
@@ -136,6 +154,21 @@ fn parse_args(args: &[String]) -> Result<CliAction, String> {
                 output_format = CliOutputFormat::parse(&flag[16..])?;
                 index += 1;
             }
+            "--allowedTools" | "--allowed-tools" => {
+                let value = args
+                    .get(index + 1)
+                    .ok_or_else(|| "missing value for --allowedTools".to_string())?;
+                allowed_tool_values.push(value.clone());
+                index += 2;
+            }
+            flag if flag.starts_with("--allowedTools=") => {
+                allowed_tool_values.push(flag[15..].to_string());
+                index += 1;
+            }
+            flag if flag.starts_with("--allowed-tools=") => {
+                allowed_tool_values.push(flag[16..].to_string());
+                index += 1;
+            }
             other => {
                 rest.push(other.to_string());
                 index += 1;
@@ -143,8 +176,17 @@ fn parse_args(args: &[String]) -> Result<CliAction, String> {
         }
     }
 
+    if wants_version {
+        return Ok(CliAction::Version);
+    }
+
+    let allowed_tools = normalize_allowed_tools(&allowed_tool_values)?;
+
     if rest.is_empty() {
-        return Ok(CliAction::Repl { model });
+        return Ok(CliAction::Repl {
+            model,
+            allowed_tools,
+        });
     }
     if matches!(rest.first().map(String::as_str), Some("--help" | "-h")) {
         return Ok(CliAction::Help);
@@ -166,17 +208,74 @@ fn parse_args(args: &[String]) -> Result<CliAction, String> {
                 prompt,
                 model,
                 output_format,
+                allowed_tools,
             })
         }
         other if !other.starts_with('/') => Ok(CliAction::Prompt {
             prompt: rest.join(" "),
             model,
             output_format,
+            allowed_tools,
         }),
         other => Err(format!("unknown subcommand: {other}")),
     }
 }
 
+fn normalize_allowed_tools(values: &[String]) -> Result<Option<AllowedToolSet>, String> {
+    if values.is_empty() {
+        return Ok(None);
+    }
+
+    let canonical_names = mvp_tool_specs()
+        .into_iter()
+        .map(|spec| spec.name.to_string())
+        .collect::<Vec<_>>();
+    let mut name_map = canonical_names
+        .iter()
+        .map(|name| (normalize_tool_name(name), name.clone()))
+        .collect::<BTreeMap<_, _>>();
+
+    for (alias, canonical) in [
+        ("read", "read_file"),
+        ("write", "write_file"),
+        ("edit", "edit_file"),
+        ("glob", "glob_search"),
+        ("grep", "grep_search"),
+    ] {
+        name_map.insert(alias.to_string(), canonical.to_string());
+    }
+
+    let mut allowed = AllowedToolSet::new();
+    for value in values {
+        for token in value
+            .split(|ch: char| ch == ',' || ch.is_whitespace())
+            .filter(|token| !token.is_empty())
+        {
+            let normalized = normalize_tool_name(token);
+            let canonical = name_map.get(&normalized).ok_or_else(|| {
+                format!(
+                    "unsupported tool in --allowedTools: {token} (expected one of: {})",
+                    canonical_names.join(", ")
+                )
+            })?;
+            allowed.insert(canonical.clone());
+        }
+    }
+
+    Ok(Some(allowed))
+}
+
+fn normalize_tool_name(value: &str) -> String {
+    value.trim().replace('-', "_").to_ascii_lowercase()
+}
+
+fn filter_tool_specs(allowed_tools: Option<&AllowedToolSet>) -> Vec<tools::ToolSpec> {
+    mvp_tool_specs()
+        .into_iter()
+        .filter(|spec| allowed_tools.is_none_or(|allowed| allowed.contains(spec.name)))
+        .collect()
+}
+
 fn parse_system_prompt_args(args: &[String]) -> Result<CliAction, String> {
     let mut cwd = env::current_dir().map_err(|error| error.to_string())?;
     let mut date = DEFAULT_DATE.to_string();
@@ -255,6 +354,10 @@ fn print_system_prompt(cwd: PathBuf, date: String) {
     }
 }
 
+fn print_version() {
+    println!("{}", render_version_report());
+}
+
 fn resume_session(session_path: &Path, commands: &[String]) {
     let session = match Session::load_from_path(session_path) {
         Ok(session) => session,
@@ -608,8 +711,11 @@ fn run_resume_command(
     }
 }
 
-fn run_repl(model: String) -> Result<(), Box<dyn std::error::Error>> {
-    let mut cli = LiveCli::new(model, true)?;
+fn run_repl(
+    model: String,
+    allowed_tools: Option<AllowedToolSet>,
+) -> Result<(), Box<dyn std::error::Error>> {
+    let mut cli = LiveCli::new(model, true, allowed_tools)?;
     let editor = input::LineEditor::new("› ");
     println!("{}", cli.startup_banner());
 
@@ -647,13 +753,18 @@ struct ManagedSessionSummary {
 
 struct LiveCli {
     model: String,
+    allowed_tools: Option<AllowedToolSet>,
     system_prompt: Vec<String>,
     runtime: ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>,
     session: SessionHandle,
 }
 
 impl LiveCli {
-    fn new(model: String, enable_tools: bool) -> Result<Self, Box<dyn std::error::Error>> {
+    fn new(
+        model: String,
+        enable_tools: bool,
+        allowed_tools: Option<AllowedToolSet>,
+    ) -> Result<Self, Box<dyn std::error::Error>> {
         let system_prompt = build_system_prompt()?;
         let session = create_managed_session_handle()?;
         let runtime = build_runtime(
@@ -661,9 +772,11 @@ impl LiveCli {
             model.clone(),
             system_prompt.clone(),
             enable_tools,
+            allowed_tools.clone(),
         )?;
         let cli = Self {
             model,
+            allowed_tools,
             system_prompt,
             runtime,
             session,
@@ -849,7 +962,13 @@ impl LiveCli {
         let previous = self.model.clone();
         let session = self.runtime.session().clone();
         let message_count = session.messages.len();
-        self.runtime = build_runtime(session, model.clone(), self.system_prompt.clone(), true)?;
+        self.runtime = build_runtime(
+            session,
+            model.clone(),
+            self.system_prompt.clone(),
+            true,
+            self.allowed_tools.clone(),
+        )?;
         self.model.clone_from(&model);
         self.persist_session()?;
         println!(
@@ -883,6 +1002,7 @@ impl LiveCli {
             self.model.clone(),
             self.system_prompt.clone(),
             true,
+            self.allowed_tools.clone(),
             normalized,
         )?;
         self.persist_session()?;
@@ -907,6 +1027,7 @@ impl LiveCli {
             self.model.clone(),
             self.system_prompt.clone(),
             true,
+            self.allowed_tools.clone(),
             permission_mode_label(),
         )?;
         self.persist_session()?;
@@ -941,6 +1062,7 @@ impl LiveCli {
             self.model.clone(),
             self.system_prompt.clone(),
             true,
+            self.allowed_tools.clone(),
             permission_mode_label(),
         )?;
         self.session = handle;
@@ -1017,6 +1139,7 @@ impl LiveCli {
                     self.model.clone(),
                     self.system_prompt.clone(),
                     true,
+                    self.allowed_tools.clone(),
                     permission_mode_label(),
                 )?;
                 self.session = handle;
@@ -1046,6 +1169,7 @@ impl LiveCli {
             self.model.clone(),
             self.system_prompt.clone(),
             true,
+            self.allowed_tools.clone(),
             permission_mode_label(),
         )?;
         self.persist_session()?;
@@ -1571,6 +1695,7 @@ fn build_runtime(
     model: String,
     system_prompt: Vec<String>,
     enable_tools: bool,
+    allowed_tools: Option<AllowedToolSet>,
 ) -> Result<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
 {
     build_runtime_with_permission_mode(
@@ -1578,6 +1703,7 @@ fn build_runtime(
         model,
         system_prompt,
         enable_tools,
+        allowed_tools,
         permission_mode_label(),
     )
 }
@@ -1587,13 +1713,14 @@ fn build_runtime_with_permission_mode(
     model: String,
     system_prompt: Vec<String>,
     enable_tools: bool,
+    allowed_tools: Option<AllowedToolSet>,
     permission_mode: &str,
 ) -> Result<ConversationRuntime<AnthropicRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>>
 {
     Ok(ConversationRuntime::new(
         session,
-        AnthropicRuntimeClient::new(model, enable_tools)?,
-        CliToolExecutor::new(),
+        AnthropicRuntimeClient::new(model, enable_tools, allowed_tools.clone())?,
+        CliToolExecutor::new(allowed_tools),
         permission_policy(permission_mode),
         system_prompt,
     ))
@@ -1604,15 +1731,21 @@ struct AnthropicRuntimeClient {
     client: AnthropicClient,
     model: String,
     enable_tools: bool,
+    allowed_tools: Option<AllowedToolSet>,
 }
 
 impl AnthropicRuntimeClient {
-    fn new(model: String, enable_tools: bool) -> Result<Self, Box<dyn std::error::Error>> {
+    fn new(
+        model: String,
+        enable_tools: bool,
+        allowed_tools: Option<AllowedToolSet>,
+    ) -> Result<Self, Box<dyn std::error::Error>> {
         Ok(Self {
             runtime: tokio::runtime::Runtime::new()?,
             client: AnthropicClient::from_env()?,
             model,
             enable_tools,
+            allowed_tools,
         })
     }
 }
@@ -1626,7 +1759,7 @@ impl ApiClient for AnthropicRuntimeClient {
             messages: convert_messages(&request.messages),
             system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")),
             tools: self.enable_tools.then(|| {
-                mvp_tool_specs()
+                filter_tool_specs(self.allowed_tools.as_ref())
                     .into_iter()
                     .map(|spec| ToolDefinition {
                         name: spec.name.to_string(),
@@ -1781,18 +1914,29 @@ fn response_to_events(
 
 struct CliToolExecutor {
     renderer: TerminalRenderer,
+    allowed_tools: Option<AllowedToolSet>,
 }
 
 impl CliToolExecutor {
-    fn new() -> Self {
+    fn new(allowed_tools: Option<AllowedToolSet>) -> Self {
         Self {
             renderer: TerminalRenderer::new(),
+            allowed_tools,
         }
     }
 }
 
 impl ToolExecutor for CliToolExecutor {
     fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> {
+        if self
+            .allowed_tools
+            .as_ref()
+            .is_some_and(|allowed| !allowed.contains(tool_name))
+        {
+            return Err(ToolError::new(format!(
+                "tool `{tool_name}` is not enabled by the current --allowedTools setting"
+            )));
+        }
         let value = serde_json::from_str(input)
             .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?;
         match execute_tool(tool_name, &value) {
@@ -1864,7 +2008,7 @@ fn print_help() {
     println!("rusty-claude-cli v{VERSION}");
     println!();
     println!("Usage:");
-    println!("  rusty-claude-cli [--model MODEL]");
+    println!("  rusty-claude-cli [--model MODEL] [--allowedTools TOOL[,TOOL...]]");
     println!("      Start the interactive REPL");
     println!("  rusty-claude-cli [--model MODEL] [--output-format text|json] prompt TEXT");
     println!("      Send one prompt and exit");
@@ -1879,6 +2023,8 @@ fn print_help() {
     println!("Flags:");
     println!("  --model MODEL              Override the active model");
     println!("  --output-format FORMAT     Non-interactive output format: text or json");
+    println!("  --allowedTools TOOLS       Restrict enabled tools (repeatable; comma-separated aliases supported)");
+    println!("  --version, -V              Print version and build information locally");
     println!();
     println!("Interactive slash commands:");
     println!("{}", render_slash_command_help());
@@ -1895,18 +2041,20 @@ fn print_help() {
     println!("Examples:");
     println!("  rusty-claude-cli --model claude-opus \"summarize this repo\"");
     println!("  rusty-claude-cli --output-format json prompt \"explain src/main.rs\"");
+    println!("  rusty-claude-cli --allowedTools read,glob \"summarize Cargo.toml\"");
     println!("  rusty-claude-cli --resume session.json /status /diff /export notes.txt");
 }
 
 #[cfg(test)]
 mod tests {
     use super::{
-        format_compact_report, format_cost_report, format_init_report, format_model_report,
-        format_model_switch_report, format_permissions_report, format_permissions_switch_report,
-        format_resume_report, format_status_report, normalize_permission_mode, parse_args,
-        parse_git_status_metadata, render_config_report, render_init_claude_md,
-        render_memory_report, render_repl_help, resume_supported_slash_commands, status_context,
-        CliAction, CliOutputFormat, SlashCommand, StatusUsage, DEFAULT_MODEL,
+        filter_tool_specs, format_compact_report, format_cost_report, format_init_report,
+        format_model_report, format_model_switch_report, format_permissions_report,
+        format_permissions_switch_report, format_resume_report, format_status_report,
+        normalize_permission_mode, parse_args, parse_git_status_metadata, render_config_report,
+        render_init_claude_md, render_memory_report, render_repl_help,
+        resume_supported_slash_commands, status_context, CliAction, CliOutputFormat, SlashCommand,
+        StatusUsage, DEFAULT_MODEL,
     };
     use runtime::{ContentBlock, ConversationMessage, MessageRole};
     use std::path::{Path, PathBuf};
@@ -1917,6 +2065,7 @@ mod tests {
             parse_args(&[]).expect("args should parse"),
             CliAction::Repl {
                 model: DEFAULT_MODEL.to_string(),
+                allowed_tools: None,
             }
         );
     }
@@ -1934,6 +2083,7 @@ mod tests {
                 prompt: "hello world".to_string(),
                 model: DEFAULT_MODEL.to_string(),
                 output_format: CliOutputFormat::Text,
+                allowed_tools: None,
             }
         );
     }
@@ -1953,10 +2103,51 @@ mod tests {
                 prompt: "explain this".to_string(),
                 model: "claude-opus".to_string(),
                 output_format: CliOutputFormat::Json,
+                allowed_tools: None,
             }
         );
     }
 
+    #[test]
+    fn parses_version_flags_without_initializing_prompt_mode() {
+        assert_eq!(
+            parse_args(&["--version".to_string()]).expect("args should parse"),
+            CliAction::Version
+        );
+        assert_eq!(
+            parse_args(&["-V".to_string()]).expect("args should parse"),
+            CliAction::Version
+        );
+    }
+
+    #[test]
+    fn parses_allowed_tools_flags_with_aliases_and_lists() {
+        let args = vec![
+            "--allowedTools".to_string(),
+            "read,glob".to_string(),
+            "--allowed-tools=write_file".to_string(),
+        ];
+        assert_eq!(
+            parse_args(&args).expect("args should parse"),
+            CliAction::Repl {
+                model: DEFAULT_MODEL.to_string(),
+                allowed_tools: Some(
+                    ["glob_search", "read_file", "write_file"]
+                        .into_iter()
+                        .map(str::to_string)
+                        .collect()
+                ),
+            }
+        );
+    }
+
+    #[test]
+    fn rejects_unknown_allowed_tools() {
+        let error = parse_args(&["--allowedTools".to_string(), "teleport".to_string()])
+            .expect_err("tool should be rejected");
+        assert!(error.contains("unsupported tool in --allowedTools: teleport"));
+    }
+
     #[test]
     fn parses_system_prompt_options() {
         let args = vec![
@@ -2013,6 +2204,20 @@ mod tests {
         );
     }
 
+    #[test]
+    fn filtered_tool_specs_respect_allowlist() {
+        let allowed = ["read_file", "grep_search"]
+            .into_iter()
+            .map(str::to_string)
+            .collect();
+        let filtered = filter_tool_specs(Some(&allowed));
+        let names = filtered
+            .into_iter()
+            .map(|spec| spec.name)
+            .collect::<Vec<_>>();
+        assert_eq!(names, vec!["read_file", "grep_search"]);
+    }
+
     #[test]
     fn shared_help_uses_resume_annotation_copy() {
         let help = commands::render_slash_command_help();