|
|
@@ -2,22 +2,98 @@ use std::collections::BTreeMap;
|
|
|
use std::io;
|
|
|
use std::process::Stdio;
|
|
|
|
|
|
+use serde::de::DeserializeOwned;
|
|
|
+use serde::{Deserialize, Serialize};
|
|
|
+use serde_json::Value as JsonValue;
|
|
|
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
|
|
|
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
|
|
|
|
|
|
-use serde_json::Value as JsonRpcMessage;
|
|
|
-
|
|
|
use crate::mcp_client::{McpClientBootstrap, McpClientTransport, McpStdioTransport};
|
|
|
|
|
|
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
|
+#[serde(untagged)]
|
|
|
+pub enum JsonRpcId {
|
|
|
+ Number(u64),
|
|
|
+ String(String),
|
|
|
+ Null,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
|
+pub struct JsonRpcRequest<T = JsonValue> {
|
|
|
+ pub jsonrpc: String,
|
|
|
+ pub id: JsonRpcId,
|
|
|
+ pub method: String,
|
|
|
+ #[serde(skip_serializing_if = "Option::is_none")]
|
|
|
+ pub params: Option<T>,
|
|
|
+}
|
|
|
+
|
|
|
+impl<T> JsonRpcRequest<T> {
|
|
|
+ #[must_use]
|
|
|
+ pub fn new(id: JsonRpcId, method: impl Into<String>, params: Option<T>) -> Self {
|
|
|
+ Self {
|
|
|
+ jsonrpc: "2.0".to_string(),
|
|
|
+ id,
|
|
|
+ method: method.into(),
|
|
|
+ params,
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
|
+pub struct JsonRpcError {
|
|
|
+ pub code: i64,
|
|
|
+ pub message: String,
|
|
|
+ #[serde(skip_serializing_if = "Option::is_none")]
|
|
|
+ pub data: Option<JsonValue>,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
|
+pub struct JsonRpcResponse<T = JsonValue> {
|
|
|
+ pub jsonrpc: String,
|
|
|
+ pub id: JsonRpcId,
|
|
|
+ #[serde(skip_serializing_if = "Option::is_none")]
|
|
|
+ pub result: Option<T>,
|
|
|
+ #[serde(skip_serializing_if = "Option::is_none")]
|
|
|
+ pub error: Option<JsonRpcError>,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
|
+#[serde(rename_all = "camelCase")]
|
|
|
+pub struct McpInitializeParams {
|
|
|
+ pub protocol_version: String,
|
|
|
+ pub capabilities: JsonValue,
|
|
|
+ pub client_info: McpInitializeClientInfo,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
|
+#[serde(rename_all = "camelCase")]
|
|
|
+pub struct McpInitializeClientInfo {
|
|
|
+ pub name: String,
|
|
|
+ pub version: String,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
|
+#[serde(rename_all = "camelCase")]
|
|
|
+pub struct McpInitializeResult {
|
|
|
+ pub protocol_version: String,
|
|
|
+ pub capabilities: JsonValue,
|
|
|
+ pub server_info: McpInitializeServerInfo,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
|
|
+#[serde(rename_all = "camelCase")]
|
|
|
+pub struct McpInitializeServerInfo {
|
|
|
+ pub name: String,
|
|
|
+ pub version: String,
|
|
|
+}
|
|
|
+
|
|
|
#[derive(Debug)]
|
|
|
-#[allow(dead_code)]
|
|
|
pub struct McpStdioProcess {
|
|
|
child: Child,
|
|
|
stdin: ChildStdin,
|
|
|
stdout: BufReader<ChildStdout>,
|
|
|
}
|
|
|
|
|
|
-#[allow(dead_code)]
|
|
|
impl McpStdioProcess {
|
|
|
pub fn spawn(transport: &McpStdioTransport) -> io::Result<Self> {
|
|
|
let mut command = Command::new(&transport.command);
|
|
|
@@ -53,6 +129,24 @@ impl McpStdioProcess {
|
|
|
self.stdin.flush().await
|
|
|
}
|
|
|
|
|
|
+ pub async fn write_line(&mut self, line: &str) -> io::Result<()> {
|
|
|
+ self.write_all(line.as_bytes()).await?;
|
|
|
+ self.write_all(b"\n").await?;
|
|
|
+ self.flush().await
|
|
|
+ }
|
|
|
+
|
|
|
+ pub async fn read_line(&mut self) -> io::Result<String> {
|
|
|
+ let mut line = String::new();
|
|
|
+ let bytes_read = self.stdout.read_line(&mut line).await?;
|
|
|
+ if bytes_read == 0 {
|
|
|
+ return Err(io::Error::new(
|
|
|
+ io::ErrorKind::UnexpectedEof,
|
|
|
+ "MCP stdio stream closed while reading line",
|
|
|
+ ));
|
|
|
+ }
|
|
|
+ Ok(line)
|
|
|
+ }
|
|
|
+
|
|
|
pub async fn read_available(&mut self) -> io::Result<Vec<u8>> {
|
|
|
let mut buffer = vec![0_u8; 4096];
|
|
|
let read = self.stdout.read(&mut buffer).await?;
|
|
|
@@ -60,19 +154,13 @@ impl McpStdioProcess {
|
|
|
Ok(buffer)
|
|
|
}
|
|
|
|
|
|
- pub async fn write_jsonrpc_message(&mut self, message: &JsonRpcMessage) -> io::Result<()> {
|
|
|
- let encoded = encode_jsonrpc_message(message)?;
|
|
|
+ pub async fn write_frame(&mut self, payload: &[u8]) -> io::Result<()> {
|
|
|
+ let encoded = encode_frame(payload);
|
|
|
self.write_all(&encoded).await?;
|
|
|
self.flush().await
|
|
|
}
|
|
|
|
|
|
- pub async fn read_jsonrpc_message(&mut self) -> io::Result<JsonRpcMessage> {
|
|
|
- let payload = self.read_frame().await?;
|
|
|
- serde_json::from_slice(&payload)
|
|
|
- .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
|
|
|
- }
|
|
|
-
|
|
|
- async fn read_frame(&mut self) -> io::Result<Vec<u8>> {
|
|
|
+ pub async fn read_frame(&mut self) -> io::Result<Vec<u8>> {
|
|
|
let mut content_length = None;
|
|
|
loop {
|
|
|
let mut line = String::new();
|
|
|
@@ -103,6 +191,39 @@ impl McpStdioProcess {
|
|
|
Ok(payload)
|
|
|
}
|
|
|
|
|
|
+ pub async fn write_jsonrpc_message<T: Serialize>(&mut self, message: &T) -> io::Result<()> {
|
|
|
+ let body = serde_json::to_vec(message)
|
|
|
+ .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
|
|
+ self.write_frame(&body).await
|
|
|
+ }
|
|
|
+
|
|
|
+ pub async fn read_jsonrpc_message<T: DeserializeOwned>(&mut self) -> io::Result<T> {
|
|
|
+ let payload = self.read_frame().await?;
|
|
|
+ serde_json::from_slice(&payload)
|
|
|
+ .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
|
|
|
+ }
|
|
|
+
|
|
|
+ pub async fn send_request<T: Serialize>(
|
|
|
+ &mut self,
|
|
|
+ request: &JsonRpcRequest<T>,
|
|
|
+ ) -> io::Result<()> {
|
|
|
+ self.write_jsonrpc_message(request).await
|
|
|
+ }
|
|
|
+
|
|
|
+ pub async fn read_response<T: DeserializeOwned>(&mut self) -> io::Result<JsonRpcResponse<T>> {
|
|
|
+ self.read_jsonrpc_message().await
|
|
|
+ }
|
|
|
+
|
|
|
+ pub async fn initialize(
|
|
|
+ &mut self,
|
|
|
+ id: JsonRpcId,
|
|
|
+ params: McpInitializeParams,
|
|
|
+ ) -> io::Result<JsonRpcResponse<McpInitializeResult>> {
|
|
|
+ let request = JsonRpcRequest::new(id, "initialize", Some(params));
|
|
|
+ self.send_request(&request).await?;
|
|
|
+ self.read_response().await
|
|
|
+ }
|
|
|
+
|
|
|
pub async fn terminate(&mut self) -> io::Result<()> {
|
|
|
self.child.kill().await
|
|
|
}
|
|
|
@@ -112,7 +233,6 @@ impl McpStdioProcess {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-#[allow(dead_code)]
|
|
|
pub fn spawn_mcp_stdio_process(bootstrap: &McpClientBootstrap) -> io::Result<McpStdioProcess> {
|
|
|
match &bootstrap.transport {
|
|
|
McpClientTransport::Stdio(transport) => McpStdioProcess::spawn(transport),
|
|
|
@@ -126,20 +246,17 @@ pub fn spawn_mcp_stdio_process(bootstrap: &McpClientBootstrap) -> io::Result<Mcp
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-#[allow(dead_code)]
|
|
|
fn apply_env(command: &mut Command, env: &BTreeMap<String, String>) {
|
|
|
for (key, value) in env {
|
|
|
command.env(key, value);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-fn encode_jsonrpc_message(message: &JsonRpcMessage) -> io::Result<Vec<u8>> {
|
|
|
- let body = serde_json::to_vec(message)
|
|
|
- .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
|
|
- let header = format!("Content-Length: {}\r\n\r\n", body.len());
|
|
|
+fn encode_frame(payload: &[u8]) -> Vec<u8> {
|
|
|
+ let header = format!("Content-Length: {}\r\n\r\n", payload.len());
|
|
|
let mut framed = header.into_bytes();
|
|
|
- framed.extend(body);
|
|
|
- Ok(framed)
|
|
|
+ framed.extend_from_slice(payload);
|
|
|
+ framed
|
|
|
}
|
|
|
|
|
|
#[cfg(test)]
|
|
|
@@ -151,6 +268,7 @@ mod tests {
|
|
|
use std::path::{Path, PathBuf};
|
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
|
|
|
|
|
+ use serde_json::json;
|
|
|
use tokio::runtime::Builder;
|
|
|
|
|
|
use crate::config::{
|
|
|
@@ -158,7 +276,10 @@ mod tests {
|
|
|
};
|
|
|
use crate::mcp_client::McpClientBootstrap;
|
|
|
|
|
|
- use super::{spawn_mcp_stdio_process, McpStdioProcess};
|
|
|
+ use super::{
|
|
|
+ spawn_mcp_stdio_process, JsonRpcId, JsonRpcRequest, McpInitializeClientInfo,
|
|
|
+ McpInitializeParams, McpInitializeResult, McpInitializeServerInfo, McpStdioProcess,
|
|
|
+ };
|
|
|
|
|
|
fn temp_dir() -> PathBuf {
|
|
|
let nanos = SystemTime::now()
|
|
|
@@ -201,8 +322,18 @@ mod tests {
|
|
|
r" if line.lower().startswith('content-length:'):",
|
|
|
r" length = int(line.split(':', 1)[1].strip())",
|
|
|
"payload = sys.stdin.buffer.read(length)",
|
|
|
- "json.loads(payload.decode())",
|
|
|
- r"response = json.dumps({'jsonrpc': '2.0', 'id': 1, 'result': {'echo': True}}).encode()",
|
|
|
+ "request = json.loads(payload.decode())",
|
|
|
+ r"assert request['jsonrpc'] == '2.0'",
|
|
|
+ r"assert request['method'] == 'initialize'",
|
|
|
+ r"response = json.dumps({",
|
|
|
+ r" 'jsonrpc': '2.0',",
|
|
|
+ r" 'id': request['id'],",
|
|
|
+ r" 'result': {",
|
|
|
+ r" 'protocolVersion': request['params']['protocolVersion'],",
|
|
|
+ r" 'capabilities': {'tools': {}},",
|
|
|
+ r" 'serverInfo': {'name': 'fake-mcp', 'version': '0.1.0'}",
|
|
|
+ r" }",
|
|
|
+ r"}).encode()",
|
|
|
r"sys.stdout.buffer.write(f'Content-Length: {len(response)}\r\n\r\n'.encode() + response)",
|
|
|
"sys.stdout.buffer.flush()",
|
|
|
"",
|
|
|
@@ -214,6 +345,7 @@ mod tests {
|
|
|
fs::set_permissions(&script_path, permissions).expect("chmod");
|
|
|
script_path
|
|
|
}
|
|
|
+
|
|
|
fn sample_bootstrap(script_path: &Path) -> McpClientBootstrap {
|
|
|
let config = ScopedMcpServerConfig {
|
|
|
scope: ConfigSource::Local,
|
|
|
@@ -237,17 +369,16 @@ mod tests {
|
|
|
let bootstrap = sample_bootstrap(&script_path);
|
|
|
let mut process = spawn_mcp_stdio_process(&bootstrap).expect("spawn stdio process");
|
|
|
|
|
|
- let ready = process.read_available().await.expect("read ready");
|
|
|
- assert_eq!(String::from_utf8_lossy(&ready), "READY:secret-value\n");
|
|
|
+ let ready = process.read_line().await.expect("read ready");
|
|
|
+ assert_eq!(ready, "READY:secret-value\n");
|
|
|
|
|
|
process
|
|
|
- .write_all(b"ping from client\n")
|
|
|
+ .write_line("ping from client")
|
|
|
.await
|
|
|
- .expect("write input");
|
|
|
- process.flush().await.expect("flush");
|
|
|
+ .expect("write line");
|
|
|
|
|
|
- let echoed = process.read_available().await.expect("read echo");
|
|
|
- assert_eq!(String::from_utf8_lossy(&echoed), "ECHO:ping from client\n");
|
|
|
+ let echoed = process.read_line().await.expect("read echo");
|
|
|
+ assert_eq!(echoed, "ECHO:ping from client\n");
|
|
|
|
|
|
let status = process.wait().await.expect("wait for exit");
|
|
|
assert!(status.success());
|
|
|
@@ -271,7 +402,7 @@ mod tests {
|
|
|
}
|
|
|
|
|
|
#[test]
|
|
|
- fn round_trips_jsonrpc_messages_over_stdio_frames() {
|
|
|
+ fn round_trips_initialize_request_and_response_over_stdio_frames() {
|
|
|
let runtime = Builder::new_current_thread()
|
|
|
.enable_all()
|
|
|
.build()
|
|
|
@@ -284,22 +415,74 @@ mod tests {
|
|
|
env: BTreeMap::new(),
|
|
|
};
|
|
|
let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly");
|
|
|
- process
|
|
|
- .write_jsonrpc_message(&serde_json::json!({
|
|
|
- "jsonrpc": "2.0",
|
|
|
- "id": 1,
|
|
|
- "method": "initialize"
|
|
|
- }))
|
|
|
- .await
|
|
|
- .expect("write jsonrpc message");
|
|
|
|
|
|
let response = process
|
|
|
- .read_jsonrpc_message()
|
|
|
+ .initialize(
|
|
|
+ JsonRpcId::Number(1),
|
|
|
+ McpInitializeParams {
|
|
|
+ protocol_version: "2025-03-26".to_string(),
|
|
|
+ capabilities: json!({"roots": {}}),
|
|
|
+ client_info: McpInitializeClientInfo {
|
|
|
+ name: "runtime-tests".to_string(),
|
|
|
+ version: "0.1.0".to_string(),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ )
|
|
|
.await
|
|
|
- .expect("read jsonrpc response");
|
|
|
- assert_eq!(response["jsonrpc"], serde_json::json!("2.0"));
|
|
|
- assert_eq!(response["id"], serde_json::json!(1));
|
|
|
- assert_eq!(response["result"]["echo"], serde_json::json!(true));
|
|
|
+ .expect("initialize roundtrip");
|
|
|
+
|
|
|
+ assert_eq!(response.id, JsonRpcId::Number(1));
|
|
|
+ assert_eq!(response.error, None);
|
|
|
+ assert_eq!(
|
|
|
+ response.result,
|
|
|
+ Some(McpInitializeResult {
|
|
|
+ protocol_version: "2025-03-26".to_string(),
|
|
|
+ capabilities: json!({"tools": {}}),
|
|
|
+ server_info: McpInitializeServerInfo {
|
|
|
+ name: "fake-mcp".to_string(),
|
|
|
+ version: "0.1.0".to_string(),
|
|
|
+ },
|
|
|
+ })
|
|
|
+ );
|
|
|
+
|
|
|
+ let status = process.wait().await.expect("wait for exit");
|
|
|
+ assert!(status.success());
|
|
|
+
|
|
|
+ fs::remove_file(&script_path).expect("cleanup script");
|
|
|
+ fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn write_jsonrpc_request_emits_content_length_frame() {
|
|
|
+ let runtime = Builder::new_current_thread()
|
|
|
+ .enable_all()
|
|
|
+ .build()
|
|
|
+ .expect("runtime");
|
|
|
+ runtime.block_on(async {
|
|
|
+ let script_path = write_jsonrpc_script();
|
|
|
+ let transport = crate::mcp_client::McpStdioTransport {
|
|
|
+ command: "python3".to_string(),
|
|
|
+ args: vec![script_path.to_string_lossy().into_owned()],
|
|
|
+ env: BTreeMap::new(),
|
|
|
+ };
|
|
|
+ let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly");
|
|
|
+ let request = JsonRpcRequest::new(
|
|
|
+ JsonRpcId::Number(7),
|
|
|
+ "initialize",
|
|
|
+ Some(json!({
|
|
|
+ "protocolVersion": "2025-03-26",
|
|
|
+ "capabilities": {},
|
|
|
+ "clientInfo": {"name": "runtime-tests", "version": "0.1.0"}
|
|
|
+ })),
|
|
|
+ );
|
|
|
+
|
|
|
+ process.send_request(&request).await.expect("send request");
|
|
|
+ let response: super::JsonRpcResponse<serde_json::Value> =
|
|
|
+ process.read_response().await.expect("read response");
|
|
|
+
|
|
|
+ assert_eq!(response.id, JsonRpcId::Number(7));
|
|
|
+ assert_eq!(response.jsonrpc, "2.0");
|
|
|
|
|
|
let status = process.wait().await.expect("wait for exit");
|
|
|
assert!(status.success());
|