mcp_client.rs 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. use std::collections::BTreeMap;
  2. use crate::config::{McpOAuthConfig, McpServerConfig, ScopedMcpServerConfig};
  3. use crate::mcp::{mcp_server_signature, mcp_tool_prefix, normalize_name_for_mcp};
  4. pub const DEFAULT_MCP_TOOL_CALL_TIMEOUT_MS: u64 = 60_000;
  5. #[derive(Debug, Clone, PartialEq, Eq)]
  6. pub enum McpClientTransport {
  7. Stdio(McpStdioTransport),
  8. Sse(McpRemoteTransport),
  9. Http(McpRemoteTransport),
  10. WebSocket(McpRemoteTransport),
  11. Sdk(McpSdkTransport),
  12. ManagedProxy(McpManagedProxyTransport),
  13. }
  14. #[derive(Debug, Clone, PartialEq, Eq)]
  15. pub struct McpStdioTransport {
  16. pub command: String,
  17. pub args: Vec<String>,
  18. pub env: BTreeMap<String, String>,
  19. pub tool_call_timeout_ms: Option<u64>,
  20. }
  21. #[derive(Debug, Clone, PartialEq, Eq)]
  22. pub struct McpRemoteTransport {
  23. pub url: String,
  24. pub headers: BTreeMap<String, String>,
  25. pub headers_helper: Option<String>,
  26. pub auth: McpClientAuth,
  27. }
  28. #[derive(Debug, Clone, PartialEq, Eq)]
  29. pub struct McpSdkTransport {
  30. pub name: String,
  31. }
  32. #[derive(Debug, Clone, PartialEq, Eq)]
  33. pub struct McpManagedProxyTransport {
  34. pub url: String,
  35. pub id: String,
  36. }
  37. #[derive(Debug, Clone, PartialEq, Eq)]
  38. pub enum McpClientAuth {
  39. None,
  40. OAuth(McpOAuthConfig),
  41. }
  42. #[derive(Debug, Clone, PartialEq, Eq)]
  43. pub struct McpClientBootstrap {
  44. pub server_name: String,
  45. pub normalized_name: String,
  46. pub tool_prefix: String,
  47. pub signature: Option<String>,
  48. pub transport: McpClientTransport,
  49. }
  50. impl McpClientBootstrap {
  51. #[must_use]
  52. pub fn from_scoped_config(server_name: &str, config: &ScopedMcpServerConfig) -> Self {
  53. Self {
  54. server_name: server_name.to_string(),
  55. normalized_name: normalize_name_for_mcp(server_name),
  56. tool_prefix: mcp_tool_prefix(server_name),
  57. signature: mcp_server_signature(&config.config),
  58. transport: McpClientTransport::from_config(&config.config),
  59. }
  60. }
  61. }
  62. impl McpClientTransport {
  63. #[must_use]
  64. pub fn from_config(config: &McpServerConfig) -> Self {
  65. match config {
  66. McpServerConfig::Stdio(config) => Self::Stdio(McpStdioTransport {
  67. command: config.command.clone(),
  68. args: config.args.clone(),
  69. env: config.env.clone(),
  70. tool_call_timeout_ms: config.tool_call_timeout_ms,
  71. }),
  72. McpServerConfig::Sse(config) => Self::Sse(McpRemoteTransport {
  73. url: config.url.clone(),
  74. headers: config.headers.clone(),
  75. headers_helper: config.headers_helper.clone(),
  76. auth: McpClientAuth::from_oauth(config.oauth.clone()),
  77. }),
  78. McpServerConfig::Http(config) => Self::Http(McpRemoteTransport {
  79. url: config.url.clone(),
  80. headers: config.headers.clone(),
  81. headers_helper: config.headers_helper.clone(),
  82. auth: McpClientAuth::from_oauth(config.oauth.clone()),
  83. }),
  84. McpServerConfig::Ws(config) => Self::WebSocket(McpRemoteTransport {
  85. url: config.url.clone(),
  86. headers: config.headers.clone(),
  87. headers_helper: config.headers_helper.clone(),
  88. auth: McpClientAuth::None,
  89. }),
  90. McpServerConfig::Sdk(config) => Self::Sdk(McpSdkTransport {
  91. name: config.name.clone(),
  92. }),
  93. McpServerConfig::ManagedProxy(config) => Self::ManagedProxy(McpManagedProxyTransport {
  94. url: config.url.clone(),
  95. id: config.id.clone(),
  96. }),
  97. }
  98. }
  99. }
  100. impl McpStdioTransport {
  101. #[must_use]
  102. pub fn resolved_tool_call_timeout_ms(&self) -> u64 {
  103. self.tool_call_timeout_ms
  104. .unwrap_or(DEFAULT_MCP_TOOL_CALL_TIMEOUT_MS)
  105. }
  106. }
  107. impl McpClientAuth {
  108. #[must_use]
  109. pub fn from_oauth(oauth: Option<McpOAuthConfig>) -> Self {
  110. oauth.map_or(Self::None, Self::OAuth)
  111. }
  112. #[must_use]
  113. pub const fn requires_user_auth(&self) -> bool {
  114. matches!(self, Self::OAuth(_))
  115. }
  116. }
  117. #[cfg(test)]
  118. mod tests {
  119. use std::collections::BTreeMap;
  120. use crate::config::{
  121. ConfigSource, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig,
  122. McpStdioServerConfig, McpWebSocketServerConfig, ScopedMcpServerConfig,
  123. };
  124. use super::{McpClientAuth, McpClientBootstrap, McpClientTransport};
  125. #[test]
  126. fn bootstraps_stdio_servers_into_transport_targets() {
  127. let config = ScopedMcpServerConfig {
  128. scope: ConfigSource::User,
  129. config: McpServerConfig::Stdio(McpStdioServerConfig {
  130. command: "uvx".to_string(),
  131. args: vec!["mcp-server".to_string()],
  132. env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]),
  133. tool_call_timeout_ms: Some(15_000),
  134. }),
  135. };
  136. let bootstrap = McpClientBootstrap::from_scoped_config("stdio-server", &config);
  137. assert_eq!(bootstrap.normalized_name, "stdio-server");
  138. assert_eq!(bootstrap.tool_prefix, "mcp__stdio-server__");
  139. assert_eq!(
  140. bootstrap.signature.as_deref(),
  141. Some("stdio:[uvx|mcp-server]")
  142. );
  143. match bootstrap.transport {
  144. McpClientTransport::Stdio(transport) => {
  145. assert_eq!(transport.command, "uvx");
  146. assert_eq!(transport.args, vec!["mcp-server"]);
  147. assert_eq!(
  148. transport.env.get("TOKEN").map(String::as_str),
  149. Some("secret")
  150. );
  151. assert_eq!(transport.tool_call_timeout_ms, Some(15_000));
  152. }
  153. other => panic!("expected stdio transport, got {other:?}"),
  154. }
  155. }
  156. #[test]
  157. fn bootstraps_remote_servers_with_oauth_auth() {
  158. let config = ScopedMcpServerConfig {
  159. scope: ConfigSource::Project,
  160. config: McpServerConfig::Http(McpRemoteServerConfig {
  161. url: "https://vendor.example/mcp".to_string(),
  162. headers: BTreeMap::from([("X-Test".to_string(), "1".to_string())]),
  163. headers_helper: Some("helper.sh".to_string()),
  164. oauth: Some(McpOAuthConfig {
  165. client_id: Some("client-id".to_string()),
  166. callback_port: Some(7777),
  167. auth_server_metadata_url: Some(
  168. "https://issuer.example/.well-known/oauth-authorization-server".to_string(),
  169. ),
  170. xaa: Some(true),
  171. }),
  172. }),
  173. };
  174. let bootstrap = McpClientBootstrap::from_scoped_config("remote server", &config);
  175. assert_eq!(bootstrap.normalized_name, "remote_server");
  176. match bootstrap.transport {
  177. McpClientTransport::Http(transport) => {
  178. assert_eq!(transport.url, "https://vendor.example/mcp");
  179. assert_eq!(transport.headers_helper.as_deref(), Some("helper.sh"));
  180. assert!(transport.auth.requires_user_auth());
  181. match transport.auth {
  182. McpClientAuth::OAuth(oauth) => {
  183. assert_eq!(oauth.client_id.as_deref(), Some("client-id"));
  184. }
  185. other @ McpClientAuth::None => panic!("expected oauth auth, got {other:?}"),
  186. }
  187. }
  188. other => panic!("expected http transport, got {other:?}"),
  189. }
  190. }
  191. #[test]
  192. fn bootstraps_websocket_and_sdk_transports_without_oauth() {
  193. let ws = ScopedMcpServerConfig {
  194. scope: ConfigSource::Local,
  195. config: McpServerConfig::Ws(McpWebSocketServerConfig {
  196. url: "wss://vendor.example/mcp".to_string(),
  197. headers: BTreeMap::new(),
  198. headers_helper: None,
  199. }),
  200. };
  201. let sdk = ScopedMcpServerConfig {
  202. scope: ConfigSource::Local,
  203. config: McpServerConfig::Sdk(McpSdkServerConfig {
  204. name: "sdk-server".to_string(),
  205. }),
  206. };
  207. let ws_bootstrap = McpClientBootstrap::from_scoped_config("ws server", &ws);
  208. match ws_bootstrap.transport {
  209. McpClientTransport::WebSocket(transport) => {
  210. assert_eq!(transport.url, "wss://vendor.example/mcp");
  211. assert!(!transport.auth.requires_user_auth());
  212. }
  213. other => panic!("expected websocket transport, got {other:?}"),
  214. }
  215. let sdk_bootstrap = McpClientBootstrap::from_scoped_config("sdk server", &sdk);
  216. assert_eq!(sdk_bootstrap.signature, None);
  217. match sdk_bootstrap.transport {
  218. McpClientTransport::Sdk(transport) => {
  219. assert_eq!(transport.name, "sdk-server");
  220. }
  221. other => panic!("expected sdk transport, got {other:?}"),
  222. }
  223. }
  224. }