mcp_client.rs 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. use std::collections::BTreeMap;
  2. use crate::config::{McpOAuthConfig, McpServerConfig, ScopedMcpServerConfig};
  3. use crate::mcp::{mcp_server_signature, mcp_tool_prefix, normalize_name_for_mcp};
  4. #[derive(Debug, Clone, PartialEq, Eq)]
  5. pub enum McpClientTransport {
  6. Stdio(McpStdioTransport),
  7. Sse(McpRemoteTransport),
  8. Http(McpRemoteTransport),
  9. WebSocket(McpRemoteTransport),
  10. Sdk(McpSdkTransport),
  11. ClaudeAiProxy(McpClaudeAiProxyTransport),
  12. }
  13. #[derive(Debug, Clone, PartialEq, Eq)]
  14. pub struct McpStdioTransport {
  15. pub command: String,
  16. pub args: Vec<String>,
  17. pub env: BTreeMap<String, String>,
  18. }
  19. #[derive(Debug, Clone, PartialEq, Eq)]
  20. pub struct McpRemoteTransport {
  21. pub url: String,
  22. pub headers: BTreeMap<String, String>,
  23. pub headers_helper: Option<String>,
  24. pub auth: McpClientAuth,
  25. }
  26. #[derive(Debug, Clone, PartialEq, Eq)]
  27. pub struct McpSdkTransport {
  28. pub name: String,
  29. }
  30. #[derive(Debug, Clone, PartialEq, Eq)]
  31. pub struct McpClaudeAiProxyTransport {
  32. pub url: String,
  33. pub id: String,
  34. }
  35. #[derive(Debug, Clone, PartialEq, Eq)]
  36. pub enum McpClientAuth {
  37. None,
  38. OAuth(McpOAuthConfig),
  39. }
  40. #[derive(Debug, Clone, PartialEq, Eq)]
  41. pub struct McpClientBootstrap {
  42. pub server_name: String,
  43. pub normalized_name: String,
  44. pub tool_prefix: String,
  45. pub signature: Option<String>,
  46. pub transport: McpClientTransport,
  47. }
  48. impl McpClientBootstrap {
  49. #[must_use]
  50. pub fn from_scoped_config(server_name: &str, config: &ScopedMcpServerConfig) -> Self {
  51. Self {
  52. server_name: server_name.to_string(),
  53. normalized_name: normalize_name_for_mcp(server_name),
  54. tool_prefix: mcp_tool_prefix(server_name),
  55. signature: mcp_server_signature(&config.config),
  56. transport: McpClientTransport::from_config(&config.config),
  57. }
  58. }
  59. }
  60. impl McpClientTransport {
  61. #[must_use]
  62. pub fn from_config(config: &McpServerConfig) -> Self {
  63. match config {
  64. McpServerConfig::Stdio(config) => Self::Stdio(McpStdioTransport {
  65. command: config.command.clone(),
  66. args: config.args.clone(),
  67. env: config.env.clone(),
  68. }),
  69. McpServerConfig::Sse(config) => Self::Sse(McpRemoteTransport {
  70. url: config.url.clone(),
  71. headers: config.headers.clone(),
  72. headers_helper: config.headers_helper.clone(),
  73. auth: McpClientAuth::from_oauth(config.oauth.clone()),
  74. }),
  75. McpServerConfig::Http(config) => Self::Http(McpRemoteTransport {
  76. url: config.url.clone(),
  77. headers: config.headers.clone(),
  78. headers_helper: config.headers_helper.clone(),
  79. auth: McpClientAuth::from_oauth(config.oauth.clone()),
  80. }),
  81. McpServerConfig::Ws(config) => Self::WebSocket(McpRemoteTransport {
  82. url: config.url.clone(),
  83. headers: config.headers.clone(),
  84. headers_helper: config.headers_helper.clone(),
  85. auth: McpClientAuth::None,
  86. }),
  87. McpServerConfig::Sdk(config) => Self::Sdk(McpSdkTransport {
  88. name: config.name.clone(),
  89. }),
  90. McpServerConfig::ClaudeAiProxy(config) => {
  91. Self::ClaudeAiProxy(McpClaudeAiProxyTransport {
  92. url: config.url.clone(),
  93. id: config.id.clone(),
  94. })
  95. }
  96. }
  97. }
  98. }
  99. impl McpClientAuth {
  100. #[must_use]
  101. pub fn from_oauth(oauth: Option<McpOAuthConfig>) -> Self {
  102. oauth.map_or(Self::None, Self::OAuth)
  103. }
  104. #[must_use]
  105. pub const fn requires_user_auth(&self) -> bool {
  106. matches!(self, Self::OAuth(_))
  107. }
  108. }
  109. #[cfg(test)]
  110. mod tests {
  111. use std::collections::BTreeMap;
  112. use crate::config::{
  113. ConfigSource, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig,
  114. McpStdioServerConfig, McpWebSocketServerConfig, ScopedMcpServerConfig,
  115. };
  116. use super::{McpClientAuth, McpClientBootstrap, McpClientTransport};
  117. #[test]
  118. fn bootstraps_stdio_servers_into_transport_targets() {
  119. let config = ScopedMcpServerConfig {
  120. scope: ConfigSource::User,
  121. config: McpServerConfig::Stdio(McpStdioServerConfig {
  122. command: "uvx".to_string(),
  123. args: vec!["mcp-server".to_string()],
  124. env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]),
  125. }),
  126. };
  127. let bootstrap = McpClientBootstrap::from_scoped_config("stdio-server", &config);
  128. assert_eq!(bootstrap.normalized_name, "stdio-server");
  129. assert_eq!(bootstrap.tool_prefix, "mcp__stdio-server__");
  130. assert_eq!(
  131. bootstrap.signature.as_deref(),
  132. Some("stdio:[uvx|mcp-server]")
  133. );
  134. match bootstrap.transport {
  135. McpClientTransport::Stdio(transport) => {
  136. assert_eq!(transport.command, "uvx");
  137. assert_eq!(transport.args, vec!["mcp-server"]);
  138. assert_eq!(
  139. transport.env.get("TOKEN").map(String::as_str),
  140. Some("secret")
  141. );
  142. }
  143. other => panic!("expected stdio transport, got {other:?}"),
  144. }
  145. }
  146. #[test]
  147. fn bootstraps_remote_servers_with_oauth_auth() {
  148. let config = ScopedMcpServerConfig {
  149. scope: ConfigSource::Project,
  150. config: McpServerConfig::Http(McpRemoteServerConfig {
  151. url: "https://vendor.example/mcp".to_string(),
  152. headers: BTreeMap::from([("X-Test".to_string(), "1".to_string())]),
  153. headers_helper: Some("helper.sh".to_string()),
  154. oauth: Some(McpOAuthConfig {
  155. client_id: Some("client-id".to_string()),
  156. callback_port: Some(7777),
  157. auth_server_metadata_url: Some(
  158. "https://issuer.example/.well-known/oauth-authorization-server".to_string(),
  159. ),
  160. xaa: Some(true),
  161. }),
  162. }),
  163. };
  164. let bootstrap = McpClientBootstrap::from_scoped_config("remote server", &config);
  165. assert_eq!(bootstrap.normalized_name, "remote_server");
  166. match bootstrap.transport {
  167. McpClientTransport::Http(transport) => {
  168. assert_eq!(transport.url, "https://vendor.example/mcp");
  169. assert_eq!(transport.headers_helper.as_deref(), Some("helper.sh"));
  170. assert!(transport.auth.requires_user_auth());
  171. match transport.auth {
  172. McpClientAuth::OAuth(oauth) => {
  173. assert_eq!(oauth.client_id.as_deref(), Some("client-id"));
  174. }
  175. other @ McpClientAuth::None => panic!("expected oauth auth, got {other:?}"),
  176. }
  177. }
  178. other => panic!("expected http transport, got {other:?}"),
  179. }
  180. }
  181. #[test]
  182. fn bootstraps_websocket_and_sdk_transports_without_oauth() {
  183. let ws = ScopedMcpServerConfig {
  184. scope: ConfigSource::Local,
  185. config: McpServerConfig::Ws(McpWebSocketServerConfig {
  186. url: "wss://vendor.example/mcp".to_string(),
  187. headers: BTreeMap::new(),
  188. headers_helper: None,
  189. }),
  190. };
  191. let sdk = ScopedMcpServerConfig {
  192. scope: ConfigSource::Local,
  193. config: McpServerConfig::Sdk(McpSdkServerConfig {
  194. name: "sdk-server".to_string(),
  195. }),
  196. };
  197. let ws_bootstrap = McpClientBootstrap::from_scoped_config("ws server", &ws);
  198. match ws_bootstrap.transport {
  199. McpClientTransport::WebSocket(transport) => {
  200. assert_eq!(transport.url, "wss://vendor.example/mcp");
  201. assert!(!transport.auth.requires_user_auth());
  202. }
  203. other => panic!("expected websocket transport, got {other:?}"),
  204. }
  205. let sdk_bootstrap = McpClientBootstrap::from_scoped_config("sdk server", &sdk);
  206. assert_eq!(sdk_bootstrap.signature, None);
  207. match sdk_bootstrap.transport {
  208. McpClientTransport::Sdk(transport) => {
  209. assert_eq!(transport.name, "sdk-server");
  210. }
  211. other => panic!("expected sdk transport, got {other:?}"),
  212. }
  213. }
  214. }