gestura_core_context/
manager.rs

1//! Context manager for smart context reduction
2//!
3//! Manages context loading, caching, and reduction based on request analysis.
4
5use crate::analyzer::RequestAnalyzer;
6use crate::cache::{CacheStats, ContextCache};
7use gestura_core_foundation::context::{
8    ContextCategory, EntityType, FileContext, RequestAnalysis, ResolvedContext, ToolContext,
9};
10use std::collections::{HashMap, HashSet};
11use std::hash::{Hash, Hasher};
12use std::path::Path;
13use std::sync::{Arc, RwLock};
14use std::time::SystemTime;
15
16/// Type alias for the tool provider callback.
17///
18/// Returns a `Vec` of `(name, summary)` pairs describing available tools.
19/// When set on a [`ContextManager`], the context resolution pipeline can
20/// include tool metadata without depending on a specific tool registry.
21pub type ToolProviderFn = Box<dyn Fn() -> Vec<(String, String)> + Send + Sync>;
22
23/// File metadata for cache invalidation
24#[derive(Debug, Clone)]
25struct FileMeta {
26    /// Last modification time
27    mtime: SystemTime,
28    /// File size
29    size: u64,
30}
31
32/// Cached response for similar requests
33#[derive(Debug, Clone)]
34pub struct CachedResponse {
35    /// The response content
36    pub response: String,
37    /// When this was cached
38    pub cached_at: std::time::Instant,
39    /// Request hash that generated this
40    pub request_hash: u64,
41}
42
43/// Manager for handling context in a smart, efficient way
44pub struct ContextManager {
45    /// Request analyzer
46    analyzer: RequestAnalyzer,
47    /// Cache for resolved contexts
48    context_cache: Arc<ContextCache<ResolvedContext>>,
49    /// Cache for file contents
50    file_cache: Arc<ContextCache<FileContext>>,
51    /// Cache for file metadata (for invalidation)
52    file_meta_cache: Arc<RwLock<HashMap<String, FileMeta>>>,
53    /// Cache for summarized history
54    history_cache: Arc<ContextCache<String>>,
55    /// Cache for similar request responses
56    response_cache: Arc<RwLock<Vec<CachedResponse>>>,
57    /// Maximum tokens for context
58    max_context_tokens: usize,
59    /// Whether to include tool schemas
60    include_tool_schemas: bool,
61    /// History summarization threshold (number of messages)
62    history_threshold: usize,
63    /// Maximum cached responses
64    max_cached_responses: usize,
65    /// Optional callback that provides tool name/summary pairs.
66    /// When `None`, tool resolution returns an empty list.
67    tool_provider: Option<ToolProviderFn>,
68}
69
70impl ContextManager {
71    /// Create a new context manager
72    pub fn new() -> Self {
73        Self {
74            analyzer: RequestAnalyzer::new(),
75            context_cache: Arc::new(ContextCache::with_ttl(600)), // 10 min TTL
76            file_cache: Arc::new(ContextCache::with_ttl(300)),    // 5 min TTL (LOW-2)
77            file_meta_cache: Arc::new(RwLock::new(HashMap::new())),
78            history_cache: Arc::new(ContextCache::with_ttl(300)), // 5 min TTL
79            response_cache: Arc::new(RwLock::new(Vec::new())),
80            max_context_tokens: 8000, // Conservative default
81            include_tool_schemas: true,
82            history_threshold: 10, // Summarize after 10 messages (matches max_history_messages)
83            max_cached_responses: 10, // Keep last 10 responses (LOW-3)
84            tool_provider: None,
85        }
86    }
87
88    /// Set a tool provider callback that supplies available tool name/summary pairs.
89    pub fn with_tool_provider(mut self, provider: ToolProviderFn) -> Self {
90        self.tool_provider = Some(provider);
91        self
92    }
93
94    /// Set maximum context tokens
95    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
96        self.max_context_tokens = max_tokens;
97        self
98    }
99
100    /// Set history summarization threshold
101    pub fn with_history_threshold(mut self, threshold: usize) -> Self {
102        self.history_threshold = threshold;
103        self
104    }
105
106    /// Disable tool schema inclusion for simpler contexts
107    pub fn without_tool_schemas(mut self) -> Self {
108        self.include_tool_schemas = false;
109        self
110    }
111
112    /// Analyze a request to determine what context is needed
113    pub fn analyze(&self, request: &str) -> RequestAnalysis {
114        self.analyzer.analyze(request)
115    }
116
117    /// Resolve context for a request
118    pub fn resolve_context<M>(
119        &self,
120        _request: &str,
121        analysis: &RequestAnalysis,
122        history: &[M],
123        workspace_dir: Option<&Path>,
124    ) -> ResolvedContext
125    where
126        M: AsRef<str>,
127    {
128        self.resolve_for_analysis_with_history(analysis, history, workspace_dir)
129    }
130
131    /// Simple resolve without history
132    pub fn resolve_simple(&self, request: &str, workspace_dir: Option<&Path>) -> ResolvedContext {
133        let analysis = self.analyze(request);
134        self.resolve_for_analysis(&analysis, workspace_dir)
135    }
136
137    /// Resolve context for a pre-analyzed request with history
138    pub fn resolve_for_analysis_with_history<M>(
139        &self,
140        analysis: &RequestAnalysis,
141        history: &[M],
142        workspace_dir: Option<&Path>,
143    ) -> ResolvedContext
144    where
145        M: AsRef<str>,
146    {
147        // G3: Use a fingerprint-based cache key that includes entity values so
148        // different file/entity requests get distinct cache entries.
149        let cache_key = self.cache_key_for(analysis);
150        if let Some(mut cached) = self.context_cache.get(&cache_key) {
151            // G3: History changes every turn — always recompute it even on a
152            // cache hit so the LLM sees the correct conversation summary.
153            let fresh_summary = self.summarize_history(history);
154            cached.history_summary = if fresh_summary.is_empty() {
155                None
156            } else {
157                Some(fresh_summary)
158            };
159            return cached;
160        }
161
162        // Build new context
163        let mut context = ResolvedContext {
164            categories: analysis.categories.clone(),
165            ..ResolvedContext::default()
166        };
167
168        // Add tools if needed
169        if analysis.needs_tools {
170            context.tools = self.get_tools_for_categories(&analysis.categories);
171            context.estimated_tokens += self.estimate_tool_tokens(&context.tools);
172        }
173
174        // Load files if mentioned (with mtime-based cache invalidation - LOW-2)
175        for entity in &analysis.entities {
176            if entity.entity_type != EntityType::FilePath {
177                continue;
178            }
179            if let Some(file_ctx) =
180                self.load_file_context_with_validation(&entity.value, workspace_dir)
181            {
182                context.estimated_tokens += estimate_tokens(&file_ctx.content);
183                context.files.push(file_ctx);
184            }
185        }
186
187        // Add history summary with threshold-based summarization (LOW-1)
188        let summary = self.summarize_history(history);
189        if !summary.is_empty() {
190            context.history_summary = Some(summary.clone());
191            context.estimated_tokens += estimate_tokens(&summary);
192        }
193
194        // Cache the result
195        self.context_cache.insert(cache_key, context.clone());
196
197        context
198    }
199
200    /// Summarize history with intelligent threshold (LOW-1)
201    pub fn summarize_history<M>(&self, history: &[M]) -> String
202    where
203        M: AsRef<str>,
204    {
205        if history.is_empty() {
206            return String::new();
207        }
208
209        // Generate cache key from history length and last message hash
210        let cache_key = format!(
211            "history:{}:{}",
212            history.len(),
213            history
214                .last()
215                .map(|m| {
216                    let mut hasher = std::collections::hash_map::DefaultHasher::new();
217                    m.as_ref().hash(&mut hasher);
218                    hasher.finish()
219                })
220                .unwrap_or(0)
221        );
222
223        // Check cache
224        if let Some(cached) = self.history_cache.get(&cache_key) {
225            return cached;
226        }
227
228        let summary = if history.len() > self.history_threshold {
229            // Summarize: keep first 3 messages (context), last 5 messages (recent)
230            let first_msgs: Vec<_> = history.iter().take(3).map(|m| m.as_ref()).collect();
231            let last_msgs: Vec<_> = history
232                .iter()
233                .rev()
234                .take(5)
235                .map(|m| m.as_ref())
236                .collect::<Vec<_>>()
237                .into_iter()
238                .rev()
239                .collect();
240
241            let summarized_count = history.len() - 8;
242            tracing::debug!(
243                total_messages = history.len(),
244                threshold = self.history_threshold,
245                summarized_messages = summarized_count,
246                "Context manager: applying history summarization"
247            );
248
249            format!(
250                "[Conversation start]\n{}\n[...{} messages summarized...]\n[Recent]\n{}",
251                first_msgs.join("\n---\n"),
252                summarized_count,
253                last_msgs.join("\n---\n")
254            )
255        } else if history.len() > 5 {
256            // Medium history: take last 5 messages
257            tracing::debug!(
258                total_messages = history.len(),
259                threshold = self.history_threshold,
260                "Context manager: medium history, taking last 5 messages"
261            );
262
263            history
264                .iter()
265                .rev()
266                .take(5)
267                .map(|m| m.as_ref())
268                .collect::<Vec<_>>()
269                .into_iter()
270                .rev()
271                .collect::<Vec<_>>()
272                .join("\n---\n")
273        } else {
274            // Short history: include all
275            tracing::debug!(
276                total_messages = history.len(),
277                "Context manager: short history, including all messages"
278            );
279
280            history
281                .iter()
282                .map(|m| m.as_ref())
283                .collect::<Vec<_>>()
284                .join("\n---\n")
285        };
286
287        // Cache the summary
288        self.history_cache.insert(cache_key, summary.clone());
289        summary
290    }
291
292    /// Resolve context for a pre-analyzed request (no history)
293    pub fn resolve_for_analysis(
294        &self,
295        analysis: &RequestAnalysis,
296        workspace_dir: Option<&Path>,
297    ) -> ResolvedContext {
298        let empty: Vec<String> = Vec::new();
299        self.resolve_for_analysis_with_history(analysis, &empty, workspace_dir)
300    }
301
302    /// Get tools relevant to the given categories
303    fn get_tools_for_categories(&self, categories: &HashSet<ContextCategory>) -> Vec<ToolContext> {
304        let all_tools: Vec<(String, String)> = match &self.tool_provider {
305            Some(provider) => provider(),
306            None => Vec::new(),
307        };
308        let mut result = Vec::new();
309
310        for (name, summary) in &all_tools {
311            let tool_cat = self.tool_to_category(name);
312            if categories.contains(&tool_cat) || categories.contains(&ContextCategory::Tools) {
313                result.push(ToolContext {
314                    name: name.clone(),
315                    description: summary.clone(),
316                    has_full_schema: self.include_tool_schemas,
317                });
318            }
319        }
320
321        result
322    }
323
324    /// Map tool name to category
325    fn tool_to_category(&self, name: &str) -> ContextCategory {
326        match name {
327            "file" => ContextCategory::FileSystem,
328            "shell" => ContextCategory::Shell,
329            "git" => ContextCategory::Git,
330            "code" => ContextCategory::Code,
331            "web" => ContextCategory::Web,
332            "permissions" => ContextCategory::Config,
333            _ => ContextCategory::General,
334        }
335    }
336
337    /// Load file context with mtime-based cache invalidation (LOW-2)
338    fn load_file_context_with_validation(
339        &self,
340        path: &str,
341        workspace_dir: Option<&Path>,
342    ) -> Option<FileContext> {
343        let actual_path = workspace_dir
344            .map(|w| w.join(path))
345            .unwrap_or_else(|| std::path::PathBuf::from(path));
346
347        // Check if file exists and get metadata
348        let metadata = std::fs::metadata(&actual_path).ok()?;
349        let mtime = metadata.modified().ok()?;
350        let size = metadata.len();
351
352        // Check if cached version is still valid
353        let cache_valid = {
354            let meta_cache = self.file_meta_cache.read().ok()?;
355            if let Some(cached_meta) = meta_cache.get(path) {
356                cached_meta.mtime == mtime && cached_meta.size == size
357            } else {
358                false
359            }
360        };
361
362        if cache_valid && let Some(cached) = self.file_cache.get(path) {
363            return Some(cached);
364        }
365
366        // Cache miss or invalidated - reload file
367        let ctx = self.load_file_context(path, workspace_dir)?;
368
369        // Update metadata cache
370        if let Ok(mut meta_cache) = self.file_meta_cache.write() {
371            meta_cache.insert(path.to_string(), FileMeta { mtime, size });
372        }
373
374        Some(ctx)
375    }
376
377    /// Load file context (with caching)
378    fn load_file_context(&self, path: &str, workspace_dir: Option<&Path>) -> Option<FileContext> {
379        // Check cache
380        if let Some(cached) = self.file_cache.get(path) {
381            return Some(cached);
382        }
383
384        let actual_path = workspace_dir
385            .map(|w| w.join(path))
386            .unwrap_or_else(|| std::path::PathBuf::from(path));
387
388        // Try to read file
389        match std::fs::read_to_string(&actual_path) {
390            Ok(content) => {
391                let lines: Vec<&str> = content.lines().collect();
392                let total_lines = lines.len();
393                let (content, truncated) = if total_lines > 100 {
394                    // Truncate to first 100 lines
395                    (lines[..100].join("\n"), true)
396                } else {
397                    (content, false)
398                };
399
400                let file_ctx = FileContext {
401                    path: path.to_string(),
402                    content,
403                    truncated,
404                    total_lines,
405                };
406                self.file_cache.insert(path.to_string(), file_ctx.clone());
407                Some(file_ctx)
408            }
409            Err(_) => None,
410        }
411    }
412
413    /// Estimate tokens for tools
414    fn estimate_tool_tokens(&self, tools: &[ToolContext]) -> usize {
415        tools
416            .iter()
417            .map(|t| {
418                let base = estimate_tokens(&t.name) + estimate_tokens(&t.description);
419                if t.has_full_schema { base + 100 } else { base }
420            })
421            .sum()
422    }
423
424    /// Generate cache key for analysis.
425    ///
426    /// G3: Uses `compute_request_hash` which hashes categories + sorted entity
427    /// values + needs_tools — distinct file/entity requests no longer collide.
428    fn cache_key_for(&self, analysis: &RequestAnalysis) -> String {
429        let hash = self.compute_request_hash(analysis);
430        format!("ctx:{:016x}", hash)
431    }
432
433    /// Get cache statistics
434    pub fn cache_stats(&self) -> ContextManagerStats {
435        ContextManagerStats {
436            context_cache: self.context_cache.stats(),
437            file_cache: self.file_cache.stats(),
438            history_cache: self.history_cache.stats(),
439        }
440    }
441
442    /// Clear all caches
443    pub fn clear_caches(&self) {
444        self.context_cache.clear();
445        self.file_cache.clear();
446        self.history_cache.clear();
447    }
448
449    /// Evict expired entries from all caches
450    pub fn cleanup(&self) {
451        self.context_cache.evict_expired();
452        self.file_cache.evict_expired();
453        self.history_cache.evict_expired();
454    }
455
456    // =========================================================================
457    // Request Similarity Detection (LOW-3)
458    // =========================================================================
459
460    /// Compute a hash for a request based on categories and key entities
461    pub fn compute_request_hash(&self, analysis: &RequestAnalysis) -> u64 {
462        use std::collections::hash_map::DefaultHasher;
463        let mut hasher = DefaultHasher::new();
464
465        // Hash categories (sorted for consistency)
466        let mut cats: Vec<_> = analysis
467            .categories
468            .iter()
469            .map(|c| format!("{:?}", c))
470            .collect();
471        cats.sort();
472        for cat in cats {
473            cat.hash(&mut hasher);
474        }
475
476        // Hash key entities (sorted)
477        let mut entities: Vec<_> = analysis
478            .entities
479            .iter()
480            .map(|e| format!("{}:{}", e.entity_type as u8, e.value))
481            .collect();
482        entities.sort();
483        for entity in entities {
484            entity.hash(&mut hasher);
485        }
486
487        // Hash tool requirement
488        analysis.needs_tools.hash(&mut hasher);
489
490        hasher.finish()
491    }
492
493    /// Check if we have a cached response for a similar request
494    pub fn get_cached_response(&self, analysis: &RequestAnalysis) -> Option<CachedResponse> {
495        let request_hash = self.compute_request_hash(analysis);
496        let cache = self.response_cache.read().ok()?;
497
498        // Find matching response (within 5 minutes)
499        let max_age = std::time::Duration::from_secs(300);
500        cache
501            .iter()
502            .find(|r| r.request_hash == request_hash && r.cached_at.elapsed() < max_age)
503            .cloned()
504    }
505
506    /// Cache a response for potential reuse
507    pub fn cache_response(&self, analysis: &RequestAnalysis, response: String) {
508        let request_hash = self.compute_request_hash(analysis);
509
510        if let Ok(mut cache) = self.response_cache.write() {
511            // Remove old entry with same hash if exists
512            cache.retain(|r| r.request_hash != request_hash);
513
514            // Add new entry
515            cache.push(CachedResponse {
516                response,
517                cached_at: std::time::Instant::now(),
518                request_hash,
519            });
520
521            // Trim to max size
522            while cache.len() > self.max_cached_responses {
523                cache.remove(0);
524            }
525        }
526    }
527
528    /// Check if a request is similar to a recent one (for deduplication)
529    pub fn is_similar_to_recent(&self, analysis: &RequestAnalysis) -> bool {
530        self.get_cached_response(analysis).is_some()
531    }
532}
533
534impl Default for ContextManager {
535    fn default() -> Self {
536        Self::new()
537    }
538}
539
540/// Statistics for the context manager
541#[derive(Debug, Clone)]
542pub struct ContextManagerStats {
543    /// Context cache stats
544    pub context_cache: CacheStats,
545    /// File cache stats
546    pub file_cache: CacheStats,
547    /// History cache stats
548    pub history_cache: CacheStats,
549}
550
551/// Estimate token count for a string (rough approximation)
552pub fn estimate_tokens(s: &str) -> usize {
553    // Rough estimate: ~4 chars per token on average
554    (s.len() / 4).max(1)
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    #[test]
562    fn test_context_manager_new() {
563        let manager = ContextManager::new();
564        assert!(manager.max_context_tokens > 0);
565    }
566
567    #[test]
568    fn test_analyze_file_request() {
569        let manager = ContextManager::new();
570        let analysis = manager.analyze("Read the file src/main.rs");
571
572        assert!(analysis.categories.contains(&ContextCategory::FileSystem));
573        assert!(analysis.needs_tools);
574    }
575
576    #[test]
577    fn test_resolve_context_general() {
578        let manager = ContextManager::new();
579        let context = manager.resolve_simple("What is Rust?", None);
580
581        assert!(context.categories.contains(&ContextCategory::General));
582        assert!(context.tools.is_empty());
583    }
584
585    #[test]
586    fn test_resolve_context_with_tools() {
587        let manager = ContextManager::new().with_tool_provider(Box::new(|| {
588            vec![
589                ("file".to_string(), "Read/write files".to_string()),
590                ("shell".to_string(), "Run shell commands".to_string()),
591                ("git".to_string(), "Git operations".to_string()),
592            ]
593        }));
594        let context = manager.resolve_simple("List files in the current directory", None);
595
596        assert!(context.categories.contains(&ContextCategory::FileSystem));
597        assert!(!context.tools.is_empty());
598    }
599
600    #[test]
601    fn test_resolve_context_without_provider() {
602        let manager = ContextManager::new();
603        let context = manager.resolve_simple("List files in the current directory", None);
604
605        assert!(context.categories.contains(&ContextCategory::FileSystem));
606        // Without a tool provider, tools should be empty
607        assert!(context.tools.is_empty());
608    }
609
610    #[test]
611    fn test_cache_stats() {
612        let manager = ContextManager::new();
613        // Make some calls to populate cache
614        let _ = manager.resolve_simple("Read a file", None);
615        let _ = manager.resolve_simple("Git status", None);
616
617        let stats = manager.cache_stats();
618        assert!(stats.context_cache.size > 0);
619    }
620
621    #[test]
622    fn test_workspace_dir_resolution() {
623        use std::io::Write;
624
625        let temp_dir = std::env::temp_dir().join("gestura_test_workspace");
626        let _ = std::fs::create_dir_all(&temp_dir);
627        let file_path = temp_dir.join("test_file.txt");
628        let mut file = std::fs::File::create(&file_path).unwrap();
629        writeln!(file, "hello workspace").unwrap();
630
631        let manager = ContextManager::new();
632        let req = "Read the file test_file.txt".to_string();
633
634        let ctx = manager.resolve_simple(&req, Some(&temp_dir));
635
636        // Cleanup
637        let _ = std::fs::remove_dir_all(&temp_dir);
638
639        assert!(ctx.categories.contains(&ContextCategory::FileSystem));
640        assert!(!ctx.files.is_empty());
641        assert_eq!(ctx.files[0].content.trim(), "hello workspace");
642    }
643}