gestura_core/pipeline/
tool_router.rs

1//! LLM Pre-flight Tool Router
2//!
3//! Implements semantic tool selection via a cheap pre-flight LLM call.
4//! Reduces silent tool-selection failures on ambiguous or novel user requests.
5
6use std::collections::HashSet;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use dashmap::DashMap;
11use serde_json::Value;
12
13use crate::config::AppConfig;
14use crate::llm_provider::{AgentContext, select_provider};
15use crate::tools::registry::ToolDefinition;
16use gestura_core_pipeline::types::ToolRoutingStrategy;
17
18// ---------------------------------------------------------------------------
19// Result type
20// ---------------------------------------------------------------------------
21
22/// Result of a pre-flight tool routing operation.
23#[derive(Debug, Clone)]
24pub struct RoutingResult {
25    /// Tool names selected by the router.
26    ///
27    /// An **empty** list is the signal to fall through to keyword/category
28    /// routing — the router made no decision.
29    pub suggested_tools: Vec<String>,
30    /// Confidence in this routing decision (0.0–1.0).
31    ///
32    /// LLM-sourced decisions carry `1.0`; a fallthrough carries `0.0`.
33    pub confidence: f32,
34}
35
36impl RoutingResult {
37    /// Pass-through sentinel — tells the pipeline to use keyword routing instead.
38    pub fn fallthrough() -> Self {
39        Self {
40            suggested_tools: Vec::new(),
41            confidence: 0.0,
42        }
43    }
44
45    /// Returns `true` if this result contains an explicit tool selection.
46    pub fn has_selection(&self) -> bool {
47        !self.suggested_tools.is_empty()
48    }
49}
50
51// ---------------------------------------------------------------------------
52// Trait
53// ---------------------------------------------------------------------------
54
55/// Async trait for tool selection strategies.
56///
57/// Implementations decide which built-in tools to expose to the LLM for a
58/// given user request.  Return [`RoutingResult::fallthrough()`] to defer to
59/// the existing keyword/category routing path.
60#[async_trait]
61pub trait ToolRouter: Send + Sync {
62    /// Select the most relevant tools for a request.
63    ///
64    /// # Parameters
65    /// - `request`: the raw user input string
66    /// - `tools`: full list of available [`ToolDefinition`]s
67    /// - `keyword_confidence`: confidence from the keyword [`RequestAnalyzer`]
68    ///
69    /// [`RequestAnalyzer`]: crate::context::RequestAnalyzer
70    async fn route(
71        &self,
72        request: &str,
73        tools: &[&'static ToolDefinition],
74        keyword_confidence: f32,
75    ) -> RoutingResult;
76}
77
78// ---------------------------------------------------------------------------
79// LlmToolRouter
80// ---------------------------------------------------------------------------
81
82/// Cache key: blake3 hash of normalised request text.
83type CacheKey = [u8; 32];
84
85fn cache_key(request: &str) -> CacheKey {
86    *blake3::hash(request.trim().to_lowercase().as_bytes()).as_bytes()
87}
88
89/// LLM-based tool router.
90///
91/// Fires a single, cheap pre-flight LLM call whose sole job is to pick
92/// the minimal set of tool names relevant to the request.  Results are
93/// cached in-process by request hash to avoid repeated calls for identical
94/// inputs within a session.
95pub struct LlmToolRouter {
96    config: Arc<AppConfig>,
97    cache: DashMap<CacheKey, Arc<RoutingResult>>,
98}
99
100impl LlmToolRouter {
101    /// Create a new router backed by `config`.
102    pub fn new(config: Arc<AppConfig>) -> Self {
103        Self {
104            config,
105            cache: DashMap::new(),
106        }
107    }
108
109    /// Build the routing prompt.
110    fn build_routing_prompt(request: &str, tools: &[&'static ToolDefinition]) -> String {
111        let mut prompt = String::from(
112            "You are a tool selector. Given a user request and a list of tools, \
113             respond with ONLY a JSON array of tool names needed to fulfill the request.\n\
114             Choose the minimal set (1-4 tools). If no tool is needed respond with [].\n\
115             Do not include any explanation — only the JSON array.\n\n\
116             Available tools:\n",
117        );
118        for tool in tools {
119            prompt.push_str(&format!("- {}: {}\n", tool.name, tool.description));
120        }
121        prompt.push_str(&format!(
122            "\nUser request: \"{}\"\n\nJSON array:",
123            request.trim()
124        ));
125        prompt
126    }
127
128    /// Parse a JSON array of tool names from the LLM response, validating
129    /// each name against the known tool set.
130    fn parse_tool_names(response: &str, tools: &[&'static ToolDefinition]) -> Vec<String> {
131        let start = response.find('[');
132        let end = response.rfind(']');
133        let (Some(s), Some(e)) = (start, end) else {
134            return Vec::new();
135        };
136        let Ok(Value::Array(arr)) = serde_json::from_str::<Value>(&response[s..=e]) else {
137            return Vec::new();
138        };
139        let valid: HashSet<&str> = tools.iter().map(|t| t.name).collect();
140        arr.into_iter()
141            .filter_map(|v| v.as_str().map(str::to_lowercase))
142            .filter(|name| valid.contains(name.as_str()))
143            .collect()
144    }
145}
146
147#[async_trait]
148impl ToolRouter for LlmToolRouter {
149    async fn route(
150        &self,
151        request: &str,
152        tools: &[&'static ToolDefinition],
153        _keyword_confidence: f32,
154    ) -> RoutingResult {
155        let key = cache_key(request);
156
157        // Cache hit — avoid redundant LLM calls for identical requests.
158        if let Some(cached) = self.cache.get(&key) {
159            tracing::debug!("LlmToolRouter: cache hit");
160            return (**cached).clone();
161        }
162
163        let prompt = Self::build_routing_prompt(request, tools);
164        let ctx = AgentContext::default();
165        let provider = select_provider(self.config.as_ref(), &ctx);
166
167        let result = match provider.call(&prompt).await {
168            Ok(response) => {
169                let suggested = Self::parse_tool_names(&response, tools);
170                tracing::debug!(tools = ?suggested, "LlmToolRouter: routed request to tools");
171                RoutingResult {
172                    suggested_tools: suggested,
173                    confidence: 1.0,
174                }
175            }
176            Err(e) => {
177                tracing::warn!(
178                    error = %e,
179                    "LlmToolRouter: LLM call failed, falling through to keyword routing"
180                );
181                RoutingResult::fallthrough()
182            }
183        };
184
185        let shared = Arc::new(result.clone());
186        self.cache.insert(key, shared);
187        result
188    }
189}
190
191// ---------------------------------------------------------------------------
192// HybridToolRouter
193// ---------------------------------------------------------------------------
194
195/// Hybrid tool router.
196///
197/// Uses the keyword analyzer's confidence score to decide when to invoke the
198/// LLM router.  Requests where keyword analysis is confident enough bypass the
199/// extra round-trip entirely.
200pub struct HybridToolRouter {
201    llm_router: LlmToolRouter,
202    confidence_threshold: f32,
203}
204
205impl HybridToolRouter {
206    /// Create a new hybrid router.
207    ///
208    /// `confidence_threshold` is the minimum keyword-analysis confidence above
209    /// which the LLM call is skipped.  Values in `[0.2, 0.5]` are recommended.
210    pub fn new(config: Arc<AppConfig>, confidence_threshold: f32) -> Self {
211        Self {
212            llm_router: LlmToolRouter::new(config),
213            confidence_threshold,
214        }
215    }
216}
217
218#[async_trait]
219impl ToolRouter for HybridToolRouter {
220    async fn route(
221        &self,
222        request: &str,
223        tools: &[&'static ToolDefinition],
224        keyword_confidence: f32,
225    ) -> RoutingResult {
226        if keyword_confidence >= self.confidence_threshold {
227            tracing::debug!(
228                keyword_confidence,
229                threshold = self.confidence_threshold,
230                "HybridToolRouter: above threshold, using keyword routing"
231            );
232            return RoutingResult::fallthrough();
233        }
234        tracing::debug!(
235            keyword_confidence,
236            threshold = self.confidence_threshold,
237            "HybridToolRouter: below threshold, invoking LLM router"
238        );
239        self.llm_router
240            .route(request, tools, keyword_confidence)
241            .await
242    }
243}
244
245// ---------------------------------------------------------------------------
246// Factory
247// ---------------------------------------------------------------------------
248
249/// Build a [`ToolRouter`] from a [`ToolRoutingStrategy`] and app config.
250///
251/// Returns `None` for [`ToolRoutingStrategy::Keyword`] — no extra router
252/// object is needed; the pipeline's existing keyword path runs as-is.
253pub fn build_tool_router(
254    strategy: &ToolRoutingStrategy,
255    config: Arc<AppConfig>,
256) -> Option<Box<dyn ToolRouter>> {
257    match strategy {
258        ToolRoutingStrategy::Keyword => None,
259        ToolRoutingStrategy::Llm => Some(Box::new(LlmToolRouter::new(config))),
260        ToolRoutingStrategy::Hybrid {
261            confidence_threshold,
262        } => Some(Box::new(HybridToolRouter::new(
263            config,
264            *confidence_threshold,
265        ))),
266    }
267}
268
269// ---------------------------------------------------------------------------
270// Tests (private helpers)
271// ---------------------------------------------------------------------------
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::tools::registry::all_tools;
277
278    // ------------------------------------------------------------------
279    // parse_tool_names
280    // ------------------------------------------------------------------
281
282    #[test]
283    fn parse_tool_names_valid_json_filters_invalid_names() {
284        let tools: Vec<&'static ToolDefinition> = all_tools().iter().collect();
285        let response = r#"["file", "web", "not_a_real_tool", "web_search"]"#;
286        let result = LlmToolRouter::parse_tool_names(response, &tools);
287        assert_eq!(result, vec!["file", "web", "web_search"]);
288    }
289
290    #[test]
291    fn parse_tool_names_no_brackets_returns_empty() {
292        let tools: Vec<&'static ToolDefinition> = all_tools().iter().collect();
293        let response = "file, web, web_search";
294        let result = LlmToolRouter::parse_tool_names(response, &tools);
295        assert!(result.is_empty(), "expected empty without JSON brackets");
296    }
297
298    #[test]
299    fn parse_tool_names_invalid_json_returns_empty() {
300        let tools: Vec<&'static ToolDefinition> = all_tools().iter().collect();
301        let response = "[file, web]"; // not valid JSON
302        let result = LlmToolRouter::parse_tool_names(response, &tools);
303        assert!(result.is_empty(), "expected empty for invalid JSON");
304    }
305
306    #[test]
307    fn parse_tool_names_empty_array_returns_empty() {
308        let tools: Vec<&'static ToolDefinition> = all_tools().iter().collect();
309        let result = LlmToolRouter::parse_tool_names("[]", &tools);
310        assert!(result.is_empty());
311    }
312
313    #[test]
314    fn parse_tool_names_normalises_case() {
315        let tools: Vec<&'static ToolDefinition> = all_tools().iter().collect();
316        // LLM might return mixed case — we lowercase before matching.
317        let response = r#"["FILE", "Web", "WEB_SEARCH"]"#;
318        let result = LlmToolRouter::parse_tool_names(response, &tools);
319        assert_eq!(result, vec!["file", "web", "web_search"]);
320    }
321
322    #[test]
323    fn parse_tool_names_extracts_from_prose_with_brackets() {
324        let tools: Vec<&'static ToolDefinition> = all_tools().iter().collect();
325        // LLM sometimes wraps the array in prose — we find the first/last bracket.
326        let response = r#"The tools you need are: ["shell", "git"]."#;
327        let result = LlmToolRouter::parse_tool_names(response, &tools);
328        assert_eq!(result, vec!["shell", "git"]);
329    }
330
331    // ------------------------------------------------------------------
332    // cache_key
333    // ------------------------------------------------------------------
334
335    #[test]
336    fn cache_key_is_deterministic() {
337        assert_eq!(cache_key("fetch gestura.ai"), cache_key("fetch gestura.ai"));
338    }
339
340    #[test]
341    fn cache_key_normalises_whitespace_and_case() {
342        // Leading/trailing whitespace and case differences should produce the
343        // same key so near-duplicate requests hit the cache.
344        assert_eq!(
345            cache_key("  Fetch Gestura.ai  "),
346            cache_key("fetch gestura.ai")
347        );
348    }
349
350    #[test]
351    fn cache_key_differs_for_different_requests() {
352        assert_ne!(
353            cache_key("take a screenshot"),
354            cache_key("fetch gestura.ai")
355        );
356    }
357}