| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518 |
- use std::collections::BTreeMap;
- use std::io;
- use std::process::Stdio;
- use serde::de::DeserializeOwned;
- use serde::{Deserialize, Serialize};
- use serde_json::Value as JsonValue;
- use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
- use tokio::process::{Child, ChildStdin, ChildStdout, Command};
- use crate::mcp_client::{McpClientBootstrap, McpClientTransport, McpStdioTransport};
- #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
- #[serde(untagged)]
- pub enum JsonRpcId {
- Number(u64),
- String(String),
- Null,
- }
- #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
- pub struct JsonRpcRequest<T = JsonValue> {
- pub jsonrpc: String,
- pub id: JsonRpcId,
- pub method: String,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub params: Option<T>,
- }
- impl<T> JsonRpcRequest<T> {
- #[must_use]
- pub fn new(id: JsonRpcId, method: impl Into<String>, params: Option<T>) -> Self {
- Self {
- jsonrpc: "2.0".to_string(),
- id,
- method: method.into(),
- params,
- }
- }
- }
- #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
- pub struct JsonRpcError {
- pub code: i64,
- pub message: String,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub data: Option<JsonValue>,
- }
- #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
- pub struct JsonRpcResponse<T = JsonValue> {
- pub jsonrpc: String,
- pub id: JsonRpcId,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub result: Option<T>,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub error: Option<JsonRpcError>,
- }
- #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
- #[serde(rename_all = "camelCase")]
- pub struct McpInitializeParams {
- pub protocol_version: String,
- pub capabilities: JsonValue,
- pub client_info: McpInitializeClientInfo,
- }
- #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
- #[serde(rename_all = "camelCase")]
- pub struct McpInitializeClientInfo {
- pub name: String,
- pub version: String,
- }
- #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
- #[serde(rename_all = "camelCase")]
- pub struct McpInitializeResult {
- pub protocol_version: String,
- pub capabilities: JsonValue,
- pub server_info: McpInitializeServerInfo,
- }
- #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
- #[serde(rename_all = "camelCase")]
- pub struct McpInitializeServerInfo {
- pub name: String,
- pub version: String,
- }
- #[derive(Debug)]
- pub struct McpStdioProcess {
- child: Child,
- stdin: ChildStdin,
- stdout: BufReader<ChildStdout>,
- }
- impl McpStdioProcess {
- pub fn spawn(transport: &McpStdioTransport) -> io::Result<Self> {
- let mut command = Command::new(&transport.command);
- command
- .args(&transport.args)
- .stdin(Stdio::piped())
- .stdout(Stdio::piped())
- .stderr(Stdio::inherit());
- apply_env(&mut command, &transport.env);
- let mut child = command.spawn()?;
- let stdin = child
- .stdin
- .take()
- .ok_or_else(|| io::Error::other("stdio MCP process missing stdin pipe"))?;
- let stdout = child
- .stdout
- .take()
- .ok_or_else(|| io::Error::other("stdio MCP process missing stdout pipe"))?;
- Ok(Self {
- child,
- stdin,
- stdout: BufReader::new(stdout),
- })
- }
- pub async fn write_all(&mut self, bytes: &[u8]) -> io::Result<()> {
- self.stdin.write_all(bytes).await
- }
- pub async fn flush(&mut self) -> io::Result<()> {
- self.stdin.flush().await
- }
- pub async fn write_line(&mut self, line: &str) -> io::Result<()> {
- self.write_all(line.as_bytes()).await?;
- self.write_all(b"\n").await?;
- self.flush().await
- }
- pub async fn read_line(&mut self) -> io::Result<String> {
- let mut line = String::new();
- let bytes_read = self.stdout.read_line(&mut line).await?;
- if bytes_read == 0 {
- return Err(io::Error::new(
- io::ErrorKind::UnexpectedEof,
- "MCP stdio stream closed while reading line",
- ));
- }
- Ok(line)
- }
- pub async fn read_available(&mut self) -> io::Result<Vec<u8>> {
- let mut buffer = vec![0_u8; 4096];
- let read = self.stdout.read(&mut buffer).await?;
- buffer.truncate(read);
- Ok(buffer)
- }
- pub async fn write_frame(&mut self, payload: &[u8]) -> io::Result<()> {
- let encoded = encode_frame(payload);
- self.write_all(&encoded).await?;
- self.flush().await
- }
- pub async fn read_frame(&mut self) -> io::Result<Vec<u8>> {
- let mut content_length = None;
- loop {
- let mut line = String::new();
- let bytes_read = self.stdout.read_line(&mut line).await?;
- if bytes_read == 0 {
- return Err(io::Error::new(
- io::ErrorKind::UnexpectedEof,
- "MCP stdio stream closed while reading headers",
- ));
- }
- if line == "\r\n" {
- break;
- }
- if let Some(value) = line.strip_prefix("Content-Length:") {
- let parsed = value
- .trim()
- .parse::<usize>()
- .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
- content_length = Some(parsed);
- }
- }
- let content_length = content_length.ok_or_else(|| {
- io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header")
- })?;
- let mut payload = vec![0_u8; content_length];
- self.stdout.read_exact(&mut payload).await?;
- Ok(payload)
- }
- pub async fn write_jsonrpc_message<T: Serialize>(&mut self, message: &T) -> io::Result<()> {
- let body = serde_json::to_vec(message)
- .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
- self.write_frame(&body).await
- }
- pub async fn read_jsonrpc_message<T: DeserializeOwned>(&mut self) -> io::Result<T> {
- let payload = self.read_frame().await?;
- serde_json::from_slice(&payload)
- .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
- }
- pub async fn send_request<T: Serialize>(
- &mut self,
- request: &JsonRpcRequest<T>,
- ) -> io::Result<()> {
- self.write_jsonrpc_message(request).await
- }
- pub async fn read_response<T: DeserializeOwned>(&mut self) -> io::Result<JsonRpcResponse<T>> {
- self.read_jsonrpc_message().await
- }
- pub async fn initialize(
- &mut self,
- id: JsonRpcId,
- params: McpInitializeParams,
- ) -> io::Result<JsonRpcResponse<McpInitializeResult>> {
- let request = JsonRpcRequest::new(id, "initialize", Some(params));
- self.send_request(&request).await?;
- self.read_response().await
- }
- pub async fn terminate(&mut self) -> io::Result<()> {
- self.child.kill().await
- }
- pub async fn wait(&mut self) -> io::Result<std::process::ExitStatus> {
- self.child.wait().await
- }
- }
- pub fn spawn_mcp_stdio_process(bootstrap: &McpClientBootstrap) -> io::Result<McpStdioProcess> {
- match &bootstrap.transport {
- McpClientTransport::Stdio(transport) => McpStdioProcess::spawn(transport),
- other => Err(io::Error::new(
- io::ErrorKind::InvalidInput,
- format!(
- "MCP bootstrap transport for {} is not stdio: {other:?}",
- bootstrap.server_name
- ),
- )),
- }
- }
- fn apply_env(command: &mut Command, env: &BTreeMap<String, String>) {
- for (key, value) in env {
- command.env(key, value);
- }
- }
- fn encode_frame(payload: &[u8]) -> Vec<u8> {
- let header = format!("Content-Length: {}\r\n\r\n", payload.len());
- let mut framed = header.into_bytes();
- framed.extend_from_slice(payload);
- framed
- }
- #[cfg(test)]
- mod tests {
- use std::collections::BTreeMap;
- use std::fs;
- use std::io::ErrorKind;
- use std::os::unix::fs::PermissionsExt;
- use std::path::{Path, PathBuf};
- use std::time::{SystemTime, UNIX_EPOCH};
- use serde_json::json;
- use tokio::runtime::Builder;
- use crate::config::{
- ConfigSource, McpServerConfig, McpStdioServerConfig, ScopedMcpServerConfig,
- };
- use crate::mcp_client::McpClientBootstrap;
- use super::{
- spawn_mcp_stdio_process, JsonRpcId, JsonRpcRequest, McpInitializeClientInfo,
- McpInitializeParams, McpInitializeResult, McpInitializeServerInfo, McpStdioProcess,
- };
- fn temp_dir() -> PathBuf {
- let nanos = SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .expect("time should be after epoch")
- .as_nanos();
- std::env::temp_dir().join(format!("runtime-mcp-stdio-{nanos}"))
- }
- fn write_echo_script() -> PathBuf {
- let root = temp_dir();
- fs::create_dir_all(&root).expect("temp dir");
- let script_path = root.join("echo-mcp.sh");
- fs::write(
- &script_path,
- "#!/bin/sh\nprintf 'READY:%s\\n' \"$MCP_TEST_TOKEN\"\nIFS= read -r line\nprintf 'ECHO:%s\\n' \"$line\"\n",
- )
- .expect("write script");
- let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
- permissions.set_mode(0o755);
- fs::set_permissions(&script_path, permissions).expect("chmod");
- script_path
- }
- fn write_jsonrpc_script() -> PathBuf {
- let root = temp_dir();
- fs::create_dir_all(&root).expect("temp dir");
- let script_path = root.join("jsonrpc-mcp.py");
- let script = [
- "#!/usr/bin/env python3",
- "import json, sys",
- "header = b''",
- r"while not header.endswith(b'\r\n\r\n'):",
- " chunk = sys.stdin.buffer.read(1)",
- " if not chunk:",
- " raise SystemExit(1)",
- " header += chunk",
- "length = 0",
- r"for line in header.decode().split('\r\n'):",
- r" if line.lower().startswith('content-length:'):",
- r" length = int(line.split(':', 1)[1].strip())",
- "payload = sys.stdin.buffer.read(length)",
- "request = json.loads(payload.decode())",
- r"assert request['jsonrpc'] == '2.0'",
- r"assert request['method'] == 'initialize'",
- r"response = json.dumps({",
- r" 'jsonrpc': '2.0',",
- r" 'id': request['id'],",
- r" 'result': {",
- r" 'protocolVersion': request['params']['protocolVersion'],",
- r" 'capabilities': {'tools': {}},",
- r" 'serverInfo': {'name': 'fake-mcp', 'version': '0.1.0'}",
- r" }",
- r"}).encode()",
- r"sys.stdout.buffer.write(f'Content-Length: {len(response)}\r\n\r\n'.encode() + response)",
- "sys.stdout.buffer.flush()",
- "",
- ]
- .join("\n");
- fs::write(&script_path, script).expect("write script");
- let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
- permissions.set_mode(0o755);
- fs::set_permissions(&script_path, permissions).expect("chmod");
- script_path
- }
- fn sample_bootstrap(script_path: &Path) -> McpClientBootstrap {
- let config = ScopedMcpServerConfig {
- scope: ConfigSource::Local,
- config: McpServerConfig::Stdio(McpStdioServerConfig {
- command: script_path.to_string_lossy().into_owned(),
- args: Vec::new(),
- env: BTreeMap::from([("MCP_TEST_TOKEN".to_string(), "secret-value".to_string())]),
- }),
- };
- McpClientBootstrap::from_scoped_config("stdio server", &config)
- }
- #[test]
- fn spawns_stdio_process_and_round_trips_io() {
- let runtime = Builder::new_current_thread()
- .enable_all()
- .build()
- .expect("runtime");
- runtime.block_on(async {
- let script_path = write_echo_script();
- let bootstrap = sample_bootstrap(&script_path);
- let mut process = spawn_mcp_stdio_process(&bootstrap).expect("spawn stdio process");
- let ready = process.read_line().await.expect("read ready");
- assert_eq!(ready, "READY:secret-value\n");
- process
- .write_line("ping from client")
- .await
- .expect("write line");
- let echoed = process.read_line().await.expect("read echo");
- assert_eq!(echoed, "ECHO:ping from client\n");
- let status = process.wait().await.expect("wait for exit");
- assert!(status.success());
- fs::remove_file(&script_path).expect("cleanup script");
- fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
- });
- }
- #[test]
- fn rejects_non_stdio_bootstrap() {
- let config = ScopedMcpServerConfig {
- scope: ConfigSource::Local,
- config: McpServerConfig::Sdk(crate::config::McpSdkServerConfig {
- name: "sdk-server".to_string(),
- }),
- };
- let bootstrap = McpClientBootstrap::from_scoped_config("sdk server", &config);
- let error = spawn_mcp_stdio_process(&bootstrap).expect_err("non-stdio should fail");
- assert_eq!(error.kind(), ErrorKind::InvalidInput);
- }
- #[test]
- fn round_trips_initialize_request_and_response_over_stdio_frames() {
- let runtime = Builder::new_current_thread()
- .enable_all()
- .build()
- .expect("runtime");
- runtime.block_on(async {
- let script_path = write_jsonrpc_script();
- let transport = crate::mcp_client::McpStdioTransport {
- command: "python3".to_string(),
- args: vec![script_path.to_string_lossy().into_owned()],
- env: BTreeMap::new(),
- };
- let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly");
- let response = process
- .initialize(
- JsonRpcId::Number(1),
- McpInitializeParams {
- protocol_version: "2025-03-26".to_string(),
- capabilities: json!({"roots": {}}),
- client_info: McpInitializeClientInfo {
- name: "runtime-tests".to_string(),
- version: "0.1.0".to_string(),
- },
- },
- )
- .await
- .expect("initialize roundtrip");
- assert_eq!(response.id, JsonRpcId::Number(1));
- assert_eq!(response.error, None);
- assert_eq!(
- response.result,
- Some(McpInitializeResult {
- protocol_version: "2025-03-26".to_string(),
- capabilities: json!({"tools": {}}),
- server_info: McpInitializeServerInfo {
- name: "fake-mcp".to_string(),
- version: "0.1.0".to_string(),
- },
- })
- );
- let status = process.wait().await.expect("wait for exit");
- assert!(status.success());
- fs::remove_file(&script_path).expect("cleanup script");
- fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
- });
- }
- #[test]
- fn write_jsonrpc_request_emits_content_length_frame() {
- let runtime = Builder::new_current_thread()
- .enable_all()
- .build()
- .expect("runtime");
- runtime.block_on(async {
- let script_path = write_jsonrpc_script();
- let transport = crate::mcp_client::McpStdioTransport {
- command: "python3".to_string(),
- args: vec![script_path.to_string_lossy().into_owned()],
- env: BTreeMap::new(),
- };
- let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly");
- let request = JsonRpcRequest::new(
- JsonRpcId::Number(7),
- "initialize",
- Some(json!({
- "protocolVersion": "2025-03-26",
- "capabilities": {},
- "clientInfo": {"name": "runtime-tests", "version": "0.1.0"}
- })),
- );
- process.send_request(&request).await.expect("send request");
- let response: super::JsonRpcResponse<serde_json::Value> =
- process.read_response().await.expect("read response");
- assert_eq!(response.id, JsonRpcId::Number(7));
- assert_eq!(response.jsonrpc, "2.0");
- let status = process.wait().await.expect("wait for exit");
- assert!(status.success());
- fs::remove_file(&script_path).expect("cleanup script");
- fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
- });
- }
- #[test]
- fn direct_spawn_uses_transport_env() {
- let runtime = Builder::new_current_thread()
- .enable_all()
- .build()
- .expect("runtime");
- runtime.block_on(async {
- let script_path = write_echo_script();
- let transport = crate::mcp_client::McpStdioTransport {
- command: script_path.to_string_lossy().into_owned(),
- args: Vec::new(),
- env: BTreeMap::from([("MCP_TEST_TOKEN".to_string(), "direct-secret".to_string())]),
- };
- let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly");
- let ready = process.read_available().await.expect("read ready");
- assert_eq!(String::from_utf8_lossy(&ready), "READY:direct-secret\n");
- process.terminate().await.expect("terminate child");
- let _ = process.wait().await.expect("wait after kill");
- fs::remove_file(&script_path).expect("cleanup script");
- fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
- });
- }
- }
|