Browse Source

Establish stdio JSON-RPC framing for MCP initialization

The runtime already knew how to spawn stdio MCP processes, but it still
needed transport primitives for framed JSON-RPC exchange. This change adds
minimal request/response types, line and frame helpers on the stdio wrapper,
and an initialize roundtrip helper so later MCP client slices can build on a
real transport foundation instead of raw byte plumbing.

Constraint: Keep the slice small and limited to stdio transport foundations
Constraint: Must verify framed request write and typed response parsing with a fake MCP process
Rejected: Introduce a broader MCP session layer now | would expand the slice beyond transport framing
Rejected: Leave JSON-RPC as untyped serde_json::Value only | weakens initialize roundtrip guarantees
Confidence: high
Scope-risk: narrow
Reversibility: clean
Directive: Preserve the camelCase MCP initialize field mapping when layering richer protocol support on top
Tested: cargo fmt --all --manifest-path rust/Cargo.toml
Tested: cargo clippy -p runtime --all-targets --manifest-path rust/Cargo.toml -- -D warnings
Tested: cargo test -p runtime --manifest-path rust/Cargo.toml
Not-tested: Integration against a real external MCP server process
Yeachan-Heo 2 tháng trước cách đây
mục cha
commit
8b6bf4cee7
2 tập tin đã thay đổi với 233 bổ sung45 xóa
  1. 5 0
      rust/crates/runtime/src/lib.rs
  2. 228 45
      rust/crates/runtime/src/mcp_stdio.rs

+ 5 - 0
rust/crates/runtime/src/lib.rs

@@ -44,6 +44,11 @@ pub use mcp_client::{
     McpClaudeAiProxyTransport, McpClientAuth, McpClientBootstrap, McpClientTransport,
     McpRemoteTransport, McpSdkTransport, McpStdioTransport,
 };
