1#[cfg(test)]
7use super::types::ToolAnnotations;
8use super::types::{Tool, ToolsCapability};
9use crate::execution_mode::ToolCategory;
10use crate::tool_inspection::ToolMetadata;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::RwLock;
14use std::time::{Duration, Instant};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct McpServerConfig {
19 pub name: String,
21 pub uri: String,
23 pub enabled: bool,
25 pub timeout_secs: u64,
27 pub auto_reconnect: bool,
29}
30
31impl Default for McpServerConfig {
32 fn default() -> Self {
33 Self {
34 name: "default".to_string(),
35 uri: String::new(),
36 enabled: true,
37 timeout_secs: 30,
38 auto_reconnect: true,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct CachedTool {
46 pub tool: Tool,
48 pub metadata: ToolMetadata,
50 pub server_name: String,
52 pub cached_at: Instant,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
58pub enum ServerState {
59 Disconnected,
61 Connecting,
63 Connected,
65 Failed,
67}
68
69#[derive(Debug, Clone)]
71pub struct ServerInfo {
72 pub config: McpServerConfig,
74 pub state: ServerState,
76 pub tools_capability: Option<ToolsCapability>,
78 pub tool_count: usize,
80 pub last_connected: Option<Instant>,
82 pub last_error: Option<String>,
84}
85
86pub struct McpDiscoveryManager {
91 servers: RwLock<HashMap<String, ServerInfo>>,
93 tool_cache: RwLock<HashMap<String, CachedTool>>,
95 cache_ttl: Duration,
97}
98
99impl Default for McpDiscoveryManager {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl McpDiscoveryManager {
106 pub fn new() -> Self {
108 Self {
109 servers: RwLock::new(HashMap::new()),
110 tool_cache: RwLock::new(HashMap::new()),
111 cache_ttl: Duration::from_secs(300), }
113 }
114
115 pub fn with_cache_ttl(cache_ttl: Duration) -> Self {
117 Self {
118 servers: RwLock::new(HashMap::new()),
119 tool_cache: RwLock::new(HashMap::new()),
120 cache_ttl,
121 }
122 }
123
124 pub fn register_server(&self, config: McpServerConfig) {
126 if let Ok(mut servers) = self.servers.write() {
127 let info = ServerInfo {
128 config: config.clone(),
129 state: ServerState::Disconnected,
130 tools_capability: None,
131 tool_count: 0,
132 last_connected: None,
133 last_error: None,
134 };
135 servers.insert(config.name.clone(), info);
136 tracing::info!("Registered MCP server: {}", config.name);
137 }
138 }
139
140 pub fn unregister_server(&self, name: &str) {
142 if let Ok(mut servers) = self.servers.write() {
143 servers.remove(name);
144 tracing::info!("Unregistered MCP server: {}", name);
145 }
146 if let Ok(mut cache) = self.tool_cache.write() {
148 cache.retain(|_, v| v.server_name != name);
149 }
150 }
151
152 pub fn list_servers(&self) -> Vec<ServerInfo> {
154 self.servers
155 .read()
156 .map(|s| s.values().cloned().collect())
157 .unwrap_or_default()
158 }
159
160 pub fn get_server(&self, name: &str) -> Option<ServerInfo> {
162 self.servers.read().ok().and_then(|s| s.get(name).cloned())
163 }
164
165 pub fn update_server_state(&self, name: &str, state: ServerState, error: Option<String>) {
167 if let Ok(mut servers) = self.servers.write()
168 && let Some(info) = servers.get_mut(name)
169 {
170 info.state = state;
171 if state == ServerState::Connected {
172 info.last_connected = Some(Instant::now());
173 info.last_error = None;
174 } else if let Some(err) = error {
175 info.last_error = Some(err);
176 }
177 }
178 }
179
180 pub fn cache_tools(&self, server_name: &str, tools: Vec<Tool>) {
182 let now = Instant::now();
183 if let Ok(mut cache) = self.tool_cache.write() {
184 for tool in tools {
185 let metadata = self.derive_metadata(&tool, server_name);
186 let key = format!("{}:{}", server_name, tool.name);
187 cache.insert(
188 key,
189 CachedTool {
190 tool,
191 metadata,
192 server_name: server_name.to_string(),
193 cached_at: now,
194 },
195 );
196 }
197 }
198 if let Ok(mut servers) = self.servers.write()
200 && let Some(info) = servers.get_mut(server_name)
201 {
202 info.tool_count = self
203 .tool_cache
204 .read()
205 .map(|c| c.values().filter(|t| t.server_name == server_name).count())
206 .unwrap_or(0);
207 }
208 }
209
210 fn derive_metadata(&self, tool: &Tool, server_name: &str) -> ToolMetadata {
212 let category = self.infer_category(tool);
213 let risk_level = self.infer_risk_level(tool, category);
214 let has_side_effects = tool
215 .annotations
216 .as_ref()
217 .map(|a| a.destructive_hint || a.open_world_hint)
218 .unwrap_or(category != ToolCategory::ReadOnly);
219
220 ToolMetadata {
221 name: format!("{}:{}", server_name, tool.name),
222 description: tool
223 .description
224 .clone()
225 .unwrap_or_else(|| format!("MCP tool from {}", server_name)),
226 category,
227 has_side_effects,
228 risk_level,
229 required_capabilities: vec!["mcp".to_string(), server_name.to_string()],
230 }
231 }
232
233 fn infer_category(&self, tool: &Tool) -> ToolCategory {
235 let name = tool.name.to_lowercase();
236 let desc = tool
237 .description
238 .as_ref()
239 .map(|s| s.to_lowercase())
240 .unwrap_or_default();
241
242 if let Some(annotations) = &tool.annotations
244 && annotations.destructive_hint
245 {
246 return ToolCategory::Write;
247 }
248
249 if name.contains("read") || name.contains("get") || name.contains("list") {
251 ToolCategory::ReadOnly
252 } else if name.contains("write")
253 || name.contains("create")
254 || name.contains("delete")
255 || name.contains("update")
256 {
257 ToolCategory::Write
258 } else if name.contains("shell")
259 || name.contains("exec")
260 || name.contains("run")
261 || name.contains("command")
262 {
263 ToolCategory::Shell
264 } else if name.contains("git") || desc.contains("git") {
265 ToolCategory::Git
266 } else if name.contains("http")
267 || name.contains("fetch")
268 || name.contains("request")
269 || desc.contains("network")
270 {
271 ToolCategory::Network
272 } else {
273 ToolCategory::Shell
275 }
276 }
277
278 fn infer_risk_level(&self, tool: &Tool, category: ToolCategory) -> u8 {
280 let base_risk = match category {
281 ToolCategory::ReadOnly => 0,
282 ToolCategory::Network => 2,
283 ToolCategory::Write => 4,
284 ToolCategory::Git => 5,
285 ToolCategory::Shell => 7,
286 ToolCategory::System => 9,
287 };
288
289 if let Some(annotations) = &tool.annotations {
291 if annotations.destructive_hint {
292 return (base_risk + 2).min(10);
293 }
294 if annotations.idempotent_hint {
295 return base_risk.saturating_sub(1);
296 }
297 }
298
299 base_risk
300 }
301
302 pub fn list_tools(&self) -> Vec<CachedTool> {
304 let now = Instant::now();
305 self.tool_cache
306 .read()
307 .map(|c| {
308 c.values()
309 .filter(|t| now.duration_since(t.cached_at) < self.cache_ttl)
310 .cloned()
311 .collect()
312 })
313 .unwrap_or_default()
314 }
315
316 pub fn get_tool(&self, server_name: &str, tool_name: &str) -> Option<CachedTool> {
318 let key = format!("{}:{}", server_name, tool_name);
319 self.tool_cache
320 .read()
321 .ok()
322 .and_then(|c| c.get(&key).cloned())
323 }
324
325 pub fn tools_from_server(&self, server_name: &str) -> Vec<CachedTool> {
327 self.tool_cache
328 .read()
329 .map(|c| {
330 c.values()
331 .filter(|t| t.server_name == server_name)
332 .cloned()
333 .collect()
334 })
335 .unwrap_or_default()
336 }
337
338 pub fn clear_expired(&self) {
340 let now = Instant::now();
341 if let Ok(mut cache) = self.tool_cache.write() {
342 cache.retain(|_, v| now.duration_since(v.cached_at) < self.cache_ttl);
343 }
344 }
345
346 pub fn clear_cache(&self) {
348 if let Ok(mut cache) = self.tool_cache.write() {
349 cache.clear();
350 }
351 }
352
353 pub fn cache_stats(&self) -> CacheStats {
355 let now = Instant::now();
356 let (total, expired) = self
357 .tool_cache
358 .read()
359 .map(|c| {
360 let total = c.len();
361 let expired = c
362 .values()
363 .filter(|t| now.duration_since(t.cached_at) >= self.cache_ttl)
364 .count();
365 (total, expired)
366 })
367 .unwrap_or((0, 0));
368
369 CacheStats {
370 total_tools: total,
371 expired_tools: expired,
372 server_count: self.servers.read().map(|s| s.len()).unwrap_or(0),
373 }
374 }
375}
376
377#[derive(Debug, Clone, Serialize, Deserialize)]
379pub struct CacheStats {
380 pub total_tools: usize,
382 pub expired_tools: usize,
384 pub server_count: usize,
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_server_registration() {
394 let manager = McpDiscoveryManager::new();
395
396 let config = McpServerConfig {
397 name: "test-server".to_string(),
398 uri: "stdio://test".to_string(),
399 ..Default::default()
400 };
401
402 manager.register_server(config);
403
404 let servers = manager.list_servers();
405 assert_eq!(servers.len(), 1);
406 assert_eq!(servers[0].config.name, "test-server");
407 assert_eq!(servers[0].state, ServerState::Disconnected);
408 }
409
410 #[test]
411 fn test_tool_caching() {
412 let manager = McpDiscoveryManager::new();
413
414 let config = McpServerConfig {
415 name: "test-server".to_string(),
416 uri: "stdio://test".to_string(),
417 ..Default::default()
418 };
419 manager.register_server(config);
420
421 let tools = vec![
422 Tool {
423 name: "read_file".to_string(),
424 description: Some("Read a file".to_string()),
425 input_schema: serde_json::json!({}),
426 annotations: None,
427 },
428 Tool {
429 name: "write_file".to_string(),
430 description: Some("Write to a file".to_string()),
431 input_schema: serde_json::json!({}),
432 annotations: Some(ToolAnnotations {
433 destructive_hint: true,
434 ..Default::default()
435 }),
436 },
437 ];
438
439 manager.cache_tools("test-server", tools);
440
441 let cached = manager.list_tools();
442 assert_eq!(cached.len(), 2);
443
444 let read_tool = manager.get_tool("test-server", "read_file").unwrap();
446 assert_eq!(read_tool.metadata.category, ToolCategory::ReadOnly);
447
448 let write_tool = manager.get_tool("test-server", "write_file").unwrap();
449 assert_eq!(write_tool.metadata.category, ToolCategory::Write);
450 assert!(write_tool.metadata.has_side_effects);
451 }
452
453 #[test]
454 fn test_category_inference() {
455 let manager = McpDiscoveryManager::new();
456
457 let shell_tool = Tool {
458 name: "run_command".to_string(),
459 description: Some("Execute a shell command".to_string()),
460 input_schema: serde_json::json!({}),
461 annotations: None,
462 };
463 assert_eq!(manager.infer_category(&shell_tool), ToolCategory::Shell);
464
465 let git_tool = Tool {
466 name: "git_status".to_string(),
467 description: Some("Get git status".to_string()),
468 input_schema: serde_json::json!({}),
469 annotations: None,
470 };
471 assert_eq!(manager.infer_category(&git_tool), ToolCategory::Git);
472 }
473
474 #[test]
475 fn test_cache_stats() {
476 let manager = McpDiscoveryManager::new();
477
478 let config = McpServerConfig {
479 name: "test".to_string(),
480 uri: "stdio://test".to_string(),
481 ..Default::default()
482 };
483 manager.register_server(config);
484
485 let stats = manager.cache_stats();
486 assert_eq!(stats.server_count, 1);
487 assert_eq!(stats.total_tools, 0);
488 }
489}