1use std::collections::HashSet;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use dashmap::DashMap;
11use serde_json::Value;
12
13use crate::config::AppConfig;
14use crate::llm_provider::{AgentContext, select_provider};
15use crate::tools::registry::ToolDefinition;
16use gestura_core_pipeline::types::ToolRoutingStrategy;
17
18#[derive(Debug, Clone)]
24pub struct RoutingResult {
25 pub suggested_tools: Vec<String>,
30 pub confidence: f32,
34}
35
36impl RoutingResult {
37 pub fn fallthrough() -> Self {
39 Self {
40 suggested_tools: Vec::new(),
41 confidence: 0.0,
42 }
43 }
44
45 pub fn has_selection(&self) -> bool {
47 !self.suggested_tools.is_empty()
48 }
49}
50
51#[async_trait]
61pub trait ToolRouter: Send + Sync {
62 async fn route(
71 &self,
72 request: &str,
73 tools: &[&'static ToolDefinition],
74 keyword_confidence: f32,
75 ) -> RoutingResult;
76}
77
78type CacheKey = [u8; 32];
84
85fn cache_key(request: &str) -> CacheKey {
86 *blake3::hash(request.trim().to_lowercase().as_bytes()).as_bytes()
87}
88
89pub struct LlmToolRouter {
96 config: Arc<AppConfig>,
97 cache: DashMap<CacheKey, Arc<RoutingResult>>,
98}
99
100impl LlmToolRouter {
101 pub fn new(config: Arc<AppConfig>) -> Self {
103 Self {
104 config,
105 cache: DashMap::new(),
106 }
107 }
108
109 fn build_routing_prompt(request: &str, tools: &[&'static ToolDefinition]) -> String {
111 let mut prompt = String::from(
112 "You are a tool selector. Given a user request and a list of tools, \
113 respond with ONLY a JSON array of tool names needed to fulfill the request.\n\
114 Choose the minimal set (1-4 tools). If no tool is needed respond with [].\n\
115 Do not include any explanation — only the JSON array.\n\n\
116 Available tools:\n",
117 );
118 for tool in tools {
119 prompt.push_str(&format!("- {}: {}\n", tool.name, tool.description));
120 }
121 prompt.push_str(&format!(
122 "\nUser request: \"{}\"\n\nJSON array:",
123 request.trim()
124 ));
125 prompt
126 }
127
128 fn parse_tool_names(response: &str, tools: &[&'static ToolDefinition]) -> Vec<String> {
131 let start = response.find('[');
132 let end = response.rfind(']');
133 let (Some(s), Some(e)) = (start, end) else {
134 return Vec::new();
135 };
136 let Ok(Value::Array(arr)) = serde_json::from_str::<Value>(&response[s..=e]) else {
137 return Vec::new();
138 };
139 let valid: HashSet<&str> = tools.iter().map(|t| t.name).collect();
140 arr.into_iter()
141 .filter_map(|v| v.as_str().map(str::to_lowercase))
142 .filter(|name| valid.contains(name.as_str()))
143 .collect()
144 }
145}
146
147#[async_trait]
148impl ToolRouter for LlmToolRouter {
149 async fn route(
150 &self,
151 request: &str,
152 tools: &[&'static ToolDefinition],
153 _keyword_confidence: f32,
154 ) -> RoutingResult {
155 let key = cache_key(request);
156
157 if let Some(cached) = self.cache.get(&key) {
159 tracing::debug!("LlmToolRouter: cache hit");
160 return (**cached).clone();
161 }
162
163 let prompt = Self::build_routing_prompt(request, tools);
164 let ctx = AgentContext::default();
165 let provider = select_provider(self.config.as_ref(), &ctx);
166
167 let result = match provider.call(&prompt).await {
168 Ok(response) => {
169 let suggested = Self::parse_tool_names(&response, tools);
170 tracing::debug!(tools = ?suggested, "LlmToolRouter: routed request to tools");
171 RoutingResult {
172 suggested_tools: suggested,
173 confidence: 1.0,
174 }
175 }
176 Err(e) => {
177 tracing::warn!(
178 error = %e,
179 "LlmToolRouter: LLM call failed, falling through to keyword routing"
180 );
181 RoutingResult::fallthrough()
182 }
183 };
184
185 let shared = Arc::new(result.clone());
186 self.cache.insert(key, shared);
187 result
188 }
189}
190
191pub struct HybridToolRouter {
201 llm_router: LlmToolRouter,
202 confidence_threshold: f32,
203}
204
205impl HybridToolRouter {
206 pub fn new(config: Arc<AppConfig>, confidence_threshold: f32) -> Self {
211 Self {
212 llm_router: LlmToolRouter::new(config),
213 confidence_threshold,
214 }
215 }
216}
217
218#[async_trait]
219impl ToolRouter for HybridToolRouter {
220 async fn route(
221 &self,
222 request: &str,
223 tools: &[&'static ToolDefinition],
224 keyword_confidence: f32,
225 ) -> RoutingResult {
226 if keyword_confidence >= self.confidence_threshold {
227 tracing::debug!(
228 keyword_confidence,
229 threshold = self.confidence_threshold,
230 "HybridToolRouter: above threshold, using keyword routing"
231 );
232 return RoutingResult::fallthrough();
233 }
234 tracing::debug!(
235 keyword_confidence,
236 threshold = self.confidence_threshold,
237 "HybridToolRouter: below threshold, invoking LLM router"
238 );
239 self.llm_router
240 .route(request, tools, keyword_confidence)
241 .await
242 }
243}
244
245pub fn build_tool_router(
254 strategy: &ToolRoutingStrategy,
255 config: Arc<AppConfig>,
256) -> Option<Box<dyn ToolRouter>> {
257 match strategy {
258 ToolRoutingStrategy::Keyword => None,
259 ToolRoutingStrategy::Llm => Some(Box::new(LlmToolRouter::new(config))),
260 ToolRoutingStrategy::Hybrid {
261 confidence_threshold,
262 } => Some(Box::new(HybridToolRouter::new(
263 config,
264 *confidence_threshold,
265 ))),
266 }
267}
268
269#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::tools::registry::all_tools;
277
278 #[test]
283 fn parse_tool_names_valid_json_filters_invalid_names() {
284 let tools: Vec<&'static ToolDefinition> = all_tools().iter().collect();
285 let response = r#"["file", "web", "not_a_real_tool", "web_search"]"#;
286 let result = LlmToolRouter::parse_tool_names(response, &tools);
287 assert_eq!(result, vec!["file", "web", "web_search"]);
288 }
289
290 #[test]
291 fn parse_tool_names_no_brackets_returns_empty() {
292 let tools: Vec<&'static ToolDefinition> = all_tools().iter().collect();
293 let response = "file, web, web_search";
294 let result = LlmToolRouter::parse_tool_names(response, &tools);
295 assert!(result.is_empty(), "expected empty without JSON brackets");
296 }
297
298 #[test]
299 fn parse_tool_names_invalid_json_returns_empty() {
300 let tools: Vec<&'static ToolDefinition> = all_tools().iter().collect();
301 let response = "[file, web]"; let result = LlmToolRouter::parse_tool_names(response, &tools);
303 assert!(result.is_empty(), "expected empty for invalid JSON");
304 }
305
306 #[test]
307 fn parse_tool_names_empty_array_returns_empty() {
308 let tools: Vec<&'static ToolDefinition> = all_tools().iter().collect();
309 let result = LlmToolRouter::parse_tool_names("[]", &tools);
310 assert!(result.is_empty());
311 }
312
313 #[test]
314 fn parse_tool_names_normalises_case() {
315 let tools: Vec<&'static ToolDefinition> = all_tools().iter().collect();
316 let response = r#"["FILE", "Web", "WEB_SEARCH"]"#;
318 let result = LlmToolRouter::parse_tool_names(response, &tools);
319 assert_eq!(result, vec!["file", "web", "web_search"]);
320 }
321
322 #[test]
323 fn parse_tool_names_extracts_from_prose_with_brackets() {
324 let tools: Vec<&'static ToolDefinition> = all_tools().iter().collect();
325 let response = r#"The tools you need are: ["shell", "git"]."#;
327 let result = LlmToolRouter::parse_tool_names(response, &tools);
328 assert_eq!(result, vec!["shell", "git"]);
329 }
330
331 #[test]
336 fn cache_key_is_deterministic() {
337 assert_eq!(cache_key("fetch gestura.ai"), cache_key("fetch gestura.ai"));
338 }
339
340 #[test]
341 fn cache_key_normalises_whitespace_and_case() {
342 assert_eq!(
345 cache_key(" Fetch Gestura.ai "),
346 cache_key("fetch gestura.ai")
347 );
348 }
349
350 #[test]
351 fn cache_key_differs_for_different_requests() {
352 assert_ne!(
353 cache_key("take a screenshot"),
354 cache_key("fetch gestura.ai")
355 );
356 }
357}