+pub use mcp_stdio::{
+    spawn_mcp_stdio_process, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
+    McpInitializeClientInfo, McpInitializeParams, McpInitializeResult, McpInitializeServerInfo,
+    McpStdioProcess,
+};
 pub use oauth::{
     code_challenge_s256, generate_pkce_pair, generate_state, loopback_redirect_uri,
     OAuthAuthorizationRequest, OAuthRefreshRequest, OAuthTokenExchangeRequest, OAuthTokenSet,

+ 228 - 45
rust/crates/runtime/src/mcp_stdio.rs

@@ -2,22 +2,98 @@ use std::collections::BTreeMap;
 use std::io;
 use std::process::Stdio;
 
+use serde::de::DeserializeOwned;
+use serde::{Deserialize, Serialize};
+use serde_json::Value as JsonValue;
 use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
 use tokio::process::{Child, ChildStdin, ChildStdout, Command};
 
-use serde_json::Value as JsonRpcMessage;
-
 use crate::mcp_client::{McpClientBootstrap, McpClientTransport, McpStdioTransport};
 
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
+#[serde(untagged)]
+pub enum JsonRpcId {
+    Number(u64),
+    String(String),
+    Null,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
+pub struct JsonRpcRequest<T = JsonValue> {
+    pub jsonrpc: String,
+    pub id: JsonRpcId,
+    pub method: String,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub params: Option<T>,
+}
+
+impl<T> JsonRpcRequest<T> {
+    #[must_use]
+    pub fn new(id: JsonRpcId, method: impl Into<String>, params: Option<T>) -> Self {
+        Self {
+            jsonrpc: "2.0".to_string(),
+            id,
+            method: method.into(),
+            params,
+        }
+    }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
+pub struct JsonRpcError {
+    pub code: i64,
+    pub message: String,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub data: Option<JsonValue>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
+pub struct JsonRpcResponse<T = JsonValue> {
+    pub jsonrpc: String,
+    pub id: JsonRpcId,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub result: Option<T>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub error: Option<JsonRpcError>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
+#[serde(rename_all = "camelCase")]
+pub struct McpInitializeParams {
+    pub protocol_version: String,
+    pub capabilities: JsonValue,
+    pub client_info: McpInitializeClientInfo,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
+#[serde(rename_all = "camelCase")]
+pub struct McpInitializeClientInfo {
+    pub name: String,
+    pub version: String,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
+#[serde(rename_all = "camelCase")]
+pub struct McpInitializeResult {
+    pub protocol_version: String,
+    pub capabilities: JsonValue,
+    pub server_info: McpInitializeServerInfo,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
+#[serde(rename_all = "camelCase")]
+pub struct McpInitializeServerInfo {
+    pub name: String,
+    pub version: String,
+}
+
 #[derive(Debug)]
-#[allow(dead_code)]
 pub struct McpStdioProcess {
     child: Child,
     stdin: ChildStdin,
     stdout: BufReader<ChildStdout>,
 }
 
-#[allow(dead_code)]
 impl McpStdioProcess {
     pub fn spawn(transport: &McpStdioTransport) -> io::Result<Self> {
         let mut command = Command::new(&transport.command);
@@ -53,6 +129,24 @@ impl McpStdioProcess {
         self.stdin.flush().await
     }
 
+    pub async fn write_line(&mut self, line: &str) -> io::Result<()> {
+        self.write_all(line.as_bytes()).await?;
+        self.write_all(b"\n").await?;
+        self.flush().await
+    }
+
+    pub async fn read_line(&mut self) -> io::Result<String> {
+        let mut line = String::new();
+        let bytes_read = self.stdout.read_line(&mut line).await?;
+        if bytes_read == 0 {
+            return Err(io::Error::new(
+                io::ErrorKind::UnexpectedEof,
+                "MCP stdio stream closed while reading line",
+            ));
+        }
+        Ok(line)
+    }
+
     pub async fn read_available(&mut self) -> io::Result<Vec<u8>> {
         let mut buffer = vec![0_u8; 4096];
         let read = self.stdout.read(&mut buffer).await?;
@@ -60,19 +154,13 @@ impl McpStdioProcess {
         Ok(buffer)
     }
 
-    pub async fn write_jsonrpc_message(&mut self, message: &JsonRpcMessage) -> io::Result<()> {
-        let encoded = encode_jsonrpc_message(message)?;
+    pub async fn write_frame(&mut self, payload: &[u8]) -> io::Result<()> {
+        let encoded = encode_frame(payload);
         self.write_all(&encoded).await?;
         self.flush().await
     }
 
-    pub async fn read_jsonrpc_message(&mut self) -> io::Result<JsonRpcMessage> {
-        let payload = self.read_frame().await?;
-        serde_json::from_slice(&payload)
-            .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
-    }
-
-    async fn read_frame(&mut self) -> io::Result<Vec<u8>> {
+    pub async fn read_frame(&mut self) -> io::Result<Vec<u8>> {
         let mut content_length = None;
         loop {
             let mut line = String::new();
@@ -103,6 +191,39 @@ impl McpStdioProcess {
         Ok(payload)
     }
 
+    pub async fn write_jsonrpc_message<T: Serialize>(&mut self, message: &T) -> io::Result<()> {
+        let body = serde_json::to_vec(message)
+            .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
+        self.write_frame(&body).await
+    }
+
+    pub async fn read_jsonrpc_message<T: DeserializeOwned>(&mut self) -> io::Result<T> {
+        let payload = self.read_frame().await?;
+        serde_json::from_slice(&payload)
+            .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
+    }
+
+    pub async fn send_request<T: Serialize>(
+        &mut self,
+        request: &JsonRpcRequest<T>,
+    ) -> io::Result<()> {
+        self.write_jsonrpc_message(request).await
+    }
+
+    pub async fn read_response<T: DeserializeOwned>(&mut self) -> io::Result<JsonRpcResponse<T>> {
+        self.read_jsonrpc_message().await
+    }
+
+    pub async fn initialize(
+        &mut self,
+        id: JsonRpcId,
+        params: McpInitializeParams,
+    ) -> io::Result<JsonRpcResponse<McpInitializeResult>> {
+        let request = JsonRpcRequest::new(id, "initialize", Some(params));
+        self.send_request(&request).await?;
+        self.read_response().await
+    }
+
     pub async fn terminate(&mut self) -> io::Result<()> {
         self.child.kill().await
     }
@@ -112,7 +233,6 @@ impl McpStdioProcess {
     }
 }
 
-#[allow(dead_code)]
 pub fn spawn_mcp_stdio_process(bootstrap: &McpClientBootstrap) -> io::Result<McpStdioProcess> {
     match &bootstrap.transport {
         McpClientTransport::Stdio(transport) => McpStdioProcess::spawn(transport),
@@ -126,20 +246,17 @@ pub fn spawn_mcp_stdio_process(bootstrap: &McpClientBootstrap) -> io::Result<Mcp
     }
 }
 
-#[allow(dead_code)]
 fn apply_env(command: &mut Command, env: &BTreeMap<String, String>) {
     for (key, value) in env {
         command.env(key, value);
     }
 }
 
-fn encode_jsonrpc_message(message: &JsonRpcMessage) -> io::Result<Vec<u8>> {
-    let body = serde_json::to_vec(message)
-        .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
-    let header = format!("Content-Length: {}\r\n\r\n", body.len());
+fn encode_frame(payload: &[u8]) -> Vec<u8> {
+    let header = format!("Content-Length: {}\r\n\r\n", payload.len());
     let mut framed = header.into_bytes();
-    framed.extend(body);
-    Ok(framed)
+    framed.extend_from_slice(payload);
+    framed
 }
 
 #[cfg(test)]
@@ -151,6 +268,7 @@ mod tests {
     use std::path::{Path, PathBuf};
     use std::time::{SystemTime, UNIX_EPOCH};
 
+    use serde_json::json;
     use tokio::runtime::Builder;
 
     use crate::config::{
@@ -158,7 +276,10 @@ mod tests {
     };
     use crate::mcp_client::McpClientBootstrap;
 
-    use super::{spawn_mcp_stdio_process, McpStdioProcess};
+    use super::{
+        spawn_mcp_stdio_process, JsonRpcId, JsonRpcRequest, McpInitializeClientInfo,
+        McpInitializeParams, McpInitializeResult, McpInitializeServerInfo, McpStdioProcess,
+    };
 
     fn temp_dir() -> PathBuf {
         let nanos = SystemTime::now()
@@ -201,8 +322,18 @@ mod tests {
             r"    if line.lower().startswith('content-length:'):",
             r"        length = int(line.split(':', 1)[1].strip())",
             "payload = sys.stdin.buffer.read(length)",
-            "json.loads(payload.decode())",
-            r"response = json.dumps({'jsonrpc': '2.0', 'id': 1, 'result': {'echo': True}}).encode()",
+            "request = json.loads(payload.decode())",
+            r"assert request['jsonrpc'] == '2.0'",
+            r"assert request['method'] == 'initialize'",
+            r"response = json.dumps({",
+            r"    'jsonrpc': '2.0',",
+            r"    'id': request['id'],",
+            r"    'result': {",
+            r"        'protocolVersion': request['params']['protocolVersion'],",
+            r"        'capabilities': {'tools': {}},",
+            r"        'serverInfo': {'name': 'fake-mcp', 'version': '0.1.0'}",
+            r"    }",
+            r"}).encode()",
             r"sys.stdout.buffer.write(f'Content-Length: {len(response)}\r\n\r\n'.encode() + response)",
             "sys.stdout.buffer.flush()",
             "",
@@ -214,6 +345,7 @@ mod tests {
         fs::set_permissions(&script_path, permissions).expect("chmod");
         script_path
     }
+
     fn sample_bootstrap(script_path: &Path) -> McpClientBootstrap {
         let config = ScopedMcpServerConfig {
             scope: ConfigSource::Local,
@@ -237,17 +369,16 @@ mod tests {
             let bootstrap = sample_bootstrap(&script_path);
             let mut process = spawn_mcp_stdio_process(&bootstrap).expect("spawn stdio process");
 
-            let ready = process.read_available().await.expect("read ready");
-            assert_eq!(String::from_utf8_lossy(&ready), "READY:secret-value\n");
+            let ready = process.read_line().await.expect("read ready");
+            assert_eq!(ready, "READY:secret-value\n");
 
             process
-                .write_all(b"ping from client\n")
+                .write_line("ping from client")
                 .await
-                .expect("write input");
-            process.flush().await.expect("flush");
+                .expect("write line");
 
-            let echoed = process.read_available().await.expect("read echo");
-            assert_eq!(String::from_utf8_lossy(&echoed), "ECHO:ping from client\n");
+            let echoed = process.read_line().await.expect("read echo");
+            assert_eq!(echoed, "ECHO:ping from client\n");
 
             let status = process.wait().await.expect("wait for exit");
             assert!(status.success());
@@ -271,7 +402,7 @@ mod tests {
     }
 
     #[test]
-    fn round_trips_jsonrpc_messages_over_stdio_frames() {
+    fn round_trips_initialize_request_and_response_over_stdio_frames() {
         let runtime = Builder::new_current_thread()
             .enable_all()
             .build()
@@ -284,22 +415,74 @@ mod tests {
                 env: BTreeMap::new(),
             };
             let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly");
-            process
-                .write_jsonrpc_message(&serde_json::json!({
-                    "jsonrpc": "2.0",
-                    "id": 1,
-                    "method": "initialize"
-                }))
-                .await
-                .expect("write jsonrpc message");
 
             let response = process
-                .read_jsonrpc_message()
+                .initialize(
+                    JsonRpcId::Number(1),
+                    McpInitializeParams {
+                        protocol_version: "2025-03-26".to_string(),
+                        capabilities: json!({"roots": {}}),
+                        client_info: McpInitializeClientInfo {
+                            name: "runtime-tests".to_string(),
+                            version: "0.1.0".to_string(),
+                        },
+                    },
+                )
                 .await
-                .expect("read jsonrpc response");
-            assert_eq!(response["jsonrpc"], serde_json::json!("2.0"));
-            assert_eq!(response["id"], serde_json::json!(1));
-            assert_eq!(response["result"]["echo"], serde_json::json!(true));
+                .expect("initialize roundtrip");
+
+            assert_eq!(response.id, JsonRpcId::Number(1));
+            assert_eq!(response.error, None);
+            assert_eq!(
+                response.result,
+                Some(McpInitializeResult {
+                    protocol_version: "2025-03-26".to_string(),
+                    capabilities: json!({"tools": {}}),
+                    server_info: McpInitializeServerInfo {
+                        name: "fake-mcp".to_string(),
+                        version: "0.1.0".to_string(),
+                    },
+                })
+            );
+
+            let status = process.wait().await.expect("wait for exit");
+            assert!(status.success());
+
+            fs::remove_file(&script_path).expect("cleanup script");
+            fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
+        });
+    }
+
+    #[test]
+    fn write_jsonrpc_request_emits_content_length_frame() {
+        let runtime = Builder::new_current_thread()
+            .enable_all()
+            .build()
+            .expect("runtime");
+        runtime.block_on(async {
+            let script_path = write_jsonrpc_script();
+            let transport = crate::mcp_client::McpStdioTransport {
+                command: "python3".to_string(),
+                args: vec![script_path.to_string_lossy().into_owned()],
+                env: BTreeMap::new(),
+            };
+            let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly");
+            let request = JsonRpcRequest::new(
+                JsonRpcId::Number(7),
+                "initialize",
+                Some(json!({
+                    "protocolVersion": "2025-03-26",
+                    "capabilities": {},
+                    "clientInfo": {"name": "runtime-tests", "version": "0.1.0"}
+                })),
+            );
+
+            process.send_request(&request).await.expect("send request");
+            let response: super::JsonRpcResponse<serde_json::Value> =
+                process.read_response().await.expect("read response");
+
+            assert_eq!(response.id, JsonRpcId::Number(7));
+            assert_eq!(response.jsonrpc, "2.0");
 
             let status = process.wait().await.expect("wait for exit");
             assert!(status.success());