gestura_core_mcp/
discovery.rs

1//! MCP Tool Discovery and Registry
2//!
3//! Provides unified tool discovery from external MCP servers, capability negotiation,
4//! and tool metadata caching for performance.
5
6#[cfg(test)]
7use super::types::ToolAnnotations;
8use super::types::{Tool, ToolsCapability};
9use crate::execution_mode::ToolCategory;
10use crate::tool_inspection::ToolMetadata;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::RwLock;
14use std::time::{Duration, Instant};
15
16/// Configuration for MCP server connection
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct McpServerConfig {
19    /// Server name/identifier
20    pub name: String,
21    /// Server URI (e.g., "stdio://path/to/server" or "http://localhost:3000")
22    pub uri: String,
23    /// Whether this server is enabled
24    pub enabled: bool,
25    /// Connection timeout in seconds
26    pub timeout_secs: u64,
27    /// Auto-reconnect on failure
28    pub auto_reconnect: bool,
29}
30
31impl Default for McpServerConfig {
32    fn default() -> Self {
33        Self {
34            name: "default".to_string(),
35            uri: String::new(),
36            enabled: true,
37            timeout_secs: 30,
38            auto_reconnect: true,
39        }
40    }
41}
42
43/// Cached tool information from an MCP server
44#[derive(Debug, Clone)]
45pub struct CachedTool {
46    /// The MCP tool definition
47    pub tool: Tool,
48    /// Derived metadata for permission checking
49    pub metadata: ToolMetadata,
50    /// Source server name
51    pub server_name: String,
52    /// When this was cached
53    pub cached_at: Instant,
54}
55
56/// MCP server connection state
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58pub enum ServerState {
59    /// Not connected
60    Disconnected,
61    /// Connecting
62    Connecting,
63    /// Connected and ready
64    Connected,
65    /// Connection failed
66    Failed,
67}
68
69/// Information about a connected MCP server
70#[derive(Debug, Clone)]
71pub struct ServerInfo {
72    /// Server configuration
73    pub config: McpServerConfig,
74    /// Current connection state
75    pub state: ServerState,
76    /// Server's advertised tools capability
77    pub tools_capability: Option<ToolsCapability>,
78    /// Number of tools available
79    pub tool_count: usize,
80    /// Last successful connection time
81    pub last_connected: Option<Instant>,
82    /// Last error message
83    pub last_error: Option<String>,
84}
85
86/// MCP Tool Discovery Manager
87///
88/// Manages connections to external MCP servers, discovers available tools,
89/// and caches tool metadata for performance.
90pub struct McpDiscoveryManager {
91    /// Registered MCP servers
92    servers: RwLock<HashMap<String, ServerInfo>>,
93    /// Cached tools from all servers
94    tool_cache: RwLock<HashMap<String, CachedTool>>,
95    /// Cache TTL
96    cache_ttl: Duration,
97}
98
99impl Default for McpDiscoveryManager {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl McpDiscoveryManager {
106    /// Create a new discovery manager
107    pub fn new() -> Self {
108        Self {
109            servers: RwLock::new(HashMap::new()),
110            tool_cache: RwLock::new(HashMap::new()),
111            cache_ttl: Duration::from_secs(300), // 5 minutes default
112        }
113    }
114
115    /// Create with custom cache TTL
116    pub fn with_cache_ttl(cache_ttl: Duration) -> Self {
117        Self {
118            servers: RwLock::new(HashMap::new()),
119            tool_cache: RwLock::new(HashMap::new()),
120            cache_ttl,
121        }
122    }
123
124    /// Register an MCP server
125    pub fn register_server(&self, config: McpServerConfig) {
126        if let Ok(mut servers) = self.servers.write() {
127            let info = ServerInfo {
128                config: config.clone(),
129                state: ServerState::Disconnected,
130                tools_capability: None,
131                tool_count: 0,
132                last_connected: None,
133                last_error: None,
134            };
135            servers.insert(config.name.clone(), info);
136            tracing::info!("Registered MCP server: {}", config.name);
137        }
138    }
139
140    /// Unregister an MCP server
141    pub fn unregister_server(&self, name: &str) {
142        if let Ok(mut servers) = self.servers.write() {
143            servers.remove(name);
144            tracing::info!("Unregistered MCP server: {}", name);
145        }
146        // Also remove cached tools from this server
147        if let Ok(mut cache) = self.tool_cache.write() {
148            cache.retain(|_, v| v.server_name != name);
149        }
150    }
151
152    /// Get all registered servers
153    pub fn list_servers(&self) -> Vec<ServerInfo> {
154        self.servers
155            .read()
156            .map(|s| s.values().cloned().collect())
157            .unwrap_or_default()
158    }
159
160    /// Get server info by name
161    pub fn get_server(&self, name: &str) -> Option<ServerInfo> {
162        self.servers.read().ok().and_then(|s| s.get(name).cloned())
163    }
164
165    /// Update server state
166    pub fn update_server_state(&self, name: &str, state: ServerState, error: Option<String>) {
167        if let Ok(mut servers) = self.servers.write()
168            && let Some(info) = servers.get_mut(name)
169        {
170            info.state = state;
171            if state == ServerState::Connected {
172                info.last_connected = Some(Instant::now());
173                info.last_error = None;
174            } else if let Some(err) = error {
175                info.last_error = Some(err);
176            }
177        }
178    }
179
180    /// Cache tools from a server
181    pub fn cache_tools(&self, server_name: &str, tools: Vec<Tool>) {
182        let now = Instant::now();
183        if let Ok(mut cache) = self.tool_cache.write() {
184            for tool in tools {
185                let metadata = self.derive_metadata(&tool, server_name);
186                let key = format!("{}:{}", server_name, tool.name);
187                cache.insert(
188                    key,
189                    CachedTool {
190                        tool,
191                        metadata,
192                        server_name: server_name.to_string(),
193                        cached_at: now,
194                    },
195                );
196            }
197        }
198        // Update server tool count
199        if let Ok(mut servers) = self.servers.write()
200            && let Some(info) = servers.get_mut(server_name)
201        {
202            info.tool_count = self
203                .tool_cache
204                .read()
205                .map(|c| c.values().filter(|t| t.server_name == server_name).count())
206                .unwrap_or(0);
207        }
208    }
209
210    /// Derive ToolMetadata from MCP Tool definition
211    fn derive_metadata(&self, tool: &Tool, server_name: &str) -> ToolMetadata {
212        let category = self.infer_category(tool);
213        let risk_level = self.infer_risk_level(tool, category);
214        let has_side_effects = tool
215            .annotations
216            .as_ref()
217            .map(|a| a.destructive_hint || a.open_world_hint)
218            .unwrap_or(category != ToolCategory::ReadOnly);
219
220        ToolMetadata {
221            name: format!("{}:{}", server_name, tool.name),
222            description: tool
223                .description
224                .clone()
225                .unwrap_or_else(|| format!("MCP tool from {}", server_name)),
226            category,
227            has_side_effects,
228            risk_level,
229            required_capabilities: vec!["mcp".to_string(), server_name.to_string()],
230        }
231    }
232
233    /// Infer tool category from MCP tool definition
234    fn infer_category(&self, tool: &Tool) -> ToolCategory {
235        let name = tool.name.to_lowercase();
236        let desc = tool
237            .description
238            .as_ref()
239            .map(|s| s.to_lowercase())
240            .unwrap_or_default();
241
242        // Check annotations first
243        if let Some(annotations) = &tool.annotations
244            && annotations.destructive_hint
245        {
246            return ToolCategory::Write;
247        }
248
249        // Infer from name/description
250        if name.contains("read") || name.contains("get") || name.contains("list") {
251            ToolCategory::ReadOnly
252        } else if name.contains("write")
253            || name.contains("create")
254            || name.contains("delete")
255            || name.contains("update")
256        {
257            ToolCategory::Write
258        } else if name.contains("shell")
259            || name.contains("exec")
260            || name.contains("run")
261            || name.contains("command")
262        {
263            ToolCategory::Shell
264        } else if name.contains("git") || desc.contains("git") {
265            ToolCategory::Git
266        } else if name.contains("http")
267            || name.contains("fetch")
268            || name.contains("request")
269            || desc.contains("network")
270        {
271            ToolCategory::Network
272        } else {
273            // Default to Shell (most restrictive) for unknown tools
274            ToolCategory::Shell
275        }
276    }
277
278    /// Infer risk level from tool definition
279    fn infer_risk_level(&self, tool: &Tool, category: ToolCategory) -> u8 {
280        let base_risk = match category {
281            ToolCategory::ReadOnly => 0,
282            ToolCategory::Network => 2,
283            ToolCategory::Write => 4,
284            ToolCategory::Git => 5,
285            ToolCategory::Shell => 7,
286            ToolCategory::System => 9,
287        };
288
289        // Adjust based on annotations
290        if let Some(annotations) = &tool.annotations {
291            if annotations.destructive_hint {
292                return (base_risk + 2).min(10);
293            }
294            if annotations.idempotent_hint {
295                return base_risk.saturating_sub(1);
296            }
297        }
298
299        base_risk
300    }
301
302    /// Get all cached tools
303    pub fn list_tools(&self) -> Vec<CachedTool> {
304        let now = Instant::now();
305        self.tool_cache
306            .read()
307            .map(|c| {
308                c.values()
309                    .filter(|t| now.duration_since(t.cached_at) < self.cache_ttl)
310                    .cloned()
311                    .collect()
312            })
313            .unwrap_or_default()
314    }
315
316    /// Get a specific tool by server:name
317    pub fn get_tool(&self, server_name: &str, tool_name: &str) -> Option<CachedTool> {
318        let key = format!("{}:{}", server_name, tool_name);
319        self.tool_cache
320            .read()
321            .ok()
322            .and_then(|c| c.get(&key).cloned())
323    }
324
325    /// Get all tools from a specific server
326    pub fn tools_from_server(&self, server_name: &str) -> Vec<CachedTool> {
327        self.tool_cache
328            .read()
329            .map(|c| {
330                c.values()
331                    .filter(|t| t.server_name == server_name)
332                    .cloned()
333                    .collect()
334            })
335            .unwrap_or_default()
336    }
337
338    /// Clear expired cache entries
339    pub fn clear_expired(&self) {
340        let now = Instant::now();
341        if let Ok(mut cache) = self.tool_cache.write() {
342            cache.retain(|_, v| now.duration_since(v.cached_at) < self.cache_ttl);
343        }
344    }
345
346    /// Clear all cached tools
347    pub fn clear_cache(&self) {
348        if let Ok(mut cache) = self.tool_cache.write() {
349            cache.clear();
350        }
351    }
352
353    /// Get cache statistics
354    pub fn cache_stats(&self) -> CacheStats {
355        let now = Instant::now();
356        let (total, expired) = self
357            .tool_cache
358            .read()
359            .map(|c| {
360                let total = c.len();
361                let expired = c
362                    .values()
363                    .filter(|t| now.duration_since(t.cached_at) >= self.cache_ttl)
364                    .count();
365                (total, expired)
366            })
367            .unwrap_or((0, 0));
368
369        CacheStats {
370            total_tools: total,
371            expired_tools: expired,
372            server_count: self.servers.read().map(|s| s.len()).unwrap_or(0),
373        }
374    }
375}
376
377/// Cache statistics
378#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct CacheStats {
380    /// Total cached tools
381    pub total_tools: usize,
382    /// Expired tools (not yet cleaned)
383    pub expired_tools: usize,
384    /// Number of registered servers
385    pub server_count: usize,
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_server_registration() {
394        let manager = McpDiscoveryManager::new();
395
396        let config = McpServerConfig {
397            name: "test-server".to_string(),
398            uri: "stdio://test".to_string(),
399            ..Default::default()
400        };
401
402        manager.register_server(config);
403
404        let servers = manager.list_servers();
405        assert_eq!(servers.len(), 1);
406        assert_eq!(servers[0].config.name, "test-server");
407        assert_eq!(servers[0].state, ServerState::Disconnected);
408    }
409
410    #[test]
411    fn test_tool_caching() {
412        let manager = McpDiscoveryManager::new();
413
414        let config = McpServerConfig {
415            name: "test-server".to_string(),
416            uri: "stdio://test".to_string(),
417            ..Default::default()
418        };
419        manager.register_server(config);
420
421        let tools = vec![
422            Tool {
423                name: "read_file".to_string(),
424                description: Some("Read a file".to_string()),
425                input_schema: serde_json::json!({}),
426                annotations: None,
427            },
428            Tool {
429                name: "write_file".to_string(),
430                description: Some("Write to a file".to_string()),
431                input_schema: serde_json::json!({}),
432                annotations: Some(ToolAnnotations {
433                    destructive_hint: true,
434                    ..Default::default()
435                }),
436            },
437        ];
438
439        manager.cache_tools("test-server", tools);
440
441        let cached = manager.list_tools();
442        assert_eq!(cached.len(), 2);
443
444        // Check category inference
445        let read_tool = manager.get_tool("test-server", "read_file").unwrap();
446        assert_eq!(read_tool.metadata.category, ToolCategory::ReadOnly);
447
448        let write_tool = manager.get_tool("test-server", "write_file").unwrap();
449        assert_eq!(write_tool.metadata.category, ToolCategory::Write);
450        assert!(write_tool.metadata.has_side_effects);
451    }
452
453    #[test]
454    fn test_category_inference() {
455        let manager = McpDiscoveryManager::new();
456
457        let shell_tool = Tool {
458            name: "run_command".to_string(),
459            description: Some("Execute a shell command".to_string()),
460            input_schema: serde_json::json!({}),
461            annotations: None,
462        };
463        assert_eq!(manager.infer_category(&shell_tool), ToolCategory::Shell);
464
465        let git_tool = Tool {
466            name: "git_status".to_string(),
467            description: Some("Get git status".to_string()),
468            input_schema: serde_json::json!({}),
469            annotations: None,
470        };
471        assert_eq!(manager.infer_category(&git_tool), ToolCategory::Git);
472    }
473
474    #[test]
475    fn test_cache_stats() {
476        let manager = McpDiscoveryManager::new();
477
478        let config = McpServerConfig {
479            name: "test".to_string(),
480            uri: "stdio://test".to_string(),
481            ..Default::default()
482        };
483        manager.register_server(config);
484
485        let stats = manager.cache_stats();
486        assert_eq!(stats.server_count, 1);
487        assert_eq!(stats.total_tools, 0);
488    }
489}