1use crate::analyzer::RequestAnalyzer;
6use crate::cache::{CacheStats, ContextCache};
7use gestura_core_foundation::context::{
8 ContextCategory, EntityType, FileContext, RequestAnalysis, ResolvedContext, ToolContext,
9};
10use std::collections::{HashMap, HashSet};
11use std::hash::{Hash, Hasher};
12use std::path::Path;
13use std::sync::{Arc, RwLock};
14use std::time::SystemTime;
15
16pub type ToolProviderFn = Box<dyn Fn() -> Vec<(String, String)> + Send + Sync>;
22
23#[derive(Debug, Clone)]
25struct FileMeta {
26 mtime: SystemTime,
28 size: u64,
30}
31
32#[derive(Debug, Clone)]
34pub struct CachedResponse {
35 pub response: String,
37 pub cached_at: std::time::Instant,
39 pub request_hash: u64,
41}
42
43pub struct ContextManager {
45 analyzer: RequestAnalyzer,
47 context_cache: Arc<ContextCache<ResolvedContext>>,
49 file_cache: Arc<ContextCache<FileContext>>,
51 file_meta_cache: Arc<RwLock<HashMap<String, FileMeta>>>,
53 history_cache: Arc<ContextCache<String>>,
55 response_cache: Arc<RwLock<Vec<CachedResponse>>>,
57 max_context_tokens: usize,
59 include_tool_schemas: bool,
61 history_threshold: usize,
63 max_cached_responses: usize,
65 tool_provider: Option<ToolProviderFn>,
68}
69
70impl ContextManager {
71 pub fn new() -> Self {
73 Self {
74 analyzer: RequestAnalyzer::new(),
75 context_cache: Arc::new(ContextCache::with_ttl(600)), file_cache: Arc::new(ContextCache::with_ttl(300)), file_meta_cache: Arc::new(RwLock::new(HashMap::new())),
78 history_cache: Arc::new(ContextCache::with_ttl(300)), response_cache: Arc::new(RwLock::new(Vec::new())),
80 max_context_tokens: 8000, include_tool_schemas: true,
82 history_threshold: 10, max_cached_responses: 10, tool_provider: None,
85 }
86 }
87
88 pub fn with_tool_provider(mut self, provider: ToolProviderFn) -> Self {
90 self.tool_provider = Some(provider);
91 self
92 }
93
94 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
96 self.max_context_tokens = max_tokens;
97 self
98 }
99
100 pub fn with_history_threshold(mut self, threshold: usize) -> Self {
102 self.history_threshold = threshold;
103 self
104 }
105
106 pub fn without_tool_schemas(mut self) -> Self {
108 self.include_tool_schemas = false;
109 self
110 }
111
112 pub fn analyze(&self, request: &str) -> RequestAnalysis {
114 self.analyzer.analyze(request)
115 }
116
117 pub fn resolve_context<M>(
119 &self,
120 _request: &str,
121 analysis: &RequestAnalysis,
122 history: &[M],
123 workspace_dir: Option<&Path>,
124 ) -> ResolvedContext
125 where
126 M: AsRef<str>,
127 {
128 self.resolve_for_analysis_with_history(analysis, history, workspace_dir)
129 }
130
131 pub fn resolve_simple(&self, request: &str, workspace_dir: Option<&Path>) -> ResolvedContext {
133 let analysis = self.analyze(request);
134 self.resolve_for_analysis(&analysis, workspace_dir)
135 }
136
137 pub fn resolve_for_analysis_with_history<M>(
139 &self,
140 analysis: &RequestAnalysis,
141 history: &[M],
142 workspace_dir: Option<&Path>,
143 ) -> ResolvedContext
144 where
145 M: AsRef<str>,
146 {
147 let cache_key = self.cache_key_for(analysis);
150 if let Some(mut cached) = self.context_cache.get(&cache_key) {
151 let fresh_summary = self.summarize_history(history);
154 cached.history_summary = if fresh_summary.is_empty() {
155 None
156 } else {
157 Some(fresh_summary)
158 };
159 return cached;
160 }
161
162 let mut context = ResolvedContext {
164 categories: analysis.categories.clone(),
165 ..ResolvedContext::default()
166 };
167
168 if analysis.needs_tools {
170 context.tools = self.get_tools_for_categories(&analysis.categories);
171 context.estimated_tokens += self.estimate_tool_tokens(&context.tools);
172 }
173
174 for entity in &analysis.entities {
176 if entity.entity_type != EntityType::FilePath {
177 continue;
178 }
179 if let Some(file_ctx) =
180 self.load_file_context_with_validation(&entity.value, workspace_dir)
181 {
182 context.estimated_tokens += estimate_tokens(&file_ctx.content);
183 context.files.push(file_ctx);
184 }
185 }
186
187 let summary = self.summarize_history(history);
189 if !summary.is_empty() {
190 context.history_summary = Some(summary.clone());
191 context.estimated_tokens += estimate_tokens(&summary);
192 }
193
194 self.context_cache.insert(cache_key, context.clone());
196
197 context
198 }
199
200 pub fn summarize_history<M>(&self, history: &[M]) -> String
202 where
203 M: AsRef<str>,
204 {
205 if history.is_empty() {
206 return String::new();
207 }
208
209 let cache_key = format!(
211 "history:{}:{}",
212 history.len(),
213 history
214 .last()
215 .map(|m| {
216 let mut hasher = std::collections::hash_map::DefaultHasher::new();
217 m.as_ref().hash(&mut hasher);
218 hasher.finish()
219 })
220 .unwrap_or(0)
221 );
222
223 if let Some(cached) = self.history_cache.get(&cache_key) {
225 return cached;
226 }
227
228 let summary = if history.len() > self.history_threshold {
229 let first_msgs: Vec<_> = history.iter().take(3).map(|m| m.as_ref()).collect();
231 let last_msgs: Vec<_> = history
232 .iter()
233 .rev()
234 .take(5)
235 .map(|m| m.as_ref())
236 .collect::<Vec<_>>()
237 .into_iter()
238 .rev()
239 .collect();
240
241 let summarized_count = history.len() - 8;
242 tracing::debug!(
243 total_messages = history.len(),
244 threshold = self.history_threshold,
245 summarized_messages = summarized_count,
246 "Context manager: applying history summarization"
247 );
248
249 format!(
250 "[Conversation start]\n{}\n[...{} messages summarized...]\n[Recent]\n{}",
251 first_msgs.join("\n---\n"),
252 summarized_count,
253 last_msgs.join("\n---\n")
254 )
255 } else if history.len() > 5 {
256 tracing::debug!(
258 total_messages = history.len(),
259 threshold = self.history_threshold,
260 "Context manager: medium history, taking last 5 messages"
261 );
262
263 history
264 .iter()
265 .rev()
266 .take(5)
267 .map(|m| m.as_ref())
268 .collect::<Vec<_>>()
269 .into_iter()
270 .rev()
271 .collect::<Vec<_>>()
272 .join("\n---\n")
273 } else {
274 tracing::debug!(
276 total_messages = history.len(),
277 "Context manager: short history, including all messages"
278 );
279
280 history
281 .iter()
282 .map(|m| m.as_ref())
283 .collect::<Vec<_>>()
284 .join("\n---\n")
285 };
286
287 self.history_cache.insert(cache_key, summary.clone());
289 summary
290 }
291
292 pub fn resolve_for_analysis(
294 &self,
295 analysis: &RequestAnalysis,
296 workspace_dir: Option<&Path>,
297 ) -> ResolvedContext {
298 let empty: Vec<String> = Vec::new();
299 self.resolve_for_analysis_with_history(analysis, &empty, workspace_dir)
300 }
301
302 fn get_tools_for_categories(&self, categories: &HashSet<ContextCategory>) -> Vec<ToolContext> {
304 let all_tools: Vec<(String, String)> = match &self.tool_provider {
305 Some(provider) => provider(),
306 None => Vec::new(),
307 };
308 let mut result = Vec::new();
309
310 for (name, summary) in &all_tools {
311 let tool_cat = self.tool_to_category(name);
312 if categories.contains(&tool_cat) || categories.contains(&ContextCategory::Tools) {
313 result.push(ToolContext {
314 name: name.clone(),
315 description: summary.clone(),
316 has_full_schema: self.include_tool_schemas,
317 });
318 }
319 }
320
321 result
322 }
323
324 fn tool_to_category(&self, name: &str) -> ContextCategory {
326 match name {
327 "file" => ContextCategory::FileSystem,
328 "shell" => ContextCategory::Shell,
329 "git" => ContextCategory::Git,
330 "code" => ContextCategory::Code,
331 "web" => ContextCategory::Web,
332 "permissions" => ContextCategory::Config,
333 _ => ContextCategory::General,
334 }
335 }
336
337 fn load_file_context_with_validation(
339 &self,
340 path: &str,
341 workspace_dir: Option<&Path>,
342 ) -> Option<FileContext> {
343 let actual_path = workspace_dir
344 .map(|w| w.join(path))
345 .unwrap_or_else(|| std::path::PathBuf::from(path));
346
347 let metadata = std::fs::metadata(&actual_path).ok()?;
349 let mtime = metadata.modified().ok()?;
350 let size = metadata.len();
351
352 let cache_valid = {
354 let meta_cache = self.file_meta_cache.read().ok()?;
355 if let Some(cached_meta) = meta_cache.get(path) {
356 cached_meta.mtime == mtime && cached_meta.size == size
357 } else {
358 false
359 }
360 };
361
362 if cache_valid && let Some(cached) = self.file_cache.get(path) {
363 return Some(cached);
364 }
365
366 let ctx = self.load_file_context(path, workspace_dir)?;
368
369 if let Ok(mut meta_cache) = self.file_meta_cache.write() {
371 meta_cache.insert(path.to_string(), FileMeta { mtime, size });
372 }
373
374 Some(ctx)
375 }
376
377 fn load_file_context(&self, path: &str, workspace_dir: Option<&Path>) -> Option<FileContext> {
379 if let Some(cached) = self.file_cache.get(path) {
381 return Some(cached);
382 }
383
384 let actual_path = workspace_dir
385 .map(|w| w.join(path))
386 .unwrap_or_else(|| std::path::PathBuf::from(path));
387
388 match std::fs::read_to_string(&actual_path) {
390 Ok(content) => {
391 let lines: Vec<&str> = content.lines().collect();
392 let total_lines = lines.len();
393 let (content, truncated) = if total_lines > 100 {
394 (lines[..100].join("\n"), true)
396 } else {
397 (content, false)
398 };
399
400 let file_ctx = FileContext {
401 path: path.to_string(),
402 content,
403 truncated,
404 total_lines,
405 };
406 self.file_cache.insert(path.to_string(), file_ctx.clone());
407 Some(file_ctx)
408 }
409 Err(_) => None,
410 }
411 }
412
413 fn estimate_tool_tokens(&self, tools: &[ToolContext]) -> usize {
415 tools
416 .iter()
417 .map(|t| {
418 let base = estimate_tokens(&t.name) + estimate_tokens(&t.description);
419 if t.has_full_schema { base + 100 } else { base }
420 })
421 .sum()
422 }
423
424 fn cache_key_for(&self, analysis: &RequestAnalysis) -> String {
429 let hash = self.compute_request_hash(analysis);
430 format!("ctx:{:016x}", hash)
431 }
432
433 pub fn cache_stats(&self) -> ContextManagerStats {
435 ContextManagerStats {
436 context_cache: self.context_cache.stats(),
437 file_cache: self.file_cache.stats(),
438 history_cache: self.history_cache.stats(),
439 }
440 }
441
442 pub fn clear_caches(&self) {
444 self.context_cache.clear();
445 self.file_cache.clear();
446 self.history_cache.clear();
447 }
448
449 pub fn cleanup(&self) {
451 self.context_cache.evict_expired();
452 self.file_cache.evict_expired();
453 self.history_cache.evict_expired();
454 }
455
456 pub fn compute_request_hash(&self, analysis: &RequestAnalysis) -> u64 {
462 use std::collections::hash_map::DefaultHasher;
463 let mut hasher = DefaultHasher::new();
464
465 let mut cats: Vec<_> = analysis
467 .categories
468 .iter()
469 .map(|c| format!("{:?}", c))
470 .collect();
471 cats.sort();
472 for cat in cats {
473 cat.hash(&mut hasher);
474 }
475
476 let mut entities: Vec<_> = analysis
478 .entities
479 .iter()
480 .map(|e| format!("{}:{}", e.entity_type as u8, e.value))
481 .collect();
482 entities.sort();
483 for entity in entities {
484 entity.hash(&mut hasher);
485 }
486
487 analysis.needs_tools.hash(&mut hasher);
489
490 hasher.finish()
491 }
492
493 pub fn get_cached_response(&self, analysis: &RequestAnalysis) -> Option<CachedResponse> {
495 let request_hash = self.compute_request_hash(analysis);
496 let cache = self.response_cache.read().ok()?;
497
498 let max_age = std::time::Duration::from_secs(300);
500 cache
501 .iter()
502 .find(|r| r.request_hash == request_hash && r.cached_at.elapsed() < max_age)
503 .cloned()
504 }
505
506 pub fn cache_response(&self, analysis: &RequestAnalysis, response: String) {
508 let request_hash = self.compute_request_hash(analysis);
509
510 if let Ok(mut cache) = self.response_cache.write() {
511 cache.retain(|r| r.request_hash != request_hash);
513
514 cache.push(CachedResponse {
516 response,
517 cached_at: std::time::Instant::now(),
518 request_hash,
519 });
520
521 while cache.len() > self.max_cached_responses {
523 cache.remove(0);
524 }
525 }
526 }
527
528 pub fn is_similar_to_recent(&self, analysis: &RequestAnalysis) -> bool {
530 self.get_cached_response(analysis).is_some()
531 }
532}
533
534impl Default for ContextManager {
535 fn default() -> Self {
536 Self::new()
537 }
538}
539
540#[derive(Debug, Clone)]
542pub struct ContextManagerStats {
543 pub context_cache: CacheStats,
545 pub file_cache: CacheStats,
547 pub history_cache: CacheStats,
549}
550
551pub fn estimate_tokens(s: &str) -> usize {
553 (s.len() / 4).max(1)
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560
561 #[test]
562 fn test_context_manager_new() {
563 let manager = ContextManager::new();
564 assert!(manager.max_context_tokens > 0);
565 }
566
567 #[test]
568 fn test_analyze_file_request() {
569 let manager = ContextManager::new();
570 let analysis = manager.analyze("Read the file src/main.rs");
571
572 assert!(analysis.categories.contains(&ContextCategory::FileSystem));
573 assert!(analysis.needs_tools);
574 }
575
576 #[test]
577 fn test_resolve_context_general() {
578 let manager = ContextManager::new();
579 let context = manager.resolve_simple("What is Rust?", None);
580
581 assert!(context.categories.contains(&ContextCategory::General));
582 assert!(context.tools.is_empty());
583 }
584
585 #[test]
586 fn test_resolve_context_with_tools() {
587 let manager = ContextManager::new().with_tool_provider(Box::new(|| {
588 vec![
589 ("file".to_string(), "Read/write files".to_string()),
590 ("shell".to_string(), "Run shell commands".to_string()),
591 ("git".to_string(), "Git operations".to_string()),
592 ]
593 }));
594 let context = manager.resolve_simple("List files in the current directory", None);
595
596 assert!(context.categories.contains(&ContextCategory::FileSystem));
597 assert!(!context.tools.is_empty());
598 }
599
600 #[test]
601 fn test_resolve_context_without_provider() {
602 let manager = ContextManager::new();
603 let context = manager.resolve_simple("List files in the current directory", None);
604
605 assert!(context.categories.contains(&ContextCategory::FileSystem));
606 assert!(context.tools.is_empty());
608 }
609
610 #[test]
611 fn test_cache_stats() {
612 let manager = ContextManager::new();
613 let _ = manager.resolve_simple("Read a file", None);
615 let _ = manager.resolve_simple("Git status", None);
616
617 let stats = manager.cache_stats();
618 assert!(stats.context_cache.size > 0);
619 }
620
621 #[test]
622 fn test_workspace_dir_resolution() {
623 use std::io::Write;
624
625 let temp_dir = std::env::temp_dir().join("gestura_test_workspace");
626 let _ = std::fs::create_dir_all(&temp_dir);
627 let file_path = temp_dir.join("test_file.txt");
628 let mut file = std::fs::File::create(&file_path).unwrap();
629 writeln!(file, "hello workspace").unwrap();
630
631 let manager = ContextManager::new();
632 let req = "Read the file test_file.txt".to_string();
633
634 let ctx = manager.resolve_simple(&req, Some(&temp_dir));
635
636 let _ = std::fs::remove_dir_all(&temp_dir);
638
639 assert!(ctx.categories.contains(&ContextCategory::FileSystem));
640 assert!(!ctx.files.is_empty());
641 assert_eq!(ctx.files[0].content.trim(), "hello workspace");
642 }
643}