1use crate::config::AppConfig;
7use crate::error::AppError;
8use crate::llm_provider::{AgentContext, select_provider};
9use std::collections::{HashMap, VecDeque};
10use std::sync::Mutex;
11
12struct PromptCache {
15 cache: HashMap<String, String>,
16 lru_queue: VecDeque<String>,
17 max_size: usize,
18}
19
20impl PromptCache {
21 fn new(max_size: usize) -> Self {
22 Self {
23 cache: HashMap::new(),
24 lru_queue: VecDeque::new(),
25 max_size,
26 }
27 }
28
29 fn get(&mut self, key: &str) -> Option<String> {
30 if let Some(value) = self.cache.get(key) {
31 self.lru_queue.retain(|k| k != key);
33 self.lru_queue.push_back(key.to_string());
34 Some(value.clone())
35 } else {
36 None
37 }
38 }
39
40 fn insert(&mut self, key: String, value: String) {
41 if self.cache.len() >= self.max_size
43 && !self.cache.contains_key(&key)
44 && let Some(lru_key) = self.lru_queue.pop_front()
45 {
46 self.cache.remove(&lru_key);
47 tracing::debug!(evicted_key = %lru_key, "Evicted LRU cache entry");
48 }
49
50 self.cache.insert(key.clone(), value);
52
53 self.lru_queue.retain(|k| k != &key);
55 self.lru_queue.push_back(key);
56 }
57
58 fn clear(&mut self) {
59 self.cache.clear();
60 self.lru_queue.clear();
61 }
62}
63
64lazy_static::lazy_static! {
65 static ref PROMPT_CACHE: Mutex<PromptCache> = Mutex::new(PromptCache::new(20));
66}
67
68fn generate_cache_key(prompt: &str, config: &AppConfig, context: &Option<PromptContext>) -> String {
74 use std::collections::hash_map::DefaultHasher;
75 use std::hash::{Hash, Hasher};
76
77 let mut hasher = DefaultHasher::new();
78
79 config.llm.primary.hash(&mut hasher);
81 match config.llm.primary.as_str() {
82 "openai" => {
83 if let Some(c) = &config.llm.openai {
84 c.base_url.hash(&mut hasher);
85 c.model.hash(&mut hasher);
86 }
87 }
88 "anthropic" => {
89 if let Some(c) = &config.llm.anthropic {
90 c.base_url.hash(&mut hasher);
91 c.model.hash(&mut hasher);
92 c.thinking_budget_tokens.hash(&mut hasher);
93 }
94 }
95 "grok" => {
96 if let Some(c) = &config.llm.grok {
97 c.base_url.hash(&mut hasher);
98 c.model.hash(&mut hasher);
99 }
100 }
101 "ollama" => {
102 if let Some(c) = &config.llm.ollama {
103 c.base_url.hash(&mut hasher);
104 c.model.hash(&mut hasher);
105 }
106 }
107 _ => {}
108 }
109 prompt.hash(&mut hasher);
110
111 if let Some(ctx) = context {
113 if let Some(history) = &ctx.session_history {
114 for (role, content) in history {
115 role.hash(&mut hasher);
116 content.hash(&mut hasher);
117 }
118 }
119 if let Some((path, content)) = &ctx.active_file {
120 path.hash(&mut hasher);
121 content.hash(&mut hasher);
122 }
123 if let Some(info) = &ctx.project_info {
124 info.hash(&mut hasher);
125 }
126 if let Some(entries) = &ctx.knowledge_entries {
127 for entry in entries {
128 entry.hash(&mut hasher);
129 }
130 }
131 }
132
133 format!("{:x}", hasher.finish())
134}
135
136#[derive(Debug, Clone, Default)]
138pub struct PromptContext {
139 pub session_history: Option<Vec<(String, String)>>,
142
143 pub active_file: Option<(String, String)>,
146
147 pub project_info: Option<String>,
150
151 pub knowledge_entries: Option<Vec<String>>,
154}
155
156impl PromptContext {
157 pub fn new() -> Self {
159 Self::default()
160 }
161
162 pub fn with_session_history(mut self, history: Vec<(String, String)>) -> Self {
164 self.session_history = Some(history);
165 self
166 }
167
168 pub fn with_active_file(mut self, path: String, content: String) -> Self {
170 self.active_file = Some((path, content));
171 self
172 }
173
174 pub fn with_project_info(mut self, info: String) -> Self {
176 self.project_info = Some(info);
177 self
178 }
179
180 pub fn with_knowledge(mut self, entries: Vec<String>) -> Self {
182 self.knowledge_entries = Some(entries);
183 self
184 }
185
186 pub fn is_empty(&self) -> bool {
188 self.session_history.is_none()
189 && self.active_file.is_none()
190 && self.project_info.is_none()
191 && self.knowledge_entries.is_none()
192 }
193}
194
195fn get_enhancement_system_prompt(style: &str, max_length_multiplier: f64) -> String {
197 let style_guidance = match style {
198 "detailed" => {
199 "Be thorough and comprehensive. Add detailed context, examples, and step-by-step breakdowns. Explain the reasoning behind requests."
200 }
201 "technical" => {
202 "Use precise technical language. Include specific implementation details, edge cases, and technical constraints. Reference relevant technologies and best practices."
203 }
204 "concise" => {
205 "Be brief and to the point. Add only essential context and clarity. Avoid unnecessary elaboration."
206 }
207 _ => {
208 "Be brief and to the point. Add only essential context and clarity. Avoid unnecessary elaboration."
209 }
210 };
211
212 format!(
213 r#"You are a prompt enhancement assistant. Your task is to improve user prompts to be more effective for AI assistants.
214
215Style: {}
216
217Guidelines:
2181. Preserve the user's intent and core request
2192. Add relevant context and specificity where helpful
2203. Structure complex requests into clear steps
2214. Include success criteria when appropriate
2225. Keep enhancements within {:.1}x original length
2236. Maintain the user's tone and style
2247. If the prompt is already clear and well-structured, make minimal changes
2258. When context is provided (conversation history, files, project info), use it to make the prompt more specific and actionable
2269. Reference relevant context naturally without being verbose
227
228Respond with ONLY the enhanced prompt, no explanations or meta-commentary."#,
229 style_guidance, max_length_multiplier
230 )
231}
232
233fn format_context(context: &PromptContext) -> String {
235 let mut sections = Vec::new();
236
237 if let Some(history) = context
239 .session_history
240 .as_ref()
241 .filter(|history| !history.is_empty())
242 {
243 let mut history_text = String::from("## Recent Conversation:\n");
244 for (role, content) in history {
245 let truncated = if content.len() > 500 {
247 format!("{}...", &content[..500])
248 } else {
249 content.clone()
250 };
251 history_text.push_str(&format!("{}: {}\n", role, truncated));
252 }
253 sections.push(history_text);
254 }
255
256 if let Some((path, content)) = &context.active_file {
258 let mut file_text = format!("## Active File: {}\n", path);
259 let truncated = if content.len() > 1000 {
261 format!("{}...\n[truncated]", &content[..1000])
262 } else {
263 content.clone()
264 };
265 file_text.push_str(&truncated);
266 sections.push(file_text);
267 }
268
269 if let Some(info) = &context.project_info {
271 sections.push(format!("## Project Context:\n{}\n", info));
272 }
273
274 if let Some(entries) = context
276 .knowledge_entries
277 .as_ref()
278 .filter(|entries| !entries.is_empty())
279 {
280 let mut knowledge_text = String::from("## Relevant Knowledge:\n");
281 for entry in entries {
282 let truncated = if entry.len() > 300 {
284 format!("- {}...\n", &entry[..300])
285 } else {
286 format!("- {}\n", entry)
287 };
288 knowledge_text.push_str(&truncated);
289 }
290 sections.push(knowledge_text);
291 }
292
293 if sections.is_empty() {
294 String::new()
295 } else {
296 format!("# Available Context:\n\n{}", sections.join("\n"))
297 }
298}
299
300pub async fn enhance_prompt_with_llm(
346 prompt: &str,
347 config: &AppConfig,
348 context: Option<PromptContext>,
349) -> Result<String, AppError> {
350 let trimmed_prompt = prompt.trim();
352 if trimmed_prompt.is_empty() {
353 return Err(AppError::Llm("Prompt cannot be empty".to_string()));
354 }
355
356 let cache_key = generate_cache_key(trimmed_prompt, config, &context);
358 {
359 let mut cache = PROMPT_CACHE.lock().unwrap();
360 if let Some(cached_result) = cache.get(&cache_key) {
361 tracing::debug!(
362 cache_key = %cache_key,
363 "Returning cached prompt enhancement"
364 );
365 return Ok(cached_result);
366 }
367 }
368
369 let agent_context = AgentContext {
371 agent_id: "prompt_enhancer".to_string(),
372 };
373
374 let provider = select_provider(config, &agent_context);
376
377 let enhancement_settings = &config.prompt_enhancement;
379 let system_prompt = get_enhancement_system_prompt(
380 &enhancement_settings.style,
381 enhancement_settings.max_length_multiplier(),
382 );
383
384 let context_section = if let Some(ctx) = context.clone() {
386 format_context(&ctx)
387 } else {
388 String::new()
389 };
390
391 let full_prompt = if context_section.is_empty() {
393 format!(
394 "{}\n\nUser prompt to enhance:\n{}\n\nEnhanced prompt:",
395 system_prompt, trimmed_prompt
396 )
397 } else {
398 format!(
399 "{}\n\n{}\n\nUser prompt to enhance:\n{}\n\nEnhanced prompt:",
400 system_prompt, context_section, trimmed_prompt
401 )
402 };
403
404 tracing::debug!(
405 original_length = trimmed_prompt.len(),
406 has_context = !context_section.is_empty(),
407 cache_key = %cache_key,
408 "Enhancing prompt with LLM (cache miss)"
409 );
410
411 let enhanced = provider.call(&full_prompt).await?;
413
414 let cleaned = enhanced.trim().trim_matches('"').trim();
419
420 if cleaned.is_empty() {
421 tracing::warn!("LLM returned empty enhancement, using original prompt");
422 return Ok(trimmed_prompt.to_string());
423 }
424
425 tracing::debug!(
426 original_length = trimmed_prompt.len(),
427 enhanced_length = cleaned.len(),
428 expansion_ratio = cleaned.len() as f64 / trimmed_prompt.len() as f64,
429 "Prompt enhancement complete"
430 );
431
432 {
434 let mut cache = PROMPT_CACHE.lock().unwrap();
435 cache.insert(cache_key, cleaned.to_string());
436 tracing::debug!("Cached prompt enhancement");
437 }
438
439 Ok(cleaned.to_string())
440}
441
442pub fn clear_prompt_cache() {
445 let mut cache = PROMPT_CACHE.lock().unwrap();
446 cache.clear();
447 tracing::info!("Cleared prompt enhancement cache");
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453 use crate::config::PromptEnhancementSettings;
454
455 #[tokio::test]
456 async fn test_enhance_prompt_with_unconfigured_provider() {
457 let config = AppConfig::default();
459
460 let original = "fix the bug";
461 let result = enhance_prompt_with_llm(original, &config, None).await;
462
463 assert!(result.is_err());
465 let err = result.unwrap_err().to_string();
466 assert!(err.contains("not configured"));
467 }
468
469 #[test]
470 fn test_empty_prompt_validation() {
471 let rt = tokio::runtime::Runtime::new().unwrap();
472 rt.block_on(async {
473 let config = AppConfig::default();
474 let result = enhance_prompt_with_llm("", &config, None).await;
475 assert!(result.is_err());
476 assert!(
477 result
478 .unwrap_err()
479 .to_string()
480 .contains("Prompt cannot be empty")
481 );
482 });
483 }
484
485 #[test]
486 fn test_whitespace_only_prompt_validation() {
487 let rt = tokio::runtime::Runtime::new().unwrap();
488 rt.block_on(async {
489 let config = AppConfig::default();
490 let result = enhance_prompt_with_llm(" \n\t ", &config, None).await;
491 assert!(result.is_err());
492 });
493 }
494
495 #[test]
496 fn test_enhance_with_session_context_unconfigured() {
497 let rt = tokio::runtime::Runtime::new().unwrap();
498 rt.block_on(async {
499 let config = AppConfig::default();
500
501 let context = PromptContext::new().with_session_history(vec![
502 (
503 "user".to_string(),
504 "I'm working on authentication".to_string(),
505 ),
506 ("assistant".to_string(), "I can help with that".to_string()),
507 ]);
508
509 let result = enhance_prompt_with_llm("add login", &config, Some(context)).await;
511 assert!(result.is_err());
512 });
513 }
514
515 #[test]
516 fn test_context_formatting() {
517 let context = PromptContext::new()
518 .with_session_history(vec![
519 ("user".to_string(), "Hello".to_string()),
520 ("assistant".to_string(), "Hi there!".to_string()),
521 ])
522 .with_project_info("A Rust project".to_string());
523
524 let formatted = format_context(&context);
525 assert!(formatted.contains("Recent Conversation"));
526 assert!(formatted.contains("Project Context"));
527 assert!(formatted.contains("Hello"));
528 assert!(formatted.contains("A Rust project"));
529 }
530
531 #[test]
532 fn test_empty_context() {
533 let context = PromptContext::new();
534 let formatted = format_context(&context);
535 assert!(formatted.is_empty());
536 }
537
538 #[test]
539 fn test_cache_operations() {
540 clear_prompt_cache();
542
543 clear_prompt_cache();
545 }
546
547 #[test]
548 fn test_cache_key_generation() {
549 let prompt1 = "test prompt";
550 let prompt2 = "test prompt";
551 let prompt3 = "different prompt";
552
553 let config = AppConfig::default();
554
555 let key1 = generate_cache_key(prompt1, &config, &None);
557 let key2 = generate_cache_key(prompt2, &config, &None);
558 assert_eq!(key1, key2);
559
560 let key3 = generate_cache_key(prompt3, &config, &None);
562 assert_ne!(key1, key3);
563
564 let mut openai_cfg = AppConfig::default();
566 openai_cfg.llm.primary = "openai".to_string();
567 openai_cfg.llm.openai = Some(crate::config::OpenAiConfig {
568 api_key: String::new(),
569 base_url: None,
570 model: "gpt-4o".to_string(),
571 });
572 let mut openai_cfg_2 = openai_cfg.clone();
573 if let Some(c) = openai_cfg_2.llm.openai.as_mut() {
574 c.model = "gpt-4o-mini".to_string();
575 }
576 let key_model_1 = generate_cache_key(prompt1, &openai_cfg, &None);
577 let key_model_2 = generate_cache_key(prompt1, &openai_cfg_2, &None);
578 assert_ne!(key_model_1, key_model_2);
579
580 let context = Some(
582 PromptContext::new()
583 .with_session_history(vec![("user".to_string(), "context".to_string())]),
584 );
585 let key4 = generate_cache_key(prompt1, &config, &context);
586 assert_ne!(key1, key4);
587 }
588
589 #[test]
590 fn test_enhancement_style_system_prompts() {
591 let concise_prompt = get_enhancement_system_prompt("concise", 3.0);
593 assert!(concise_prompt.contains("Be brief and to the point"));
594 assert!(concise_prompt.contains("3.0x"));
595
596 let detailed_prompt = get_enhancement_system_prompt("detailed", 4.0);
598 assert!(detailed_prompt.contains("Be thorough and comprehensive"));
599 assert!(detailed_prompt.contains("4.0x"));
600
601 let technical_prompt = get_enhancement_system_prompt("technical", 2.5);
603 assert!(technical_prompt.contains("Use precise technical language"));
604 assert!(technical_prompt.contains("2.5x"));
605
606 let unknown_prompt = get_enhancement_system_prompt("unknown", 3.0);
608 assert!(unknown_prompt.contains("Be brief and to the point"));
609 }
610
611 #[test]
612 fn test_user_preferences_settings() {
613 let mut config = AppConfig::default();
615
616 config.prompt_enhancement.style = "detailed".to_string();
618 config.prompt_enhancement.set_max_length_multiplier(4.0);
619 assert_eq!(config.prompt_enhancement.style, "detailed");
620 assert_eq!(config.prompt_enhancement.max_length_multiplier(), 4.0);
621
622 config.prompt_enhancement.style = "technical".to_string();
624 config.prompt_enhancement.set_max_length_multiplier(2.0);
625 assert_eq!(config.prompt_enhancement.style, "technical");
626 assert_eq!(config.prompt_enhancement.max_length_multiplier(), 2.0);
627 }
628
629 #[test]
630 fn test_max_length_multiplier_conversion() {
631 let mut settings = PromptEnhancementSettings::default();
632
633 assert_eq!(settings.max_length_multiplier(), 3.0);
635
636 settings.set_max_length_multiplier(2.5);
638 assert_eq!(settings.max_length_multiplier(), 2.5);
639
640 settings.set_max_length_multiplier(4.0);
641 assert_eq!(settings.max_length_multiplier(), 4.0);
642
643 settings.set_max_length_multiplier(0.5);
645 assert_eq!(settings.max_length_multiplier(), 1.0);
646
647 settings.set_max_length_multiplier(10.0);
648 assert_eq!(settings.max_length_multiplier(), 5.0);
649 }
650}