gestura_core_mcp/
client.rs

1//! MCP Client — connects to external MCP servers and invokes tools.
2//!
3//! Supports two transports:
4//! - **stdio**: spawns a child process, communicates via JSON-RPC over stdin/stdout.
5//! - **http**: sends JSON-RPC requests over HTTP POST (Streamable HTTP transport).
6//!
7//! The client performs the MCP initialize/initialized handshake, discovers tools
8//! via `tools/list`, and invokes tools via `tools/call`.
9
10use crate::config::McpServerEntry;
11use crate::error::{AppError, Result};
12use crate::types::{Tool, ToolsCallResult};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::time::Duration;
18use tokio::sync::RwLock;
19
20/// JSON-RPC request envelope used by the MCP client.
21#[derive(Debug, Serialize)]
22struct JsonRpcClientRequest {
23    jsonrpc: &'static str,
24    id: u64,
25    method: String,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    params: Option<serde_json::Value>,
28}
29
30/// JSON-RPC response envelope received from MCP servers.
31#[derive(Debug, Deserialize)]
32struct JsonRpcClientResponse {
33    #[allow(dead_code)]
34    jsonrpc: String,
35    #[allow(dead_code)]
36    id: Option<serde_json::Value>,
37    result: Option<serde_json::Value>,
38    error: Option<JsonRpcClientError>,
39}
40
41/// JSON-RPC error object.
42#[derive(Debug, Deserialize)]
43struct JsonRpcClientError {
44    code: i64,
45    message: String,
46    #[allow(dead_code)]
47    data: Option<serde_json::Value>,
48}
49
50/// An active connection to a single MCP server.
51#[derive(Debug)]
52pub struct McpClient {
53    /// Server name (from config).
54    pub name: String,
55    /// Transport backend.
56    transport: McpTransport,
57    /// Monotonically increasing request ID.
58    next_id: AtomicU64,
59    /// Per-RPC timeout used for both HTTP and stdio transports.
60    rpc_timeout: Duration,
61    /// Tools discovered from this server.
62    tools: RwLock<Vec<Tool>>,
63}
64
65/// Transport backend for an MCP client connection.
66#[derive(Debug)]
67enum McpTransport {
68    /// HTTP transport — uses a shared reqwest client.
69    Http {
70        url: String,
71        headers: HashMap<String, String>,
72        client: reqwest::Client,
73    },
74    /// Stdio transport — owns a child process handle.
75    Stdio {
76        conn: Arc<tokio::sync::Mutex<StdioConnection>>,
77    },
78}
79
80/// Persistent stdio connection state for a single MCP server.
81///
82/// We keep a single `BufReader` for stdout across requests. Re-creating and
83/// dropping a `BufReader` per RPC can discard buffered bytes and lead to
84/// seemingly-random hangs on subsequent requests.
85#[derive(Debug)]
86struct StdioConnection {
87    child: tokio::process::Child,
88    stdin: tokio::process::ChildStdin,
89    stdout: tokio::io::BufReader<tokio::process::ChildStdout>,
90}
91
92impl Drop for StdioConnection {
93    fn drop(&mut self) {
94        #[cfg(unix)]
95        {
96            if let Some(pid) = self.child.id() {
97                // Send SIGKILL to the process group (negative PID).
98                // Requires libc or simple command execution to kill process subtree if possible.
99                // We will use standard kill command to be safe and avoid pulling in `libc` directly.
100                let _ = std::process::Command::new("kill")
101                    .arg("-9")
102                    .arg(format!("-{}", pid))
103                    .output();
104            }
105        }
106        // Fallback for non-unix or immediate child stop
107        let _ = self.child.start_kill();
108    }
109}
110
111/// Global registry of active MCP client connections, keyed by server name.
112pub struct McpClientRegistry {
113    clients: RwLock<HashMap<String, Arc<McpClient>>>,
114}
115
116impl Default for McpClientRegistry {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl McpClientRegistry {
123    /// Create an empty registry.
124    pub fn new() -> Self {
125        Self {
126            clients: RwLock::new(HashMap::new()),
127        }
128    }
129
130    /// Connect to an MCP server described by `entry`.
131    ///
132    /// On success the client is stored in the registry and its discovered tools
133    /// are returned.
134    pub async fn connect(&self, entry: &McpServerEntry) -> Result<Vec<Tool>> {
135        if !entry.enabled {
136            return Err(AppError::Io(std::io::Error::other(format!(
137                "MCP server '{}' is disabled",
138                entry.name
139            ))));
140        }
141
142        let client = McpClient::connect(entry).await?;
143        let tools = client.tools.read().await.clone();
144        self.clients
145            .write()
146            .await
147            .insert(entry.name.clone(), Arc::new(client));
148        Ok(tools)
149    }
150
151    /// Get an active client by server name.
152    pub async fn get(&self, name: &str) -> Option<Arc<McpClient>> {
153        self.clients.read().await.get(name).cloned()
154    }
155
156    /// Remove and drop a client connection.
157    pub async fn disconnect(&self, name: &str) {
158        self.clients.write().await.remove(name);
159    }
160
161    /// List all connected server names.
162    pub async fn connected_servers(&self) -> Vec<String> {
163        self.clients.read().await.keys().cloned().collect()
164    }
165
166    /// Get all discovered tools across all connected servers.
167    pub async fn all_tools(&self) -> Vec<(String, Vec<Tool>)> {
168        let clients = self.clients.read().await;
169        let mut out = Vec::with_capacity(clients.len());
170        for (name, client) in clients.iter() {
171            let tools = client.tools.read().await.clone();
172            out.push((name.clone(), tools));
173        }
174        out
175    }
176
177    /// Call a tool on a specific server.
178    pub async fn call_tool(
179        &self,
180        server_name: &str,
181        tool_name: &str,
182        arguments: serde_json::Value,
183    ) -> Result<ToolsCallResult> {
184        let client = self.get(server_name).await.ok_or_else(|| {
185            AppError::Io(std::io::Error::other(format!(
186                "MCP server '{}' is not connected",
187                server_name
188            )))
189        })?;
190        client.call_tool(tool_name, arguments).await
191    }
192}
193
194// ============================================================================
195// McpClient implementation
196// ============================================================================
197
198impl McpClient {
199    /// Connect to an MCP server, perform the initialize handshake, and
200    /// discover tools.
201    pub async fn connect(entry: &McpServerEntry) -> Result<Self> {
202        use crate::config::McpTransportType;
203
204        let rpc_timeout = Duration::from_secs(entry.timeout_secs.max(1));
205
206        let transport = match entry.transport {
207            McpTransportType::Http | McpTransportType::Sse => {
208                let url = entry.url.clone().ok_or_else(|| {
209                    AppError::Io(std::io::Error::other(format!(
210                        "MCP server '{}': HTTP transport requires a url",
211                        entry.name
212                    )))
213                })?;
214                McpTransport::Http {
215                    url,
216                    headers: entry.headers.clone(),
217                    client: reqwest::Client::builder()
218                        .timeout(rpc_timeout)
219                        .build()
220                        .map_err(|e| {
221                            AppError::Io(std::io::Error::other(format!(
222                                "Failed to create HTTP client: {e}"
223                            )))
224                        })?,
225                }
226            }
227            McpTransportType::Stdio => {
228                let command = entry.command.as_deref().ok_or_else(|| {
229                    AppError::Io(std::io::Error::other(format!(
230                        "MCP server '{}': stdio transport requires a command",
231                        entry.name
232                    )))
233                })?;
234
235                let resolved_command = crate::cmd_utils::resolve_mcp_command(command);
236                let mut envs = entry.env.clone();
237                crate::cmd_utils::inject_enriched_path(&mut envs);
238
239                let mut cmd = tokio::process::Command::new(resolved_command.clone());
240                cmd.args(&entry.args)
241                    .envs(&envs)
242                    .kill_on_drop(true)
243                    .stdin(std::process::Stdio::piped())
244                    .stdout(std::process::Stdio::piped())
245                    .stderr(std::process::Stdio::null());
246
247                // On Unix platforms, create a new process group so that we can kill the entire tree
248                #[cfg(unix)]
249                {
250                    cmd.process_group(0);
251                }
252
253                let mut child = cmd.spawn().map_err(|e| {
254                    AppError::Io(std::io::Error::other(format!(
255                        "Failed to spawn MCP server '{}' ({}): {e}",
256                        entry.name, resolved_command
257                    )))
258                })?;
259
260                let stdin = child.stdin.take().ok_or_else(|| {
261                    AppError::Io(std::io::Error::other(format!(
262                        "MCP server '{}': stdin not available",
263                        entry.name
264                    )))
265                })?;
266                let stdout = child.stdout.take().ok_or_else(|| {
267                    AppError::Io(std::io::Error::other(format!(
268                        "MCP server '{}': stdout not available",
269                        entry.name
270                    )))
271                })?;
272
273                let conn = StdioConnection {
274                    child,
275                    stdin,
276                    stdout: tokio::io::BufReader::new(stdout),
277                };
278
279                McpTransport::Stdio {
280                    conn: Arc::new(tokio::sync::Mutex::new(conn)),
281                }
282            }
283        };
284
285        let client = Self {
286            name: entry.name.clone(),
287            transport,
288            next_id: AtomicU64::new(1),
289            rpc_timeout,
290            tools: RwLock::new(Vec::new()),
291        };
292
293        // Perform MCP initialize handshake
294        client.initialize().await?;
295
296        // Discover tools
297        let tools = client.list_tools_rpc().await?;
298        *client.tools.write().await = tools;
299
300        tracing::info!(
301            "MCP client '{}': connected, {} tools discovered",
302            client.name,
303            client.tools.read().await.len()
304        );
305
306        Ok(client)
307    }
308
309    /// Send a JSON-RPC request and return the result value.
310    async fn rpc(
311        &self,
312        method: &str,
313        params: Option<serde_json::Value>,
314    ) -> Result<serde_json::Value> {
315        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
316        let request = JsonRpcClientRequest {
317            jsonrpc: "2.0",
318            id,
319            method: method.to_string(),
320            params,
321        };
322
323        let response_value = match &self.transport {
324            McpTransport::Http {
325                url,
326                headers,
327                client,
328            } => {
329                let mut req = client.post(url).json(&request);
330                for (k, v) in headers {
331                    req = req.header(k, v);
332                }
333                let resp = req.send().await.map_err(|e| {
334                    AppError::Io(std::io::Error::other(format!(
335                        "MCP HTTP request to '{}' failed: {e}",
336                        self.name
337                    )))
338                })?;
339                if !resp.status().is_success() {
340                    let status = resp.status();
341                    let body = resp.text().await.unwrap_or_default();
342                    return Err(AppError::Io(std::io::Error::other(format!(
343                        "MCP server '{}' HTTP {}: {}",
344                        self.name, status, body
345                    ))));
346                }
347                resp.json::<JsonRpcClientResponse>().await.map_err(|e| {
348                    AppError::Io(std::io::Error::other(format!(
349                        "MCP server '{}': invalid JSON-RPC response: {e}",
350                        self.name
351                    )))
352                })?
353            }
354            McpTransport::Stdio { conn } => self.stdio_rpc(conn, &request, method).await?,
355        };
356
357        if let Some(err) = response_value.error {
358            return Err(AppError::Io(std::io::Error::other(format!(
359                "MCP server '{}' RPC error {}: {}",
360                self.name, err.code, err.message
361            ))));
362        }
363
364        response_value.result.ok_or_else(|| {
365            AppError::Io(std::io::Error::other(format!(
366                "MCP server '{}': empty result for method '{}'",
367                self.name, method
368            )))
369        })
370    }
371
372    /// Send/receive a single JSON-RPC message over a child process stdio.
373    async fn stdio_rpc(
374        &self,
375        conn: &Arc<tokio::sync::Mutex<StdioConnection>>,
376        request: &JsonRpcClientRequest,
377        method: &str,
378    ) -> Result<JsonRpcClientResponse> {
379        use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
380
381        let mut conn = conn.lock().await;
382
383        let mut line = serde_json::to_string(request).map_err(|e| {
384            AppError::Io(std::io::Error::other(format!(
385                "Failed to serialize JSON-RPC request: {e}"
386            )))
387        })?;
388        line.push('\n');
389        conn.stdin.write_all(line.as_bytes()).await?;
390        conn.stdin.flush().await?;
391
392        let mut buf = String::new();
393        let read = tokio::time::timeout(self.rpc_timeout, conn.stdout.read_line(&mut buf))
394            .await
395            .map_err(|_| {
396                AppError::Io(std::io::Error::other(format!(
397                    "MCP server '{}': RPC '{}' timed out after {}s",
398                    self.name,
399                    method,
400                    self.rpc_timeout.as_secs()
401                )))
402            })??;
403
404        if read == 0 {
405            return Err(AppError::Io(std::io::Error::other(format!(
406                "MCP server '{}': EOF while waiting for RPC '{}' response",
407                self.name, method
408            ))));
409        }
410
411        serde_json::from_str(&buf).map_err(|e| {
412            AppError::Io(std::io::Error::other(format!(
413                "MCP server '{}': invalid JSON-RPC response on stdout: {e}",
414                self.name
415            )))
416        })
417    }
418
419    /// Perform the MCP `initialize` / `notifications/initialized` handshake.
420    async fn initialize(&self) -> Result<()> {
421        let params = serde_json::json!({
422            "protocolVersion": "2025-11-25",
423            "capabilities": {},
424            "clientInfo": {
425                "name": "gestura",
426                "version": env!("CARGO_PKG_VERSION")
427            }
428        });
429        let _result = self.rpc("initialize", Some(params)).await?;
430
431        // Send `notifications/initialized` (no id, no result expected).
432        // For HTTP this is fire-and-forget; for stdio we write but don't read.
433        let notif = serde_json::json!({
434            "jsonrpc": "2.0",
435            "method": "notifications/initialized"
436        });
437
438        match &self.transport {
439            McpTransport::Http {
440                url,
441                headers,
442                client,
443            } => {
444                let mut req = client.post(url).json(&notif);
445                for (k, v) in headers {
446                    req = req.header(k, v);
447                }
448                // Fire-and-forget — ignore errors.
449                let _ = req.send().await;
450            }
451            McpTransport::Stdio { conn } => {
452                use tokio::io::AsyncWriteExt;
453                let mut conn = conn.lock().await;
454                let mut line = serde_json::to_string(&notif).unwrap_or_default();
455                line.push('\n');
456                let _ = conn.stdin.write_all(line.as_bytes()).await;
457                let _ = conn.stdin.flush().await;
458            }
459        }
460
461        tracing::debug!("MCP client '{}': initialized", self.name);
462        Ok(())
463    }
464
465    /// Discover tools from the server via `tools/list`.
466    async fn list_tools_rpc(&self) -> Result<Vec<Tool>> {
467        let result = self.rpc("tools/list", None).await?;
468
469        #[derive(Deserialize)]
470        struct ToolsListResponse {
471            tools: Vec<Tool>,
472        }
473
474        let parsed: ToolsListResponse = serde_json::from_value(result).map_err(|e| {
475            AppError::Io(std::io::Error::other(format!(
476                "MCP server '{}': failed to parse tools/list: {e}",
477                self.name
478            )))
479        })?;
480        Ok(parsed.tools)
481    }
482
483    /// Invoke a tool on this server via `tools/call`.
484    pub async fn call_tool(
485        &self,
486        tool_name: &str,
487        arguments: serde_json::Value,
488    ) -> Result<ToolsCallResult> {
489        let params = serde_json::json!({
490            "name": tool_name,
491            "arguments": arguments
492        });
493        let result = self.rpc("tools/call", Some(params)).await?;
494
495        serde_json::from_value(result).map_err(|e| {
496            AppError::Io(std::io::Error::other(format!(
497                "MCP server '{}': failed to parse tools/call result: {e}",
498                self.name
499            )))
500        })
501    }
502
503    /// Get the list of discovered tools (cached from the last `tools/list`).
504    pub async fn get_tools(&self) -> Vec<Tool> {
505        self.tools.read().await.clone()
506    }
507
508    /// Refresh the tool list from the server.
509    pub async fn refresh_tools(&self) -> Result<Vec<Tool>> {
510        let tools = self.list_tools_rpc().await?;
511        *self.tools.write().await = tools.clone();
512        Ok(tools)
513    }
514}
515
516// ============================================================================
517// Global singleton for the MCP client registry
518// ============================================================================
519
520static MCP_CLIENT_REGISTRY: std::sync::OnceLock<McpClientRegistry> = std::sync::OnceLock::new();
521
522/// Get the global MCP client registry.
523pub fn get_mcp_client_registry() -> &'static McpClientRegistry {
524    MCP_CLIENT_REGISTRY.get_or_init(McpClientRegistry::new)
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    #[test]
532    fn registry_starts_empty() {
533        let registry = McpClientRegistry::new();
534        let rt = tokio::runtime::Runtime::new().unwrap();
535        let servers = rt.block_on(registry.connected_servers());
536        assert!(servers.is_empty());
537    }
538
539    #[test]
540    fn registry_disconnect_nonexistent_is_noop() {
541        let registry = McpClientRegistry::new();
542        let rt = tokio::runtime::Runtime::new().unwrap();
543        rt.block_on(registry.disconnect("nonexistent"));
544        assert!(rt.block_on(registry.connected_servers()).is_empty());
545    }
546}