mcp_stdio.rs 60 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697
  1. use std::collections::BTreeMap;
  2. use std::io;
  3. use std::process::Stdio;
  4. use serde::de::DeserializeOwned;
  5. use serde::{Deserialize, Serialize};
  6. use serde_json::Value as JsonValue;
  7. use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
  8. use tokio::process::{Child, ChildStdin, ChildStdout, Command};
  9. use crate::config::{McpTransport, RuntimeConfig, ScopedMcpServerConfig};
  10. use crate::mcp::mcp_tool_name;
  11. use crate::mcp_client::{McpClientBootstrap, McpClientTransport, McpStdioTransport};
  12. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
  13. #[serde(untagged)]
  14. pub enum JsonRpcId {
  15. Number(u64),
  16. String(String),
  17. Null,
  18. }
  19. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  20. pub struct JsonRpcRequest<T = JsonValue> {
  21. pub jsonrpc: String,
  22. pub id: JsonRpcId,
  23. pub method: String,
  24. #[serde(skip_serializing_if = "Option::is_none")]
  25. pub params: Option<T>,
  26. }
  27. impl<T> JsonRpcRequest<T> {
  28. #[must_use]
  29. pub fn new(id: JsonRpcId, method: impl Into<String>, params: Option<T>) -> Self {
  30. Self {
  31. jsonrpc: "2.0".to_string(),
  32. id,
  33. method: method.into(),
  34. params,
  35. }
  36. }
  37. }
  38. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  39. pub struct JsonRpcError {
  40. pub code: i64,
  41. pub message: String,
  42. #[serde(skip_serializing_if = "Option::is_none")]
  43. pub data: Option<JsonValue>,
  44. }
  45. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  46. pub struct JsonRpcResponse<T = JsonValue> {
  47. pub jsonrpc: String,
  48. pub id: JsonRpcId,
  49. #[serde(skip_serializing_if = "Option::is_none")]
  50. pub result: Option<T>,
  51. #[serde(skip_serializing_if = "Option::is_none")]
  52. pub error: Option<JsonRpcError>,
  53. }
  54. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  55. #[serde(rename_all = "camelCase")]
  56. pub struct McpInitializeParams {
  57. pub protocol_version: String,
  58. pub capabilities: JsonValue,
  59. pub client_info: McpInitializeClientInfo,
  60. }
  61. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
  62. #[serde(rename_all = "camelCase")]
  63. pub struct McpInitializeClientInfo {
  64. pub name: String,
  65. pub version: String,
  66. }
  67. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  68. #[serde(rename_all = "camelCase")]
  69. pub struct McpInitializeResult {
  70. pub protocol_version: String,
  71. pub capabilities: JsonValue,
  72. pub server_info: McpInitializeServerInfo,
  73. }
  74. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
  75. #[serde(rename_all = "camelCase")]
  76. pub struct McpInitializeServerInfo {
  77. pub name: String,
  78. pub version: String,
  79. }
  80. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  81. #[serde(rename_all = "camelCase")]
  82. pub struct McpListToolsParams {
  83. #[serde(skip_serializing_if = "Option::is_none")]
  84. pub cursor: Option<String>,
  85. }
  86. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  87. pub struct McpTool {
  88. pub name: String,
  89. #[serde(skip_serializing_if = "Option::is_none")]
  90. pub description: Option<String>,
  91. #[serde(rename = "inputSchema", skip_serializing_if = "Option::is_none")]
  92. pub input_schema: Option<JsonValue>,
  93. #[serde(skip_serializing_if = "Option::is_none")]
  94. pub annotations: Option<JsonValue>,
  95. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
  96. pub meta: Option<JsonValue>,
  97. }
  98. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  99. #[serde(rename_all = "camelCase")]
  100. pub struct McpListToolsResult {
  101. pub tools: Vec<McpTool>,
  102. #[serde(skip_serializing_if = "Option::is_none")]
  103. pub next_cursor: Option<String>,
  104. }
  105. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  106. #[serde(rename_all = "camelCase")]
  107. pub struct McpToolCallParams {
  108. pub name: String,
  109. #[serde(skip_serializing_if = "Option::is_none")]
  110. pub arguments: Option<JsonValue>,
  111. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
  112. pub meta: Option<JsonValue>,
  113. }
  114. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  115. pub struct McpToolCallContent {
  116. #[serde(rename = "type")]
  117. pub kind: String,
  118. #[serde(flatten)]
  119. pub data: BTreeMap<String, JsonValue>,
  120. }
  121. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  122. #[serde(rename_all = "camelCase")]
  123. pub struct McpToolCallResult {
  124. #[serde(default)]
  125. pub content: Vec<McpToolCallContent>,
  126. #[serde(default)]
  127. pub structured_content: Option<JsonValue>,
  128. #[serde(default)]
  129. pub is_error: Option<bool>,
  130. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
  131. pub meta: Option<JsonValue>,
  132. }
  133. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  134. #[serde(rename_all = "camelCase")]
  135. pub struct McpListResourcesParams {
  136. #[serde(skip_serializing_if = "Option::is_none")]
  137. pub cursor: Option<String>,
  138. }
  139. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  140. pub struct McpResource {
  141. pub uri: String,
  142. #[serde(skip_serializing_if = "Option::is_none")]
  143. pub name: Option<String>,
  144. #[serde(skip_serializing_if = "Option::is_none")]
  145. pub description: Option<String>,
  146. #[serde(rename = "mimeType", skip_serializing_if = "Option::is_none")]
  147. pub mime_type: Option<String>,
  148. #[serde(skip_serializing_if = "Option::is_none")]
  149. pub annotations: Option<JsonValue>,
  150. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
  151. pub meta: Option<JsonValue>,
  152. }
  153. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  154. #[serde(rename_all = "camelCase")]
  155. pub struct McpListResourcesResult {
  156. pub resources: Vec<McpResource>,
  157. #[serde(skip_serializing_if = "Option::is_none")]
  158. pub next_cursor: Option<String>,
  159. }
  160. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  161. #[serde(rename_all = "camelCase")]
  162. pub struct McpReadResourceParams {
  163. pub uri: String,
  164. }
  165. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  166. pub struct McpResourceContents {
  167. pub uri: String,
  168. #[serde(rename = "mimeType", skip_serializing_if = "Option::is_none")]
  169. pub mime_type: Option<String>,
  170. #[serde(skip_serializing_if = "Option::is_none")]
  171. pub text: Option<String>,
  172. #[serde(skip_serializing_if = "Option::is_none")]
  173. pub blob: Option<String>,
  174. #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
  175. pub meta: Option<JsonValue>,
  176. }
  177. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
  178. pub struct McpReadResourceResult {
  179. pub contents: Vec<McpResourceContents>,
  180. }
  181. #[derive(Debug, Clone, PartialEq)]
  182. pub struct ManagedMcpTool {
  183. pub server_name: String,
  184. pub qualified_name: String,
  185. pub raw_name: String,
  186. pub tool: McpTool,
  187. }
  188. #[derive(Debug, Clone, PartialEq, Eq)]
  189. pub struct UnsupportedMcpServer {
  190. pub server_name: String,
  191. pub transport: McpTransport,
  192. pub reason: String,
  193. }
  194. #[derive(Debug)]
  195. pub enum McpServerManagerError {
  196. Io(io::Error),
  197. JsonRpc {
  198. server_name: String,
  199. method: &'static str,
  200. error: JsonRpcError,
  201. },
  202. InvalidResponse {
  203. server_name: String,
  204. method: &'static str,
  205. details: String,
  206. },
  207. UnknownTool {
  208. qualified_name: String,
  209. },
  210. UnknownServer {
  211. server_name: String,
  212. },
  213. }
  214. impl std::fmt::Display for McpServerManagerError {
  215. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  216. match self {
  217. Self::Io(error) => write!(f, "{error}"),
  218. Self::JsonRpc {
  219. server_name,
  220. method,
  221. error,
  222. } => write!(
  223. f,
  224. "MCP server `{server_name}` returned JSON-RPC error for {method}: {} ({})",
  225. error.message, error.code
  226. ),
  227. Self::InvalidResponse {
  228. server_name,
  229. method,
  230. details,
  231. } => write!(
  232. f,
  233. "MCP server `{server_name}` returned invalid response for {method}: {details}"
  234. ),
  235. Self::UnknownTool { qualified_name } => {
  236. write!(f, "unknown MCP tool `{qualified_name}`")
  237. }
  238. Self::UnknownServer { server_name } => write!(f, "unknown MCP server `{server_name}`"),
  239. }
  240. }
  241. }
  242. impl std::error::Error for McpServerManagerError {
  243. fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
  244. match self {
  245. Self::Io(error) => Some(error),
  246. Self::JsonRpc { .. }
  247. | Self::InvalidResponse { .. }
  248. | Self::UnknownTool { .. }
  249. | Self::UnknownServer { .. } => None,
  250. }
  251. }
  252. }
  253. impl From<io::Error> for McpServerManagerError {
  254. fn from(value: io::Error) -> Self {
  255. Self::Io(value)
  256. }
  257. }
  258. #[derive(Debug, Clone, PartialEq, Eq)]
  259. struct ToolRoute {
  260. server_name: String,
  261. raw_name: String,
  262. }
  263. #[derive(Debug)]
  264. struct ManagedMcpServer {
  265. bootstrap: McpClientBootstrap,
  266. process: Option<McpStdioProcess>,
  267. initialized: bool,
  268. }
  269. impl ManagedMcpServer {
  270. fn new(bootstrap: McpClientBootstrap) -> Self {
  271. Self {
  272. bootstrap,
  273. process: None,
  274. initialized: false,
  275. }
  276. }
  277. }
  278. #[derive(Debug)]
  279. pub struct McpServerManager {
  280. servers: BTreeMap<String, ManagedMcpServer>,
  281. unsupported_servers: Vec<UnsupportedMcpServer>,
  282. tool_index: BTreeMap<String, ToolRoute>,
  283. next_request_id: u64,
  284. }
  285. impl McpServerManager {
  286. #[must_use]
  287. pub fn from_runtime_config(config: &RuntimeConfig) -> Self {
  288. Self::from_servers(config.mcp().servers())
  289. }
  290. #[must_use]
  291. pub fn from_servers(servers: &BTreeMap<String, ScopedMcpServerConfig>) -> Self {
  292. let mut managed_servers = BTreeMap::new();
  293. let mut unsupported_servers = Vec::new();
  294. for (server_name, server_config) in servers {
  295. if server_config.transport() == McpTransport::Stdio {
  296. let bootstrap = McpClientBootstrap::from_scoped_config(server_name, server_config);
  297. managed_servers.insert(server_name.clone(), ManagedMcpServer::new(bootstrap));
  298. } else {
  299. unsupported_servers.push(UnsupportedMcpServer {
  300. server_name: server_name.clone(),
  301. transport: server_config.transport(),
  302. reason: format!(
  303. "transport {:?} is not supported by McpServerManager",
  304. server_config.transport()
  305. ),
  306. });
  307. }
  308. }
  309. Self {
  310. servers: managed_servers,
  311. unsupported_servers,
  312. tool_index: BTreeMap::new(),
  313. next_request_id: 1,
  314. }
  315. }
  316. #[must_use]
  317. pub fn unsupported_servers(&self) -> &[UnsupportedMcpServer] {
  318. &self.unsupported_servers
  319. }
  320. pub async fn discover_tools(&mut self) -> Result<Vec<ManagedMcpTool>, McpServerManagerError> {
  321. let server_names = self.servers.keys().cloned().collect::<Vec<_>>();
  322. let mut discovered_tools = Vec::new();
  323. for server_name in server_names {
  324. self.ensure_server_ready(&server_name).await?;
  325. self.clear_routes_for_server(&server_name);
  326. let mut cursor = None;
  327. loop {
  328. let request_id = self.take_request_id();
  329. let response = {
  330. let server = self.server_mut(&server_name)?;
  331. let process = server.process.as_mut().ok_or_else(|| {
  332. McpServerManagerError::InvalidResponse {
  333. server_name: server_name.clone(),
  334. method: "tools/list",
  335. details: "server process missing after initialization".to_string(),
  336. }
  337. })?;
  338. process
  339. .list_tools(
  340. request_id,
  341. Some(McpListToolsParams {
  342. cursor: cursor.clone(),
  343. }),
  344. )
  345. .await?
  346. };
  347. if let Some(error) = response.error {
  348. return Err(McpServerManagerError::JsonRpc {
  349. server_name: server_name.clone(),
  350. method: "tools/list",
  351. error,
  352. });
  353. }
  354. let result =
  355. response
  356. .result
  357. .ok_or_else(|| McpServerManagerError::InvalidResponse {
  358. server_name: server_name.clone(),
  359. method: "tools/list",
  360. details: "missing result payload".to_string(),
  361. })?;
  362. for tool in result.tools {
  363. let qualified_name = mcp_tool_name(&server_name, &tool.name);
  364. self.tool_index.insert(
  365. qualified_name.clone(),
  366. ToolRoute {
  367. server_name: server_name.clone(),
  368. raw_name: tool.name.clone(),
  369. },
  370. );
  371. discovered_tools.push(ManagedMcpTool {
  372. server_name: server_name.clone(),
  373. qualified_name,
  374. raw_name: tool.name.clone(),
  375. tool,
  376. });
  377. }
  378. match result.next_cursor {
  379. Some(next_cursor) => cursor = Some(next_cursor),
  380. None => break,
  381. }
  382. }
  383. }
  384. Ok(discovered_tools)
  385. }
  386. pub async fn call_tool(
  387. &mut self,
  388. qualified_tool_name: &str,
  389. arguments: Option<JsonValue>,
  390. ) -> Result<JsonRpcResponse<McpToolCallResult>, McpServerManagerError> {
  391. let route = self
  392. .tool_index
  393. .get(qualified_tool_name)
  394. .cloned()
  395. .ok_or_else(|| McpServerManagerError::UnknownTool {
  396. qualified_name: qualified_tool_name.to_string(),
  397. })?;
  398. self.ensure_server_ready(&route.server_name).await?;
  399. let request_id = self.take_request_id();
  400. let response =
  401. {
  402. let server = self.server_mut(&route.server_name)?;
  403. let process = server.process.as_mut().ok_or_else(|| {
  404. McpServerManagerError::InvalidResponse {
  405. server_name: route.server_name.clone(),
  406. method: "tools/call",
  407. details: "server process missing after initialization".to_string(),
  408. }
  409. })?;
  410. process
  411. .call_tool(
  412. request_id,
  413. McpToolCallParams {
  414. name: route.raw_name,
  415. arguments,
  416. meta: None,
  417. },
  418. )
  419. .await?
  420. };
  421. Ok(response)
  422. }
  423. pub async fn shutdown(&mut self) -> Result<(), McpServerManagerError> {
  424. let server_names = self.servers.keys().cloned().collect::<Vec<_>>();
  425. for server_name in server_names {
  426. let server = self.server_mut(&server_name)?;
  427. if let Some(process) = server.process.as_mut() {
  428. process.shutdown().await?;
  429. }
  430. server.process = None;
  431. server.initialized = false;
  432. }
  433. Ok(())
  434. }
  435. fn clear_routes_for_server(&mut self, server_name: &str) {
  436. self.tool_index
  437. .retain(|_, route| route.server_name != server_name);
  438. }
  439. fn server_mut(
  440. &mut self,
  441. server_name: &str,
  442. ) -> Result<&mut ManagedMcpServer, McpServerManagerError> {
  443. self.servers
  444. .get_mut(server_name)
  445. .ok_or_else(|| McpServerManagerError::UnknownServer {
  446. server_name: server_name.to_string(),
  447. })
  448. }
  449. fn take_request_id(&mut self) -> JsonRpcId {
  450. let id = self.next_request_id;
  451. self.next_request_id = self.next_request_id.saturating_add(1);
  452. JsonRpcId::Number(id)
  453. }
  454. async fn ensure_server_ready(
  455. &mut self,
  456. server_name: &str,
  457. ) -> Result<(), McpServerManagerError> {
  458. let needs_spawn = self
  459. .servers
  460. .get(server_name)
  461. .map(|server| server.process.is_none())
  462. .ok_or_else(|| McpServerManagerError::UnknownServer {
  463. server_name: server_name.to_string(),
  464. })?;
  465. if needs_spawn {
  466. let server = self.server_mut(server_name)?;
  467. server.process = Some(spawn_mcp_stdio_process(&server.bootstrap)?);
  468. server.initialized = false;
  469. }
  470. let needs_initialize = self
  471. .servers
  472. .get(server_name)
  473. .map(|server| !server.initialized)
  474. .ok_or_else(|| McpServerManagerError::UnknownServer {
  475. server_name: server_name.to_string(),
  476. })?;
  477. if needs_initialize {
  478. let request_id = self.take_request_id();
  479. let response = {
  480. let server = self.server_mut(server_name)?;
  481. let process = server.process.as_mut().ok_or_else(|| {
  482. McpServerManagerError::InvalidResponse {
  483. server_name: server_name.to_string(),
  484. method: "initialize",
  485. details: "server process missing before initialize".to_string(),
  486. }
  487. })?;
  488. process
  489. .initialize(request_id, default_initialize_params())
  490. .await?
  491. };
  492. if let Some(error) = response.error {
  493. return Err(McpServerManagerError::JsonRpc {
  494. server_name: server_name.to_string(),
  495. method: "initialize",
  496. error,
  497. });
  498. }
  499. if response.result.is_none() {
  500. return Err(McpServerManagerError::InvalidResponse {
  501. server_name: server_name.to_string(),
  502. method: "initialize",
  503. details: "missing result payload".to_string(),
  504. });
  505. }
  506. let server = self.server_mut(server_name)?;
  507. server.initialized = true;
  508. }
  509. Ok(())
  510. }
  511. }
  512. #[derive(Debug)]
  513. pub struct McpStdioProcess {
  514. child: Child,
  515. stdin: ChildStdin,
  516. stdout: BufReader<ChildStdout>,
  517. }
  518. impl McpStdioProcess {
  519. pub fn spawn(transport: &McpStdioTransport) -> io::Result<Self> {
  520. let mut command = Command::new(&transport.command);
  521. command
  522. .args(&transport.args)
  523. .stdin(Stdio::piped())
  524. .stdout(Stdio::piped())
  525. .stderr(Stdio::inherit());
  526. apply_env(&mut command, &transport.env);
  527. let mut child = command.spawn()?;
  528. let stdin = child
  529. .stdin
  530. .take()
  531. .ok_or_else(|| io::Error::other("stdio MCP process missing stdin pipe"))?;
  532. let stdout = child
  533. .stdout
  534. .take()
  535. .ok_or_else(|| io::Error::other("stdio MCP process missing stdout pipe"))?;
  536. Ok(Self {
  537. child,
  538. stdin,
  539. stdout: BufReader::new(stdout),
  540. })
  541. }
  542. pub async fn write_all(&mut self, bytes: &[u8]) -> io::Result<()> {
  543. self.stdin.write_all(bytes).await
  544. }
  545. pub async fn flush(&mut self) -> io::Result<()> {
  546. self.stdin.flush().await
  547. }
  548. pub async fn write_line(&mut self, line: &str) -> io::Result<()> {
  549. self.write_all(line.as_bytes()).await?;
  550. self.write_all(b"\n").await?;
  551. self.flush().await
  552. }
  553. pub async fn read_line(&mut self) -> io::Result<String> {
  554. let mut line = String::new();
  555. let bytes_read = self.stdout.read_line(&mut line).await?;
  556. if bytes_read == 0 {
  557. return Err(io::Error::new(
  558. io::ErrorKind::UnexpectedEof,
  559. "MCP stdio stream closed while reading line",
  560. ));
  561. }
  562. Ok(line)
  563. }
  564. pub async fn read_available(&mut self) -> io::Result<Vec<u8>> {
  565. let mut buffer = vec![0_u8; 4096];
  566. let read = self.stdout.read(&mut buffer).await?;
  567. buffer.truncate(read);
  568. Ok(buffer)
  569. }
  570. pub async fn write_frame(&mut self, payload: &[u8]) -> io::Result<()> {
  571. let encoded = encode_frame(payload);
  572. self.write_all(&encoded).await?;
  573. self.flush().await
  574. }
  575. pub async fn read_frame(&mut self) -> io::Result<Vec<u8>> {
  576. let mut content_length = None;
  577. loop {
  578. let mut line = String::new();
  579. let bytes_read = self.stdout.read_line(&mut line).await?;
  580. if bytes_read == 0 {
  581. return Err(io::Error::new(
  582. io::ErrorKind::UnexpectedEof,
  583. "MCP stdio stream closed while reading headers",
  584. ));
  585. }
  586. if line == "\r\n" {
  587. break;
  588. }
  589. if let Some(value) = line.strip_prefix("Content-Length:") {
  590. let parsed = value
  591. .trim()
  592. .parse::<usize>()
  593. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
  594. content_length = Some(parsed);
  595. }
  596. }
  597. let content_length = content_length.ok_or_else(|| {
  598. io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header")
  599. })?;
  600. let mut payload = vec![0_u8; content_length];
  601. self.stdout.read_exact(&mut payload).await?;
  602. Ok(payload)
  603. }
  604. pub async fn write_jsonrpc_message<T: Serialize>(&mut self, message: &T) -> io::Result<()> {
  605. let body = serde_json::to_vec(message)
  606. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
  607. self.write_frame(&body).await
  608. }
  609. pub async fn read_jsonrpc_message<T: DeserializeOwned>(&mut self) -> io::Result<T> {
  610. let payload = self.read_frame().await?;
  611. serde_json::from_slice(&payload)
  612. .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
  613. }
  614. pub async fn send_request<T: Serialize>(
  615. &mut self,
  616. request: &JsonRpcRequest<T>,
  617. ) -> io::Result<()> {
  618. self.write_jsonrpc_message(request).await
  619. }
  620. pub async fn read_response<T: DeserializeOwned>(&mut self) -> io::Result<JsonRpcResponse<T>> {
  621. self.read_jsonrpc_message().await
  622. }
  623. pub async fn request<TParams: Serialize, TResult: DeserializeOwned>(
  624. &mut self,
  625. id: JsonRpcId,
  626. method: impl Into<String>,
  627. params: Option<TParams>,
  628. ) -> io::Result<JsonRpcResponse<TResult>> {
  629. let request = JsonRpcRequest::new(id, method, params);
  630. self.send_request(&request).await?;
  631. self.read_response().await
  632. }
  633. pub async fn initialize(
  634. &mut self,
  635. id: JsonRpcId,
  636. params: McpInitializeParams,
  637. ) -> io::Result<JsonRpcResponse<McpInitializeResult>> {
  638. self.request(id, "initialize", Some(params)).await
  639. }
  640. pub async fn list_tools(
  641. &mut self,
  642. id: JsonRpcId,
  643. params: Option<McpListToolsParams>,
  644. ) -> io::Result<JsonRpcResponse<McpListToolsResult>> {
  645. self.request(id, "tools/list", params).await
  646. }
  647. pub async fn call_tool(
  648. &mut self,
  649. id: JsonRpcId,
  650. params: McpToolCallParams,
  651. ) -> io::Result<JsonRpcResponse<McpToolCallResult>> {
  652. self.request(id, "tools/call", Some(params)).await
  653. }
  654. pub async fn list_resources(
  655. &mut self,
  656. id: JsonRpcId,
  657. params: Option<McpListResourcesParams>,
  658. ) -> io::Result<JsonRpcResponse<McpListResourcesResult>> {
  659. self.request(id, "resources/list", params).await
  660. }
  661. pub async fn read_resource(
  662. &mut self,
  663. id: JsonRpcId,
  664. params: McpReadResourceParams,
  665. ) -> io::Result<JsonRpcResponse<McpReadResourceResult>> {
  666. self.request(id, "resources/read", Some(params)).await
  667. }
  668. pub async fn terminate(&mut self) -> io::Result<()> {
  669. self.child.kill().await
  670. }
  671. pub async fn wait(&mut self) -> io::Result<std::process::ExitStatus> {
  672. self.child.wait().await
  673. }
  674. async fn shutdown(&mut self) -> io::Result<()> {
  675. if self.child.try_wait()?.is_none() {
  676. self.child.kill().await?;
  677. }
  678. let _ = self.child.wait().await?;
  679. Ok(())
  680. }
  681. }
  682. pub fn spawn_mcp_stdio_process(bootstrap: &McpClientBootstrap) -> io::Result<McpStdioProcess> {
  683. match &bootstrap.transport {
  684. McpClientTransport::Stdio(transport) => McpStdioProcess::spawn(transport),
  685. other => Err(io::Error::new(
  686. io::ErrorKind::InvalidInput,
  687. format!(
  688. "MCP bootstrap transport for {} is not stdio: {other:?}",
  689. bootstrap.server_name
  690. ),
  691. )),
  692. }
  693. }
  694. fn apply_env(command: &mut Command, env: &BTreeMap<String, String>) {
  695. for (key, value) in env {
  696. command.env(key, value);
  697. }
  698. }
  699. fn encode_frame(payload: &[u8]) -> Vec<u8> {
  700. let header = format!("Content-Length: {}\r\n\r\n", payload.len());
  701. let mut framed = header.into_bytes();
  702. framed.extend_from_slice(payload);
  703. framed
  704. }
  705. fn default_initialize_params() -> McpInitializeParams {
  706. McpInitializeParams {
  707. protocol_version: "2025-03-26".to_string(),
  708. capabilities: JsonValue::Object(serde_json::Map::new()),
  709. client_info: McpInitializeClientInfo {
  710. name: "runtime".to_string(),
  711. version: env!("CARGO_PKG_VERSION").to_string(),
  712. },
  713. }
  714. }
  715. #[cfg(test)]
  716. mod tests {
  717. use std::collections::BTreeMap;
  718. use std::fs;
  719. use std::io::ErrorKind;
  720. use std::os::unix::fs::PermissionsExt;
  721. use std::path::{Path, PathBuf};
  722. use std::time::{SystemTime, UNIX_EPOCH};
  723. use serde_json::json;
  724. use tokio::runtime::Builder;
  725. use crate::config::{
  726. ConfigSource, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig,
  727. McpStdioServerConfig, McpWebSocketServerConfig, ScopedMcpServerConfig,
  728. };
  729. use crate::mcp::mcp_tool_name;
  730. use crate::mcp_client::McpClientBootstrap;
  731. use super::{
  732. spawn_mcp_stdio_process, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
  733. McpInitializeClientInfo, McpInitializeParams, McpInitializeResult, McpInitializeServerInfo,
  734. McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpServerManager,
  735. McpServerManagerError, McpStdioProcess, McpTool, McpToolCallParams,
  736. };
  737. fn temp_dir() -> PathBuf {
  738. let nanos = SystemTime::now()
  739. .duration_since(UNIX_EPOCH)
  740. .expect("time should be after epoch")
  741. .as_nanos();
  742. std::env::temp_dir().join(format!("runtime-mcp-stdio-{nanos}"))
  743. }
  744. fn write_echo_script() -> PathBuf {
  745. let root = temp_dir();
  746. fs::create_dir_all(&root).expect("temp dir");
  747. let script_path = root.join("echo-mcp.sh");
  748. fs::write(
  749. &script_path,
  750. "#!/bin/sh\nprintf 'READY:%s\\n' \"$MCP_TEST_TOKEN\"\nIFS= read -r line\nprintf 'ECHO:%s\\n' \"$line\"\n",
  751. )
  752. .expect("write script");
  753. let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
  754. permissions.set_mode(0o755);
  755. fs::set_permissions(&script_path, permissions).expect("chmod");
  756. script_path
  757. }
  758. fn write_jsonrpc_script() -> PathBuf {
  759. let root = temp_dir();
  760. fs::create_dir_all(&root).expect("temp dir");
  761. let script_path = root.join("jsonrpc-mcp.py");
  762. let script = [
  763. "#!/usr/bin/env python3",
  764. "import json, sys",
  765. "header = b''",
  766. r"while not header.endswith(b'\r\n\r\n'):",
  767. " chunk = sys.stdin.buffer.read(1)",
  768. " if not chunk:",
  769. " raise SystemExit(1)",
  770. " header += chunk",
  771. "length = 0",
  772. r"for line in header.decode().split('\r\n'):",
  773. r" if line.lower().startswith('content-length:'):",
  774. r" length = int(line.split(':', 1)[1].strip())",
  775. "payload = sys.stdin.buffer.read(length)",
  776. "request = json.loads(payload.decode())",
  777. r"assert request['jsonrpc'] == '2.0'",
  778. r"assert request['method'] == 'initialize'",
  779. r"response = json.dumps({",
  780. r" 'jsonrpc': '2.0',",
  781. r" 'id': request['id'],",
  782. r" 'result': {",
  783. r" 'protocolVersion': request['params']['protocolVersion'],",
  784. r" 'capabilities': {'tools': {}},",
  785. r" 'serverInfo': {'name': 'fake-mcp', 'version': '0.1.0'}",
  786. r" }",
  787. r"}).encode()",
  788. r"sys.stdout.buffer.write(f'Content-Length: {len(response)}\r\n\r\n'.encode() + response)",
  789. "sys.stdout.buffer.flush()",
  790. "",
  791. ]
  792. .join("\n");
  793. fs::write(&script_path, script).expect("write script");
  794. let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
  795. permissions.set_mode(0o755);
  796. fs::set_permissions(&script_path, permissions).expect("chmod");
  797. script_path
  798. }
  799. #[allow(clippy::too_many_lines)]
  800. fn write_mcp_server_script() -> PathBuf {
  801. let root = temp_dir();
  802. fs::create_dir_all(&root).expect("temp dir");
  803. let script_path = root.join("fake-mcp-server.py");
  804. let script = [
  805. "#!/usr/bin/env python3",
  806. "import json, sys",
  807. "",
  808. "def read_message():",
  809. " header = b''",
  810. r" while not header.endswith(b'\r\n\r\n'):",
  811. " chunk = sys.stdin.buffer.read(1)",
  812. " if not chunk:",
  813. " return None",
  814. " header += chunk",
  815. " length = 0",
  816. r" for line in header.decode().split('\r\n'):",
  817. r" if line.lower().startswith('content-length:'):",
  818. r" length = int(line.split(':', 1)[1].strip())",
  819. " payload = sys.stdin.buffer.read(length)",
  820. " return json.loads(payload.decode())",
  821. "",
  822. "def send_message(message):",
  823. " payload = json.dumps(message).encode()",
  824. r" sys.stdout.buffer.write(f'Content-Length: {len(payload)}\r\n\r\n'.encode() + payload)",
  825. " sys.stdout.buffer.flush()",
  826. "",
  827. "while True:",
  828. " request = read_message()",
  829. " if request is None:",
  830. " break",
  831. " method = request['method']",
  832. " if method == 'initialize':",
  833. " send_message({",
  834. " 'jsonrpc': '2.0',",
  835. " 'id': request['id'],",
  836. " 'result': {",
  837. " 'protocolVersion': request['params']['protocolVersion'],",
  838. " 'capabilities': {'tools': {}, 'resources': {}},",
  839. " 'serverInfo': {'name': 'fake-mcp', 'version': '0.2.0'}",
  840. " }",
  841. " })",
  842. " elif method == 'tools/list':",
  843. " send_message({",
  844. " 'jsonrpc': '2.0',",
  845. " 'id': request['id'],",
  846. " 'result': {",
  847. " 'tools': [",
  848. " {",
  849. " 'name': 'echo',",
  850. " 'description': 'Echoes text',",
  851. " 'inputSchema': {",
  852. " 'type': 'object',",
  853. " 'properties': {'text': {'type': 'string'}},",
  854. " 'required': ['text']",
  855. " }",
  856. " }",
  857. " ]",
  858. " }",
  859. " })",
  860. " elif method == 'tools/call':",
  861. " args = request['params'].get('arguments') or {}",
  862. " if request['params']['name'] == 'fail':",
  863. " send_message({",
  864. " 'jsonrpc': '2.0',",
  865. " 'id': request['id'],",
  866. " 'error': {'code': -32001, 'message': 'tool failed'},",
  867. " })",
  868. " else:",
  869. " text = args.get('text', '')",
  870. " send_message({",
  871. " 'jsonrpc': '2.0',",
  872. " 'id': request['id'],",
  873. " 'result': {",
  874. " 'content': [{'type': 'text', 'text': f'echo:{text}'}],",
  875. " 'structuredContent': {'echoed': text},",
  876. " 'isError': False",
  877. " }",
  878. " })",
  879. " elif method == 'resources/list':",
  880. " send_message({",
  881. " 'jsonrpc': '2.0',",
  882. " 'id': request['id'],",
  883. " 'result': {",
  884. " 'resources': [",
  885. " {",
  886. " 'uri': 'file://guide.txt',",
  887. " 'name': 'guide',",
  888. " 'description': 'Guide text',",
  889. " 'mimeType': 'text/plain'",
  890. " }",
  891. " ]",
  892. " }",
  893. " })",
  894. " elif method == 'resources/read':",
  895. " uri = request['params']['uri']",
  896. " send_message({",
  897. " 'jsonrpc': '2.0',",
  898. " 'id': request['id'],",
  899. " 'result': {",
  900. " 'contents': [",
  901. " {",
  902. " 'uri': uri,",
  903. " 'mimeType': 'text/plain',",
  904. " 'text': f'contents for {uri}'",
  905. " }",
  906. " ]",
  907. " }",
  908. " })",
  909. " else:",
  910. " send_message({",
  911. " 'jsonrpc': '2.0',",
  912. " 'id': request['id'],",
  913. " 'error': {'code': -32601, 'message': f'unknown method: {method}'},",
  914. " })",
  915. "",
  916. ]
  917. .join("\n");
  918. fs::write(&script_path, script).expect("write script");
  919. let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
  920. permissions.set_mode(0o755);
  921. fs::set_permissions(&script_path, permissions).expect("chmod");
  922. script_path
  923. }
  924. #[allow(clippy::too_many_lines)]
  925. fn write_manager_mcp_server_script() -> PathBuf {
  926. let root = temp_dir();
  927. fs::create_dir_all(&root).expect("temp dir");
  928. let script_path = root.join("manager-mcp-server.py");
  929. let script = [
  930. "#!/usr/bin/env python3",
  931. "import json, os, sys",
  932. "",
  933. "LABEL = os.environ.get('MCP_SERVER_LABEL', 'server')",
  934. "LOG_PATH = os.environ.get('MCP_LOG_PATH')",
  935. "initialize_count = 0",
  936. "",
  937. "def log(method):",
  938. " if LOG_PATH:",
  939. " with open(LOG_PATH, 'a', encoding='utf-8') as handle:",
  940. " handle.write(f'{method}\\n')",
  941. "",
  942. "def read_message():",
  943. " header = b''",
  944. r" while not header.endswith(b'\r\n\r\n'):",
  945. " chunk = sys.stdin.buffer.read(1)",
  946. " if not chunk:",
  947. " return None",
  948. " header += chunk",
  949. " length = 0",
  950. r" for line in header.decode().split('\r\n'):",
  951. r" if line.lower().startswith('content-length:'):",
  952. r" length = int(line.split(':', 1)[1].strip())",
  953. " payload = sys.stdin.buffer.read(length)",
  954. " return json.loads(payload.decode())",
  955. "",
  956. "def send_message(message):",
  957. " payload = json.dumps(message).encode()",
  958. r" sys.stdout.buffer.write(f'Content-Length: {len(payload)}\r\n\r\n'.encode() + payload)",
  959. " sys.stdout.buffer.flush()",
  960. "",
  961. "while True:",
  962. " request = read_message()",
  963. " if request is None:",
  964. " break",
  965. " method = request['method']",
  966. " log(method)",
  967. " if method == 'initialize':",
  968. " initialize_count += 1",
  969. " send_message({",
  970. " 'jsonrpc': '2.0',",
  971. " 'id': request['id'],",
  972. " 'result': {",
  973. " 'protocolVersion': request['params']['protocolVersion'],",
  974. " 'capabilities': {'tools': {}},",
  975. " 'serverInfo': {'name': LABEL, 'version': '1.0.0'}",
  976. " }",
  977. " })",
  978. " elif method == 'tools/list':",
  979. " send_message({",
  980. " 'jsonrpc': '2.0',",
  981. " 'id': request['id'],",
  982. " 'result': {",
  983. " 'tools': [",
  984. " {",
  985. " 'name': 'echo',",
  986. " 'description': f'Echo tool for {LABEL}',",
  987. " 'inputSchema': {",
  988. " 'type': 'object',",
  989. " 'properties': {'text': {'type': 'string'}},",
  990. " 'required': ['text']",
  991. " }",
  992. " }",
  993. " ]",
  994. " }",
  995. " })",
  996. " elif method == 'tools/call':",
  997. " args = request['params'].get('arguments') or {}",
  998. " text = args.get('text', '')",
  999. " send_message({",
  1000. " 'jsonrpc': '2.0',",
  1001. " 'id': request['id'],",
  1002. " 'result': {",
  1003. " 'content': [{'type': 'text', 'text': f'{LABEL}:{text}'}],",
  1004. " 'structuredContent': {",
  1005. " 'server': LABEL,",
  1006. " 'echoed': text,",
  1007. " 'initializeCount': initialize_count",
  1008. " },",
  1009. " 'isError': False",
  1010. " }",
  1011. " })",
  1012. " else:",
  1013. " send_message({",
  1014. " 'jsonrpc': '2.0',",
  1015. " 'id': request['id'],",
  1016. " 'error': {'code': -32601, 'message': f'unknown method: {method}'},",
  1017. " })",
  1018. "",
  1019. ]
  1020. .join("\n");
  1021. fs::write(&script_path, script).expect("write script");
  1022. let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
  1023. permissions.set_mode(0o755);
  1024. fs::set_permissions(&script_path, permissions).expect("chmod");
  1025. script_path
  1026. }
  1027. fn sample_bootstrap(script_path: &Path) -> McpClientBootstrap {
  1028. let config = ScopedMcpServerConfig {
  1029. scope: ConfigSource::Local,
  1030. config: McpServerConfig::Stdio(McpStdioServerConfig {
  1031. command: "/bin/sh".to_string(),
  1032. args: vec![script_path.to_string_lossy().into_owned()],
  1033. env: BTreeMap::from([("MCP_TEST_TOKEN".to_string(), "secret-value".to_string())]),
  1034. }),
  1035. };
  1036. McpClientBootstrap::from_scoped_config("stdio server", &config)
  1037. }
  1038. fn script_transport(script_path: &Path) -> crate::mcp_client::McpStdioTransport {
  1039. crate::mcp_client::McpStdioTransport {
  1040. command: "python3".to_string(),
  1041. args: vec![script_path.to_string_lossy().into_owned()],
  1042. env: BTreeMap::new(),
  1043. }
  1044. }
  1045. fn cleanup_script(script_path: &Path) {
  1046. fs::remove_file(script_path).expect("cleanup script");
  1047. fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir");
  1048. }
  1049. fn manager_server_config(
  1050. script_path: &Path,
  1051. label: &str,
  1052. log_path: &Path,
  1053. ) -> ScopedMcpServerConfig {
  1054. ScopedMcpServerConfig {
  1055. scope: ConfigSource::Local,
  1056. config: McpServerConfig::Stdio(McpStdioServerConfig {
  1057. command: "python3".to_string(),
  1058. args: vec![script_path.to_string_lossy().into_owned()],
  1059. env: BTreeMap::from([
  1060. ("MCP_SERVER_LABEL".to_string(), label.to_string()),
  1061. (
  1062. "MCP_LOG_PATH".to_string(),
  1063. log_path.to_string_lossy().into_owned(),
  1064. ),
  1065. ]),
  1066. }),
  1067. }
  1068. }
  1069. #[test]
  1070. fn spawns_stdio_process_and_round_trips_io() {
  1071. let runtime = Builder::new_current_thread()
  1072. .enable_all()
  1073. .build()
  1074. .expect("runtime");
  1075. runtime.block_on(async {
  1076. let script_path = write_echo_script();
  1077. let bootstrap = sample_bootstrap(&script_path);
  1078. let mut process = spawn_mcp_stdio_process(&bootstrap).expect("spawn stdio process");
  1079. let ready = process.read_line().await.expect("read ready");
  1080. assert_eq!(ready, "READY:secret-value\n");
  1081. process
  1082. .write_line("ping from client")
  1083. .await
  1084. .expect("write line");
  1085. let echoed = process.read_line().await.expect("read echo");
  1086. assert_eq!(echoed, "ECHO:ping from client\n");
  1087. let status = process.wait().await.expect("wait for exit");
  1088. assert!(status.success());
  1089. cleanup_script(&script_path);
  1090. });
  1091. }
  1092. #[test]
  1093. fn rejects_non_stdio_bootstrap() {
  1094. let config = ScopedMcpServerConfig {
  1095. scope: ConfigSource::Local,
  1096. config: McpServerConfig::Sdk(crate::config::McpSdkServerConfig {
  1097. name: "sdk-server".to_string(),
  1098. }),
  1099. };
  1100. let bootstrap = McpClientBootstrap::from_scoped_config("sdk server", &config);
  1101. let error = spawn_mcp_stdio_process(&bootstrap).expect_err("non-stdio should fail");
  1102. assert_eq!(error.kind(), ErrorKind::InvalidInput);
  1103. }
  1104. #[test]
  1105. fn round_trips_initialize_request_and_response_over_stdio_frames() {
  1106. let runtime = Builder::new_current_thread()
  1107. .enable_all()
  1108. .build()
  1109. .expect("runtime");
  1110. runtime.block_on(async {
  1111. let script_path = write_jsonrpc_script();
  1112. let transport = script_transport(&script_path);
  1113. let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly");
  1114. let response = process
  1115. .initialize(
  1116. JsonRpcId::Number(1),
  1117. McpInitializeParams {
  1118. protocol_version: "2025-03-26".to_string(),
  1119. capabilities: json!({"roots": {}}),
  1120. client_info: McpInitializeClientInfo {
  1121. name: "runtime-tests".to_string(),
  1122. version: "0.1.0".to_string(),
  1123. },
  1124. },
  1125. )
  1126. .await
  1127. .expect("initialize roundtrip");
  1128. assert_eq!(response.id, JsonRpcId::Number(1));
  1129. assert_eq!(response.error, None);
  1130. assert_eq!(
  1131. response.result,
  1132. Some(McpInitializeResult {
  1133. protocol_version: "2025-03-26".to_string(),
  1134. capabilities: json!({"tools": {}}),
  1135. server_info: McpInitializeServerInfo {
  1136. name: "fake-mcp".to_string(),
  1137. version: "0.1.0".to_string(),
  1138. },
  1139. })
  1140. );
  1141. let status = process.wait().await.expect("wait for exit");
  1142. assert!(status.success());
  1143. cleanup_script(&script_path);
  1144. });
  1145. }
  1146. #[test]
  1147. fn write_jsonrpc_request_emits_content_length_frame() {
  1148. let runtime = Builder::new_current_thread()
  1149. .enable_all()
  1150. .build()
  1151. .expect("runtime");
  1152. runtime.block_on(async {
  1153. let script_path = write_jsonrpc_script();
  1154. let transport = script_transport(&script_path);
  1155. let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly");
  1156. let request = JsonRpcRequest::new(
  1157. JsonRpcId::Number(7),
  1158. "initialize",
  1159. Some(json!({
  1160. "protocolVersion": "2025-03-26",
  1161. "capabilities": {},
  1162. "clientInfo": {"name": "runtime-tests", "version": "0.1.0"}
  1163. })),
  1164. );
  1165. process.send_request(&request).await.expect("send request");
  1166. let response: JsonRpcResponse<serde_json::Value> =
  1167. process.read_response().await.expect("read response");
  1168. assert_eq!(response.id, JsonRpcId::Number(7));
  1169. assert_eq!(response.jsonrpc, "2.0");
  1170. let status = process.wait().await.expect("wait for exit");
  1171. assert!(status.success());
  1172. cleanup_script(&script_path);
  1173. });
  1174. }
  1175. #[test]
  1176. fn direct_spawn_uses_transport_env() {
  1177. let runtime = Builder::new_current_thread()
  1178. .enable_all()
  1179. .build()
  1180. .expect("runtime");
  1181. runtime.block_on(async {
  1182. let script_path = write_echo_script();
  1183. let transport = crate::mcp_client::McpStdioTransport {
  1184. command: "/bin/sh".to_string(),
  1185. args: vec![script_path.to_string_lossy().into_owned()],
  1186. env: BTreeMap::from([("MCP_TEST_TOKEN".to_string(), "direct-secret".to_string())]),
  1187. };
  1188. let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly");
  1189. let ready = process.read_available().await.expect("read ready");
  1190. assert_eq!(String::from_utf8_lossy(&ready), "READY:direct-secret\n");
  1191. process.terminate().await.expect("terminate child");
  1192. let _ = process.wait().await.expect("wait after kill");
  1193. cleanup_script(&script_path);
  1194. });
  1195. }
  1196. #[test]
  1197. fn lists_tools_calls_tool_and_reads_resources_over_jsonrpc() {
  1198. let runtime = Builder::new_current_thread()
  1199. .enable_all()
  1200. .build()
  1201. .expect("runtime");
  1202. runtime.block_on(async {
  1203. let script_path = write_mcp_server_script();
  1204. let transport = script_transport(&script_path);
  1205. let mut process = McpStdioProcess::spawn(&transport).expect("spawn fake mcp server");
  1206. let tools = process
  1207. .list_tools(JsonRpcId::Number(2), None)
  1208. .await
  1209. .expect("list tools");
  1210. assert_eq!(tools.error, None);
  1211. assert_eq!(tools.id, JsonRpcId::Number(2));
  1212. assert_eq!(
  1213. tools.result,
  1214. Some(McpListToolsResult {
  1215. tools: vec![McpTool {
  1216. name: "echo".to_string(),
  1217. description: Some("Echoes text".to_string()),
  1218. input_schema: Some(json!({
  1219. "type": "object",
  1220. "properties": {"text": {"type": "string"}},
  1221. "required": ["text"]
  1222. })),
  1223. annotations: None,
  1224. meta: None,
  1225. }],
  1226. next_cursor: None,
  1227. })
  1228. );
  1229. let call = process
  1230. .call_tool(
  1231. JsonRpcId::String("call-1".to_string()),
  1232. McpToolCallParams {
  1233. name: "echo".to_string(),
  1234. arguments: Some(json!({"text": "hello"})),
  1235. meta: None,
  1236. },
  1237. )
  1238. .await
  1239. .expect("call tool");
  1240. assert_eq!(call.error, None);
  1241. let call_result = call.result.expect("tool result");
  1242. assert_eq!(call_result.is_error, Some(false));
  1243. assert_eq!(
  1244. call_result.structured_content,
  1245. Some(json!({"echoed": "hello"}))
  1246. );
  1247. assert_eq!(call_result.content.len(), 1);
  1248. assert_eq!(call_result.content[0].kind, "text");
  1249. assert_eq!(
  1250. call_result.content[0].data.get("text"),
  1251. Some(&json!("echo:hello"))
  1252. );
  1253. let resources = process
  1254. .list_resources(JsonRpcId::Number(3), None)
  1255. .await
  1256. .expect("list resources");
  1257. let resources_result = resources.result.expect("resources result");
  1258. assert_eq!(resources_result.resources.len(), 1);
  1259. assert_eq!(resources_result.resources[0].uri, "file://guide.txt");
  1260. assert_eq!(
  1261. resources_result.resources[0].mime_type.as_deref(),
  1262. Some("text/plain")
  1263. );
  1264. let read = process
  1265. .read_resource(
  1266. JsonRpcId::Number(4),
  1267. McpReadResourceParams {
  1268. uri: "file://guide.txt".to_string(),
  1269. },
  1270. )
  1271. .await
  1272. .expect("read resource");
  1273. assert_eq!(
  1274. read.result,
  1275. Some(McpReadResourceResult {
  1276. contents: vec![super::McpResourceContents {
  1277. uri: "file://guide.txt".to_string(),
  1278. mime_type: Some("text/plain".to_string()),
  1279. text: Some("contents for file://guide.txt".to_string()),
  1280. blob: None,
  1281. meta: None,
  1282. }],
  1283. })
  1284. );
  1285. process.terminate().await.expect("terminate child");
  1286. let _ = process.wait().await.expect("wait after kill");
  1287. cleanup_script(&script_path);
  1288. });
  1289. }
  1290. #[test]
  1291. fn surfaces_jsonrpc_errors_from_tool_calls() {
  1292. let runtime = Builder::new_current_thread()
  1293. .enable_all()
  1294. .build()
  1295. .expect("runtime");
  1296. runtime.block_on(async {
  1297. let script_path = write_mcp_server_script();
  1298. let transport = script_transport(&script_path);
  1299. let mut process = McpStdioProcess::spawn(&transport).expect("spawn fake mcp server");
  1300. let response = process
  1301. .call_tool(
  1302. JsonRpcId::Number(9),
  1303. McpToolCallParams {
  1304. name: "fail".to_string(),
  1305. arguments: None,
  1306. meta: None,
  1307. },
  1308. )
  1309. .await
  1310. .expect("call tool with error response");
  1311. assert_eq!(response.id, JsonRpcId::Number(9));
  1312. assert!(response.result.is_none());
  1313. assert_eq!(response.error.as_ref().map(|e| e.code), Some(-32001));
  1314. assert_eq!(
  1315. response.error.as_ref().map(|e| e.message.as_str()),
  1316. Some("tool failed")
  1317. );
  1318. process.terminate().await.expect("terminate child");
  1319. let _ = process.wait().await.expect("wait after kill");
  1320. cleanup_script(&script_path);
  1321. });
  1322. }
  1323. #[test]
  1324. fn manager_discovers_tools_from_stdio_config() {
  1325. let runtime = Builder::new_current_thread()
  1326. .enable_all()
  1327. .build()
  1328. .expect("runtime");
  1329. runtime.block_on(async {
  1330. let script_path = write_manager_mcp_server_script();
  1331. let root = script_path.parent().expect("script parent");
  1332. let log_path = root.join("alpha.log");
  1333. let servers = BTreeMap::from([(
  1334. "alpha".to_string(),
  1335. manager_server_config(&script_path, "alpha", &log_path),
  1336. )]);
  1337. let mut manager = McpServerManager::from_servers(&servers);
  1338. let tools = manager.discover_tools().await.expect("discover tools");
  1339. assert_eq!(tools.len(), 1);
  1340. assert_eq!(tools[0].server_name, "alpha");
  1341. assert_eq!(tools[0].raw_name, "echo");
  1342. assert_eq!(tools[0].qualified_name, mcp_tool_name("alpha", "echo"));
  1343. assert_eq!(tools[0].tool.name, "echo");
  1344. assert!(manager.unsupported_servers().is_empty());
  1345. manager.shutdown().await.expect("shutdown");
  1346. cleanup_script(&script_path);
  1347. });
  1348. }
  1349. #[test]
  1350. fn manager_routes_tool_calls_to_correct_server() {
  1351. let runtime = Builder::new_current_thread()
  1352. .enable_all()
  1353. .build()
  1354. .expect("runtime");
  1355. runtime.block_on(async {
  1356. let script_path = write_manager_mcp_server_script();
  1357. let root = script_path.parent().expect("script parent");
  1358. let alpha_log = root.join("alpha.log");
  1359. let beta_log = root.join("beta.log");
  1360. let servers = BTreeMap::from([
  1361. (
  1362. "alpha".to_string(),
  1363. manager_server_config(&script_path, "alpha", &alpha_log),
  1364. ),
  1365. (
  1366. "beta".to_string(),
  1367. manager_server_config(&script_path, "beta", &beta_log),
  1368. ),
  1369. ]);
  1370. let mut manager = McpServerManager::from_servers(&servers);
  1371. let tools = manager.discover_tools().await.expect("discover tools");
  1372. assert_eq!(tools.len(), 2);
  1373. let alpha = manager
  1374. .call_tool(
  1375. &mcp_tool_name("alpha", "echo"),
  1376. Some(json!({"text": "hello"})),
  1377. )
  1378. .await
  1379. .expect("call alpha tool");
  1380. let beta = manager
  1381. .call_tool(
  1382. &mcp_tool_name("beta", "echo"),
  1383. Some(json!({"text": "world"})),
  1384. )
  1385. .await
  1386. .expect("call beta tool");
  1387. assert_eq!(
  1388. alpha
  1389. .result
  1390. .as_ref()
  1391. .and_then(|result| result.structured_content.as_ref())
  1392. .and_then(|value| value.get("server")),
  1393. Some(&json!("alpha"))
  1394. );
  1395. assert_eq!(
  1396. beta.result
  1397. .as_ref()
  1398. .and_then(|result| result.structured_content.as_ref())
  1399. .and_then(|value| value.get("server")),
  1400. Some(&json!("beta"))
  1401. );
  1402. manager.shutdown().await.expect("shutdown");
  1403. cleanup_script(&script_path);
  1404. });
  1405. }
  1406. #[test]
  1407. fn manager_records_unsupported_non_stdio_servers_without_panicking() {
  1408. let servers = BTreeMap::from([
  1409. (
  1410. "http".to_string(),
  1411. ScopedMcpServerConfig {
  1412. scope: ConfigSource::Local,
  1413. config: McpServerConfig::Http(McpRemoteServerConfig {
  1414. url: "https://example.test/mcp".to_string(),
  1415. headers: BTreeMap::new(),
  1416. headers_helper: None,
  1417. oauth: None,
  1418. }),
  1419. },
  1420. ),
  1421. (
  1422. "sdk".to_string(),
  1423. ScopedMcpServerConfig {
  1424. scope: ConfigSource::Local,
  1425. config: McpServerConfig::Sdk(McpSdkServerConfig {
  1426. name: "sdk-server".to_string(),
  1427. }),
  1428. },
  1429. ),
  1430. (
  1431. "ws".to_string(),
  1432. ScopedMcpServerConfig {
  1433. scope: ConfigSource::Local,
  1434. config: McpServerConfig::Ws(McpWebSocketServerConfig {
  1435. url: "wss://example.test/mcp".to_string(),
  1436. headers: BTreeMap::new(),
  1437. headers_helper: None,
  1438. }),
  1439. },
  1440. ),
  1441. ]);
  1442. let manager = McpServerManager::from_servers(&servers);
  1443. let unsupported = manager.unsupported_servers();
  1444. assert_eq!(unsupported.len(), 3);
  1445. assert_eq!(unsupported[0].server_name, "http");
  1446. assert_eq!(unsupported[1].server_name, "sdk");
  1447. assert_eq!(unsupported[2].server_name, "ws");
  1448. }
  1449. #[test]
  1450. fn manager_shutdown_terminates_spawned_children_and_is_idempotent() {
  1451. let runtime = Builder::new_current_thread()
  1452. .enable_all()
  1453. .build()
  1454. .expect("runtime");
  1455. runtime.block_on(async {
  1456. let script_path = write_manager_mcp_server_script();
  1457. let root = script_path.parent().expect("script parent");
  1458. let log_path = root.join("alpha.log");
  1459. let servers = BTreeMap::from([(
  1460. "alpha".to_string(),
  1461. manager_server_config(&script_path, "alpha", &log_path),
  1462. )]);
  1463. let mut manager = McpServerManager::from_servers(&servers);
  1464. manager.discover_tools().await.expect("discover tools");
  1465. manager.shutdown().await.expect("first shutdown");
  1466. manager.shutdown().await.expect("second shutdown");
  1467. cleanup_script(&script_path);
  1468. });
  1469. }
  1470. #[test]
  1471. fn manager_reuses_spawned_server_between_discovery_and_call() {
  1472. let runtime = Builder::new_current_thread()
  1473. .enable_all()
  1474. .build()
  1475. .expect("runtime");
  1476. runtime.block_on(async {
  1477. let script_path = write_manager_mcp_server_script();
  1478. let root = script_path.parent().expect("script parent");
  1479. let log_path = root.join("alpha.log");
  1480. let servers = BTreeMap::from([(
  1481. "alpha".to_string(),
  1482. manager_server_config(&script_path, "alpha", &log_path),
  1483. )]);
  1484. let mut manager = McpServerManager::from_servers(&servers);
  1485. manager.discover_tools().await.expect("discover tools");
  1486. let response = manager
  1487. .call_tool(
  1488. &mcp_tool_name("alpha", "echo"),
  1489. Some(json!({"text": "reuse"})),
  1490. )
  1491. .await
  1492. .expect("call tool");
  1493. assert_eq!(
  1494. response
  1495. .result
  1496. .as_ref()
  1497. .and_then(|result| result.structured_content.as_ref())
  1498. .and_then(|value| value.get("initializeCount")),
  1499. Some(&json!(1))
  1500. );
  1501. let log = fs::read_to_string(&log_path).expect("read log");
  1502. assert_eq!(log.lines().filter(|line| *line == "initialize").count(), 1);
  1503. assert_eq!(
  1504. log.lines().collect::<Vec<_>>(),
  1505. vec!["initialize", "tools/list", "tools/call"]
  1506. );
  1507. manager.shutdown().await.expect("shutdown");
  1508. cleanup_script(&script_path);
  1509. });
  1510. }
  1511. #[test]
  1512. fn manager_reports_unknown_qualified_tool_name() {
  1513. let runtime = Builder::new_current_thread()
  1514. .enable_all()
  1515. .build()
  1516. .expect("runtime");
  1517. runtime.block_on(async {
  1518. let script_path = write_manager_mcp_server_script();
  1519. let root = script_path.parent().expect("script parent");
  1520. let log_path = root.join("alpha.log");
  1521. let servers = BTreeMap::from([(
  1522. "alpha".to_string(),
  1523. manager_server_config(&script_path, "alpha", &log_path),
  1524. )]);
  1525. let mut manager = McpServerManager::from_servers(&servers);
  1526. let error = manager
  1527. .call_tool(
  1528. &mcp_tool_name("alpha", "missing"),
  1529. Some(json!({"text": "nope"})),
  1530. )
  1531. .await
  1532. .expect_err("unknown qualified tool should fail");
  1533. match error {
  1534. McpServerManagerError::UnknownTool { qualified_name } => {
  1535. assert_eq!(qualified_name, mcp_tool_name("alpha", "missing"));
  1536. }
  1537. other => panic!("expected unknown tool error, got {other:?}"),
  1538. }
  1539. cleanup_script(&script_path);
  1540. });
  1541. }
  1542. }