1use crate::config::McpServerEntry;
11use crate::error::{AppError, Result};
12use crate::types::{Tool, ToolsCallResult};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::time::Duration;
18use tokio::sync::RwLock;
19
20#[derive(Debug, Serialize)]
22struct JsonRpcClientRequest {
23 jsonrpc: &'static str,
24 id: u64,
25 method: String,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 params: Option<serde_json::Value>,
28}
29
30#[derive(Debug, Deserialize)]
32struct JsonRpcClientResponse {
33 #[allow(dead_code)]
34 jsonrpc: String,
35 #[allow(dead_code)]
36 id: Option<serde_json::Value>,
37 result: Option<serde_json::Value>,
38 error: Option<JsonRpcClientError>,
39}
40
41#[derive(Debug, Deserialize)]
43struct JsonRpcClientError {
44 code: i64,
45 message: String,
46 #[allow(dead_code)]
47 data: Option<serde_json::Value>,
48}
49
50#[derive(Debug)]
52pub struct McpClient {
53 pub name: String,
55 transport: McpTransport,
57 next_id: AtomicU64,
59 rpc_timeout: Duration,
61 tools: RwLock<Vec<Tool>>,
63}
64
65#[derive(Debug)]
67enum McpTransport {
68 Http {
70 url: String,
71 headers: HashMap<String, String>,
72 client: reqwest::Client,
73 },
74 Stdio {
76 conn: Arc<tokio::sync::Mutex<StdioConnection>>,
77 },
78}
79
80#[derive(Debug)]
86struct StdioConnection {
87 child: tokio::process::Child,
88 stdin: tokio::process::ChildStdin,
89 stdout: tokio::io::BufReader<tokio::process::ChildStdout>,
90}
91
92impl Drop for StdioConnection {
93 fn drop(&mut self) {
94 #[cfg(unix)]
95 {
96 if let Some(pid) = self.child.id() {
97 let _ = std::process::Command::new("kill")
101 .arg("-9")
102 .arg(format!("-{}", pid))
103 .output();
104 }
105 }
106 let _ = self.child.start_kill();
108 }
109}
110
111pub struct McpClientRegistry {
113 clients: RwLock<HashMap<String, Arc<McpClient>>>,
114}
115
116impl Default for McpClientRegistry {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl McpClientRegistry {
123 pub fn new() -> Self {
125 Self {
126 clients: RwLock::new(HashMap::new()),
127 }
128 }
129
130 pub async fn connect(&self, entry: &McpServerEntry) -> Result<Vec<Tool>> {
135 if !entry.enabled {
136 return Err(AppError::Io(std::io::Error::other(format!(
137 "MCP server '{}' is disabled",
138 entry.name
139 ))));
140 }
141
142 let client = McpClient::connect(entry).await?;
143 let tools = client.tools.read().await.clone();
144 self.clients
145 .write()
146 .await
147 .insert(entry.name.clone(), Arc::new(client));
148 Ok(tools)
149 }
150
151 pub async fn get(&self, name: &str) -> Option<Arc<McpClient>> {
153 self.clients.read().await.get(name).cloned()
154 }
155
156 pub async fn disconnect(&self, name: &str) {
158 self.clients.write().await.remove(name);
159 }
160
161 pub async fn connected_servers(&self) -> Vec<String> {
163 self.clients.read().await.keys().cloned().collect()
164 }
165
166 pub async fn all_tools(&self) -> Vec<(String, Vec<Tool>)> {
168 let clients = self.clients.read().await;
169 let mut out = Vec::with_capacity(clients.len());
170 for (name, client) in clients.iter() {
171 let tools = client.tools.read().await.clone();
172 out.push((name.clone(), tools));
173 }
174 out
175 }
176
177 pub async fn call_tool(
179 &self,
180 server_name: &str,
181 tool_name: &str,
182 arguments: serde_json::Value,
183 ) -> Result<ToolsCallResult> {
184 let client = self.get(server_name).await.ok_or_else(|| {
185 AppError::Io(std::io::Error::other(format!(
186 "MCP server '{}' is not connected",
187 server_name
188 )))
189 })?;
190 client.call_tool(tool_name, arguments).await
191 }
192}
193
194impl McpClient {
199 pub async fn connect(entry: &McpServerEntry) -> Result<Self> {
202 use crate::config::McpTransportType;
203
204 let rpc_timeout = Duration::from_secs(entry.timeout_secs.max(1));
205
206 let transport = match entry.transport {
207 McpTransportType::Http | McpTransportType::Sse => {
208 let url = entry.url.clone().ok_or_else(|| {
209 AppError::Io(std::io::Error::other(format!(
210 "MCP server '{}': HTTP transport requires a url",
211 entry.name
212 )))
213 })?;
214 McpTransport::Http {
215 url,
216 headers: entry.headers.clone(),
217 client: reqwest::Client::builder()
218 .timeout(rpc_timeout)
219 .build()
220 .map_err(|e| {
221 AppError::Io(std::io::Error::other(format!(
222 "Failed to create HTTP client: {e}"
223 )))
224 })?,
225 }
226 }
227 McpTransportType::Stdio => {
228 let command = entry.command.as_deref().ok_or_else(|| {
229 AppError::Io(std::io::Error::other(format!(
230 "MCP server '{}': stdio transport requires a command",
231 entry.name
232 )))
233 })?;
234
235 let resolved_command = crate::cmd_utils::resolve_mcp_command(command);
236 let mut envs = entry.env.clone();
237 crate::cmd_utils::inject_enriched_path(&mut envs);
238
239 let mut cmd = tokio::process::Command::new(resolved_command.clone());
240 cmd.args(&entry.args)
241 .envs(&envs)
242 .kill_on_drop(true)
243 .stdin(std::process::Stdio::piped())
244 .stdout(std::process::Stdio::piped())
245 .stderr(std::process::Stdio::null());
246
247 #[cfg(unix)]
249 {
250 cmd.process_group(0);
251 }
252
253 let mut child = cmd.spawn().map_err(|e| {
254 AppError::Io(std::io::Error::other(format!(
255 "Failed to spawn MCP server '{}' ({}): {e}",
256 entry.name, resolved_command
257 )))
258 })?;
259
260 let stdin = child.stdin.take().ok_or_else(|| {
261 AppError::Io(std::io::Error::other(format!(
262 "MCP server '{}': stdin not available",
263 entry.name
264 )))
265 })?;
266 let stdout = child.stdout.take().ok_or_else(|| {
267 AppError::Io(std::io::Error::other(format!(
268 "MCP server '{}': stdout not available",
269 entry.name
270 )))
271 })?;
272
273 let conn = StdioConnection {
274 child,
275 stdin,
276 stdout: tokio::io::BufReader::new(stdout),
277 };
278
279 McpTransport::Stdio {
280 conn: Arc::new(tokio::sync::Mutex::new(conn)),
281 }
282 }
283 };
284
285 let client = Self {
286 name: entry.name.clone(),
287 transport,
288 next_id: AtomicU64::new(1),
289 rpc_timeout,
290 tools: RwLock::new(Vec::new()),
291 };
292
293 client.initialize().await?;
295
296 let tools = client.list_tools_rpc().await?;
298 *client.tools.write().await = tools;
299
300 tracing::info!(
301 "MCP client '{}': connected, {} tools discovered",
302 client.name,
303 client.tools.read().await.len()
304 );
305
306 Ok(client)
307 }
308
309 async fn rpc(
311 &self,
312 method: &str,
313 params: Option<serde_json::Value>,
314 ) -> Result<serde_json::Value> {
315 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
316 let request = JsonRpcClientRequest {
317 jsonrpc: "2.0",
318 id,
319 method: method.to_string(),
320 params,
321 };
322
323 let response_value = match &self.transport {
324 McpTransport::Http {
325 url,
326 headers,
327 client,
328 } => {
329 let mut req = client.post(url).json(&request);
330 for (k, v) in headers {
331 req = req.header(k, v);
332 }
333 let resp = req.send().await.map_err(|e| {
334 AppError::Io(std::io::Error::other(format!(
335 "MCP HTTP request to '{}' failed: {e}",
336 self.name
337 )))
338 })?;
339 if !resp.status().is_success() {
340 let status = resp.status();
341 let body = resp.text().await.unwrap_or_default();
342 return Err(AppError::Io(std::io::Error::other(format!(
343 "MCP server '{}' HTTP {}: {}",
344 self.name, status, body
345 ))));
346 }
347 resp.json::<JsonRpcClientResponse>().await.map_err(|e| {
348 AppError::Io(std::io::Error::other(format!(
349 "MCP server '{}': invalid JSON-RPC response: {e}",
350 self.name
351 )))
352 })?
353 }
354 McpTransport::Stdio { conn } => self.stdio_rpc(conn, &request, method).await?,
355 };
356
357 if let Some(err) = response_value.error {
358 return Err(AppError::Io(std::io::Error::other(format!(
359 "MCP server '{}' RPC error {}: {}",
360 self.name, err.code, err.message
361 ))));
362 }
363
364 response_value.result.ok_or_else(|| {
365 AppError::Io(std::io::Error::other(format!(
366 "MCP server '{}': empty result for method '{}'",
367 self.name, method
368 )))
369 })
370 }
371
372 async fn stdio_rpc(
374 &self,
375 conn: &Arc<tokio::sync::Mutex<StdioConnection>>,
376 request: &JsonRpcClientRequest,
377 method: &str,
378 ) -> Result<JsonRpcClientResponse> {
379 use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
380
381 let mut conn = conn.lock().await;
382
383 let mut line = serde_json::to_string(request).map_err(|e| {
384 AppError::Io(std::io::Error::other(format!(
385 "Failed to serialize JSON-RPC request: {e}"
386 )))
387 })?;
388 line.push('\n');
389 conn.stdin.write_all(line.as_bytes()).await?;
390 conn.stdin.flush().await?;
391
392 let mut buf = String::new();
393 let read = tokio::time::timeout(self.rpc_timeout, conn.stdout.read_line(&mut buf))
394 .await
395 .map_err(|_| {
396 AppError::Io(std::io::Error::other(format!(
397 "MCP server '{}': RPC '{}' timed out after {}s",
398 self.name,
399 method,
400 self.rpc_timeout.as_secs()
401 )))
402 })??;
403
404 if read == 0 {
405 return Err(AppError::Io(std::io::Error::other(format!(
406 "MCP server '{}': EOF while waiting for RPC '{}' response",
407 self.name, method
408 ))));
409 }
410
411 serde_json::from_str(&buf).map_err(|e| {
412 AppError::Io(std::io::Error::other(format!(
413 "MCP server '{}': invalid JSON-RPC response on stdout: {e}",
414 self.name
415 )))
416 })
417 }
418
419 async fn initialize(&self) -> Result<()> {
421 let params = serde_json::json!({
422 "protocolVersion": "2025-11-25",
423 "capabilities": {},
424 "clientInfo": {
425 "name": "gestura",
426 "version": env!("CARGO_PKG_VERSION")
427 }
428 });
429 let _result = self.rpc("initialize", Some(params)).await?;
430
431 let notif = serde_json::json!({
434 "jsonrpc": "2.0",
435 "method": "notifications/initialized"
436 });
437
438 match &self.transport {
439 McpTransport::Http {
440 url,
441 headers,
442 client,
443 } => {
444 let mut req = client.post(url).json(¬if);
445 for (k, v) in headers {
446 req = req.header(k, v);
447 }
448 let _ = req.send().await;
450 }
451 McpTransport::Stdio { conn } => {
452 use tokio::io::AsyncWriteExt;
453 let mut conn = conn.lock().await;
454 let mut line = serde_json::to_string(¬if).unwrap_or_default();
455 line.push('\n');
456 let _ = conn.stdin.write_all(line.as_bytes()).await;
457 let _ = conn.stdin.flush().await;
458 }
459 }
460
461 tracing::debug!("MCP client '{}': initialized", self.name);
462 Ok(())
463 }
464
465 async fn list_tools_rpc(&self) -> Result<Vec<Tool>> {
467 let result = self.rpc("tools/list", None).await?;
468
469 #[derive(Deserialize)]
470 struct ToolsListResponse {
471 tools: Vec<Tool>,
472 }
473
474 let parsed: ToolsListResponse = serde_json::from_value(result).map_err(|e| {
475 AppError::Io(std::io::Error::other(format!(
476 "MCP server '{}': failed to parse tools/list: {e}",
477 self.name
478 )))
479 })?;
480 Ok(parsed.tools)
481 }
482
483 pub async fn call_tool(
485 &self,
486 tool_name: &str,
487 arguments: serde_json::Value,
488 ) -> Result<ToolsCallResult> {
489 let params = serde_json::json!({
490 "name": tool_name,
491 "arguments": arguments
492 });
493 let result = self.rpc("tools/call", Some(params)).await?;
494
495 serde_json::from_value(result).map_err(|e| {
496 AppError::Io(std::io::Error::other(format!(
497 "MCP server '{}': failed to parse tools/call result: {e}",
498 self.name
499 )))
500 })
501 }
502
503 pub async fn get_tools(&self) -> Vec<Tool> {
505 self.tools.read().await.clone()
506 }
507
508 pub async fn refresh_tools(&self) -> Result<Vec<Tool>> {
510 let tools = self.list_tools_rpc().await?;
511 *self.tools.write().await = tools.clone();
512 Ok(tools)
513 }
514}
515
516static MCP_CLIENT_REGISTRY: std::sync::OnceLock<McpClientRegistry> = std::sync::OnceLock::new();
521
522pub fn get_mcp_client_registry() -> &'static McpClientRegistry {
524 MCP_CLIENT_REGISTRY.get_or_init(McpClientRegistry::new)
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530
531 #[test]
532 fn registry_starts_empty() {
533 let registry = McpClientRegistry::new();
534 let rt = tokio::runtime::Runtime::new().unwrap();
535 let servers = rt.block_on(registry.connected_servers());
536 assert!(servers.is_empty());
537 }
538
539 #[test]
540 fn registry_disconnect_nonexistent_is_noop() {
541 let registry = McpClientRegistry::new();
542 let rt = tokio::runtime::Runtime::new().unwrap();
543 rt.block_on(registry.disconnect("nonexistent"));
544 assert!(rt.block_on(registry.connected_servers()).is_empty());
545 }
546}