mcp_stdio.rs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. use std::collections::BTreeMap;
  2. use std::io;
  3. use std::process::Stdio;
  4. use serde::de::DeserializeOwned;
  5. use serde::{Deserialize, Serialize};
  6. use serde_json::Value as JsonValue;
  7. use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
  8. use tokio::process::{Child, ChildStdin, ChildStdout, Command};
  9. use crate::mcp_client::{McpClientBootstrap, McpClientTransport, McpStdioTransport};
  10. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
  11. #[serde(untagged)]
  12. pub enum JsonRpcId {
  13. Number(u64),
  14. String(String),
  15. Null,
  16. }
  17. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  18. pub struct JsonRpcRequest<T = JsonValue> {
  19. pub jsonrpc: String,
  20. pub id: JsonRpcId,
  21. pub method: String,
  22. #[serde(skip_serializing_if = "Option::is_none")]
  23. pub params: Option<T>,
  24. }
  25. impl<T> JsonRpcRequest<T> {
  26. #[must_use]
  27. pub fn new(id: JsonRpcId, method: impl Into<String>, params: Option<T>) -> Self {
  28. Self {
  29. jsonrpc: "2.0".to_string(),
  30. id,
  31. method: method.into(),
  32. params,
  33. }
  34. }
  35. }
  36. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  37. pub struct JsonRpcError {
  38. pub code: i64,
  39. pub message: String,
  40. #[serde(skip_serializing_if = "Option::is_none")]
  41. pub data: Option<JsonValue>,
  42. }
  43. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  44. pub struct JsonRpcResponse<T = JsonValue> {
  45. pub jsonrpc: String,
  46. pub id: JsonRpcId,
  47. #[serde(skip_serializing_if = "Option::is_none")]
  48. pub result: Option<T>,
  49. #[serde(skip_serializing_if = "Option::is_none")]
  50. pub error: Option<JsonRpcError>,
  51. }
  52. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  53. #[serde(rename_all = "camelCase")]
  54. pub struct McpInitializeParams {
  55. pub protocol_version: String,
  56. pub capabilities: JsonValue,
  57. pub client_info: McpInitializeClientInfo,
  58. }
  59. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
  60. #[serde(rename_all = "camelCase")]
  61. pub struct McpInitializeClientInfo {
  62. pub name: String,
  63. pub version: String,
  64. }
  65. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  66. #[serde(rename_all = "camelCase")]
  67. pub struct McpInitializeResult {
  68. pub protocol_version: String,
  69. pub capabilities: JsonValue,
  70. pub server_info: McpInitializeServerInfo,
  71. }
  72. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
  73. #[serde(rename_all = "camelCase")]
  74. pub struct McpInitializeServerInfo {
  75. pub name: String,
  76. pub version: String,
  77. }
  78. #[derive(Debug)]
  79. pub struct McpStdioProcess {
  80. child: Child,
  81. stdin: ChildStdin,
  82. stdout: BufReader<ChildStdout>,
  83. }
  84. impl McpStdioProcess {
  85. pub fn spawn(transport: &McpStdioTransport) -> io::Result<Self> {
  86. let mut command = Command::new(&transport.command);
  87. command
  88. .args(&transport.args)
  89. .stdin(Stdio::piped())
  90. .stdout(Stdio::piped())
  91. .stderr(Stdio::inherit());
  92. apply_env(&mut command, &transport.env);
  93. let mut child = command.spawn()?;
  94. let stdin = child
  95. .stdin
  96. .take()
  97. .ok_or_else(|| io::Error::other("stdio MCP process missing stdin pipe"))?;
  98. let stdout = child
  99. .stdout
  100. .take()
  101. .ok_or_else(|| io::Error::other("stdio MCP process missing stdout pipe"))?;
  102. Ok(Self {
  103. child,
  104. stdin,
  105. stdout: BufReader::new(stdout),
  106. })
  107. }
  108. pub async fn write_all(&mut self, bytes: &[u8]) -> io::Result<()> {
  109. self.stdin.write_all(bytes).await
  110. }
  111. pub async fn flush(&mut self) -> io::Result<()> {
  112. self.stdin.flush().await
  113. }
  114. pub async fn write_line(&mut self, line: &str) -> io::Result<()> {
  115. self.write_all(line.as_bytes()).await?;
  116. self.write_all(b"\n").await?;
  117. self.flush().await
  118. }
  119. pub async fn read_line(&mut self) -> io::Result<String> {
  120. let mut line = String::new();
  121. let bytes_read = self.stdout.read_line(&mut line).await?;
  122. if bytes_read == 0 {
  123. return Err(io::Error::new(
  124. io::ErrorKind::UnexpectedEof,
  125. "MCP stdio stream closed while reading line",
  126. ));
  127. }
  128. Ok(line)
  129. }
  130. pub async fn read_available(&mut self) -> io::Result<Vec<u8>> {
  131. let mut buffer = vec![0_u8; 4096];
  132. let read = self.stdout.read(&mut buffer).await?;
  133. buffer.truncate(read);
  134. Ok(buffer)
  135. }
  136. pub async fn write_frame(&mut self, payload: &[u8]) -> io::Result<()> {
  137. let encoded = encode_frame(payload);
  138. self.write_all(&encoded).await?;
  139. self.flush().await
  140. }
  141. pub async fn read_frame(&mut self) -> io::Result<Vec<u8>> {
  142. let mut content_length = None;
  143. loop {
  144. let mut line = String::new();
  145. let bytes_read = self.stdout.read_line(&mut line).await?;
  146. if bytes_read == 0 {
  147. return Err(io::Error::new(
  148. io::ErrorKind::UnexpectedEof,
  149. "MCP stdio stream closed while reading headers",
  150. ));
  151. }
  152. if line == "\r\n" {
  153. break;
  154. }
  155. if let Some(value) = line.strip_prefix("Content-Length:") {
  156. let parsed = value
  157. .trim()
  158. .parse::<usize>()
  159. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
  160. content_length = Some(parsed);
  161. }
  162. }
  163. let content_length = content_length.ok_or_else(|| {
  164. io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header")
  165. })?;
  166. let mut payload = vec![0_u8; content_length];
  167. self.stdout.read_exact(&mut payload).await?;
  168. Ok(payload)
  169. }
  170. pub async fn write_jsonrpc_message<T: Serialize>(&mut self, message: &T) -> io::Result<()> {
  171. let body = serde_json::to_vec(message)
  172. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
  173. self.write_frame(&body).await
  174. }
  175. pub async fn read_jsonrpc_message<T: DeserializeOwned>(&mut self) -> io::Result<T> {
  176. let payload = self.read_frame().await?;
  177. serde_json::from_slice(&payload)
  178. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
  179. }
  180. pub async fn send_request<T: Serialize>(
  181. &mut self,
  182. request: &JsonRpcRequest<T>,
  183. ) -> io::Result<()> {
  184. self.write_jsonrpc_message(request).await
  185. }
  186. pub async fn read_response<T: DeserializeOwned>(&mut self) -> io::Result<JsonRpcResponse<T>> {
  187. self.read_jsonrpc_message().await
  188. }
  189. pub async fn initialize(
  190. &mut self,
  191. id: JsonRpcId,
  192. params: McpInitializeParams,
  193. ) -> io::Result<JsonRpcResponse<McpInitializeResult>> {
  194. let request = JsonRpcRequest::new(id, "initialize", Some(params));
  195. self.send_request(&request).await?;
  196. self.read_response().await
  197. }
  198. pub async fn terminate(&mut self) -> io::Result<()> {
  199. self.child.kill().await
  200. }
  201. pub async fn wait(&mut self) -> io::Result<std::process::ExitStatus> {
  202. self.child.wait().await
  203. }
  204. }
  205. pub fn spawn_mcp_stdio_process(bootstrap: &McpClientBootstrap) -> io::Result<McpStdioProcess> {
  206. match &bootstrap.transport {
  207. McpClientTransport::Stdio(transport) => McpStdioProcess::spawn(transport),
  208. other => Err(io::Error::new(
  209. io::ErrorKind::InvalidInput,
  210. format!(
  211. "MCP bootstrap transport for {} is not stdio: {other:?}",
  212. bootstrap.server_name
  213. ),
  214. )),
  215. }
  216. }
  217. fn apply_env(command: &mut Command, env: &BTreeMap<String, String>) {
  218. for (key, value) in env {
  219. command.env(key, value);
  220. }
  221. }
  222. fn encode_frame(payload: &[u8]) -> Vec<u8> {
  223. let header = format!("Content-Length: {}\r\n\r\n", payload.len());
  224. let mut framed = header.into_bytes();
  225. framed.extend_from_slice(payload);
  226. framed
  227. }
  228. #[cfg(test)]
  229. mod tests {
  230. use std::collections::BTreeMap;
  231. use std::fs;
  232. use std::io::ErrorKind;
  233. use std::os::unix::fs::PermissionsExt;
  234. use std::path::{Path, PathBuf};
  235. use std::time::{SystemTime, UNIX_EPOCH};
  236. use serde_json::json;
  237. use tokio::runtime::Builder;
  238. use crate::config::{
  239. ConfigSource, McpServerConfig, McpStdioServerConfig, ScopedMcpServerConfig,
  240. };
  241. use crate::mcp_client::McpClientBootstrap;
  242. use super::{
  243. spawn_mcp_stdio_process, JsonRpcId, JsonRpcRequest, McpInitializeClientInfo,
  244. McpInitializeParams, McpInitializeResult, McpInitializeServerInfo, McpStdioProcess,
  245. };
  246. fn temp_dir() -> PathBuf {
  247. let nanos = SystemTime::now()
  248. .duration_since(UNIX_EPOCH)
  249. .expect("time should be after epoch")
  250. .as_nanos();
  251. std::env::temp_dir().join(format!("runtime-mcp-stdio-{nanos}"))
  252. }
  253. fn write_echo_script() -> PathBuf {
  254. let root = temp_dir();
  255. fs::create_dir_all(&root).expect("temp dir");
  256. let script_path = root.join("echo-mcp.sh");
  257. fs::write(
  258. &script_path,
  259. "#!/bin/sh\nprintf 'READY:%s\\n' \"$MCP_TEST_TOKEN\"\nIFS= read -r line\nprintf 'ECHO:%s\\n' \"$line\"\n",
  260. )
  261. .expect("write script");
  262. let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
  263. permissions.set_mode(0o755);
  264. fs::set_permissions(&script_path, permissions).expect("chmod");
  265. script_path
  266. }
  267. fn write_jsonrpc_script() -> PathBuf {
  268. let root = temp_dir();
  269. fs::create_dir_all(&root).expect("temp dir");
  270. let script_path = root.join("jsonrpc-mcp.py");
  271. let script = [
  272. "#!/usr/bin/env python3",
  273. "import json, sys",
  274. "header = b''",
  275. r"while not header.endswith(b'\r\n\r\n'):",
  276. " chunk = sys.stdin.buffer.read(1)",
  277. " if not chunk:",
  278. " raise SystemExit(1)",
  279. " header += chunk",
  280. "length = 0",
  281. r"for line in header.decode().split('\r\n'):",
  282. r" if line.lower().startswith('content-length:'):",
  283. r" length = int(line.split(':', 1)[1].strip())",
  284. "payload = sys.stdin.buffer.read(length)",
  285. "request = json.loads(payload.decode())",
  286. r"assert request['jsonrpc'] == '2.0'",
  287. r"assert request['method'] == 'initialize'",
  288. r"response = json.dumps({",
  289. r" 'jsonrpc': '2.0',",
  290. r" 'id': request['id'],",
  291. r" 'result': {",
  292. r" 'protocolVersion': request['params']['protocolVersion'],",
  293. r" 'capabilities': {'tools': {}},",
  294. r" 'serverInfo': {'name': 'fake-mcp', 'version': '0.1.0'}",
  295. r" }",
  296. r"}).encode()",
  297. r"sys.stdout.buffer.write(f'Content-Length: {len(response)}\r\n\r\n'.encode() + response)",
  298. "sys.stdout.buffer.flush()",
  299. "",
  300. ]
  301. .join("\n");
  302. fs::write(&script_path, script).expect("write script");
  303. let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
  304. permissions.set_mode(0o755);
  305. fs::set_permissions(&script_path, permissions).expect("chmod");
  306. script_path
  307. }
  308. fn sample_bootstrap(script_path: &Path) -> McpClientBootstrap {
  309. let config = ScopedMcpServerConfig {
  310. scope: ConfigSource::Local,
  311. config: McpServerConfig::Stdio(McpStdioServerConfig {
  312. command: script_path.to_string_lossy().into_owned(),
  313. args: Vec::new(),
  314. env: BTreeMap::from([("MCP_TEST_TOKEN".to_string(), "secret-value".to_string())]),
  315. }),
  316. };
  317. McpClientBootstrap::from_scoped_config("stdio server", &config)
  318. }
  319. #[test]
  320. fn spawns_stdio_process_and_round_trips_io() {
  321. let runtime = Builder::new_current_thread()
  322. .enable_all()
  323. .build()
  324. .expect("runtime");
  325. runtime.block_on(async {
  326. let script_path = write_echo_script();
  327. let bootstrap = sample_bootstrap(&script_path);
  328. let mut process = spawn_mcp_stdio_process(&bootstrap).expect("spawn stdio process");
  329. let ready = process.read_line().await.expect("read ready");
  330. assert_eq!(ready, "READY:secret-value\n");
  331. process
  332. .write_line("ping from client")
  333. .await
  334. .expect("write line");
  335. let echoed = process.read_line().await.expect("read echo");
  336. assert_eq!(echoed, "ECHO:ping from client\n");
  337. let status = process.wait().await.expect("wait for exit");
  338. assert!(status.success());
  339. fs::remove_file(&script_path).expect("cleanup script");
  340. fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
  341. });
  342. }
  343. #[test]
  344. fn rejects_non_stdio_bootstrap() {
  345. let config = ScopedMcpServerConfig {
  346. scope: ConfigSource::Local,
  347. config: McpServerConfig::Sdk(crate::config::McpSdkServerConfig {
  348. name: "sdk-server".to_string(),
  349. }),
  350. };
  351. let bootstrap = McpClientBootstrap::from_scoped_config("sdk server", &config);
  352. let error = spawn_mcp_stdio_process(&bootstrap).expect_err("non-stdio should fail");
  353. assert_eq!(error.kind(), ErrorKind::InvalidInput);
  354. }
  355. #[test]
  356. fn round_trips_initialize_request_and_response_over_stdio_frames() {
  357. let runtime = Builder::new_current_thread()
  358. .enable_all()
  359. .build()
  360. .expect("runtime");
  361. runtime.block_on(async {
  362. let script_path = write_jsonrpc_script();
  363. let transport = crate::mcp_client::McpStdioTransport {
  364. command: "python3".to_string(),
  365. args: vec![script_path.to_string_lossy().into_owned()],
  366. env: BTreeMap::new(),
  367. };
  368. let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly");
  369. let response = process
  370. .initialize(
  371. JsonRpcId::Number(1),
  372. McpInitializeParams {
  373. protocol_version: "2025-03-26".to_string(),
  374. capabilities: json!({"roots": {}}),
  375. client_info: McpInitializeClientInfo {
  376. name: "runtime-tests".to_string(),
  377. version: "0.1.0".to_string(),
  378. },
  379. },
  380. )
  381. .await
  382. .expect("initialize roundtrip");
  383. assert_eq!(response.id, JsonRpcId::Number(1));
  384. assert_eq!(response.error, None);
  385. assert_eq!(
  386. response.result,
  387. Some(McpInitializeResult {
  388. protocol_version: "2025-03-26".to_string(),
  389. capabilities: json!({"tools": {}}),
  390. server_info: McpInitializeServerInfo {
  391. name: "fake-mcp".to_string(),
  392. version: "0.1.0".to_string(),
  393. },
  394. })
  395. );
  396. let status = process.wait().await.expect("wait for exit");
  397. assert!(status.success());
  398. fs::remove_file(&script_path).expect("cleanup script");
  399. fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
  400. });
  401. }
  402. #[test]
  403. fn write_jsonrpc_request_emits_content_length_frame() {
  404. let runtime = Builder::new_current_thread()
  405. .enable_all()
  406. .build()
  407. .expect("runtime");
  408. runtime.block_on(async {
  409. let script_path = write_jsonrpc_script();
  410. let transport = crate::mcp_client::McpStdioTransport {
  411. command: "python3".to_string(),
  412. args: vec![script_path.to_string_lossy().into_owned()],
  413. env: BTreeMap::new(),
  414. };
  415. let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly");
  416. let request = JsonRpcRequest::new(
  417. JsonRpcId::Number(7),
  418. "initialize",
  419. Some(json!({
  420. "protocolVersion": "2025-03-26",
  421. "capabilities": {},
  422. "clientInfo": {"name": "runtime-tests", "version": "0.1.0"}
  423. })),
  424. );
  425. process.send_request(&request).await.expect("send request");
  426. let response: super::JsonRpcResponse<serde_json::Value> =
  427. process.read_response().await.expect("read response");
  428. assert_eq!(response.id, JsonRpcId::Number(7));
  429. assert_eq!(response.jsonrpc, "2.0");
  430. let status = process.wait().await.expect("wait for exit");
  431. assert!(status.success());
  432. fs::remove_file(&script_path).expect("cleanup script");
  433. fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
  434. });
  435. }
  436. #[test]
  437. fn direct_spawn_uses_transport_env() {
  438. let runtime = Builder::new_current_thread()
  439. .enable_all()
  440. .build()
  441. .expect("runtime");
  442. runtime.block_on(async {
  443. let script_path = write_echo_script();
  444. let transport = crate::mcp_client::McpStdioTransport {
  445. command: script_path.to_string_lossy().into_owned(),
  446. args: Vec::new(),
  447. env: BTreeMap::from([("MCP_TEST_TOKEN".to_string(), "direct-secret".to_string())]),
  448. };
  449. let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly");
  450. let ready = process.read_available().await.expect("read ready");
  451. assert_eq!(String::from_utf8_lossy(&ready), "READY:direct-secret\n");
  452. process.terminate().await.expect("terminate child");
  453. let _ = process.wait().await.expect("wait after kill");
  454. fs::remove_file(&script_path).expect("cleanup script");
  455. fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
  456. });
  457. }
  458. }