|
|
@@ -96,6 +96,7 @@ pub struct ToolSpec {
|
|
|
#[derive(Debug, Clone, PartialEq)]
|
|
|
pub struct GlobalToolRegistry {
|
|
|
plugin_tools: Vec<PluginTool>,
|
|
|
+ enforcer: Option<PermissionEnforcer>,
|
|
|
}
|
|
|
|
|
|
impl GlobalToolRegistry {
|
|
|
@@ -103,6 +104,7 @@ impl GlobalToolRegistry {
|
|
|
pub fn builtin() -> Self {
|
|
|
Self {
|
|
|
plugin_tools: Vec::new(),
|
|
|
+ enforcer: None,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -125,7 +127,7 @@ impl GlobalToolRegistry {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- Ok(Self { plugin_tools })
|
|
|
+ Ok(Self { plugin_tools, enforcer: None })
|
|
|
}
|
|
|
|
|
|
pub fn normalize_allowed_tools(
|
|
|
@@ -229,7 +231,14 @@ impl GlobalToolRegistry {
|
|
|
Ok(builtin.chain(plugin).collect())
|
|
|
}
|
|
|
|
|
|
+ pub fn set_enforcer(&mut self, enforcer: PermissionEnforcer) {
|
|
|
+ self.enforcer = Some(enforcer);
|
|
|
+ }
|
|
|
+
|
|
|
pub fn execute(&self, name: &str, input: &Value) -> Result<String, String> {
|
|
|
+ if let Some(enforcer) = &self.enforcer {
|
|
|
+ enforce_permission_check(enforcer, name, input)?;
|
|
|
+ }
|
|
|
if mvp_tool_specs().iter().any(|spec| spec.name == name) {
|
|
|
return execute_tool(name, input);
|
|
|
}
|
|
|
@@ -2776,11 +2785,12 @@ impl ApiClient for ProviderRuntimeClient {
|
|
|
|
|
|
struct SubagentToolExecutor {
|
|
|
allowed_tools: BTreeSet<String>,
|
|
|
+ enforcer: Option<PermissionEnforcer>,
|
|
|
}
|
|
|
|
|
|
impl SubagentToolExecutor {
|
|
|
fn new(allowed_tools: BTreeSet<String>) -> Self {
|
|
|
- Self { allowed_tools }
|
|
|
+ Self { allowed_tools, enforcer: None }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -2793,6 +2803,10 @@ impl ToolExecutor for SubagentToolExecutor {
|
|
|
}
|
|
|
let value = serde_json::from_str(input)
|
|
|
.map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?;
|
|
|
+ if let Some(enforcer) = &self.enforcer {
|
|
|
+ enforce_permission_check(enforcer, tool_name, &value)
|
|
|
+ .map_err(ToolError::new)?;
|
|
|
+ }
|
|
|
execute_tool(tool_name, &value).map_err(ToolError::new)
|
|
|
}
|
|
|
}
|
|
|
@@ -4890,7 +4904,7 @@ mod tests {
|
|
|
AssistantEvent::MessageStop,
|
|
|
])
|
|
|
}
|
|
|
- _ => panic!("unexpected mock stream call"),
|
|
|
+ _ => unreachable!("extra mock stream call"),
|
|
|
}
|
|
|
}
|
|
|
}
|