gestura_core_llm/
model_capabilities.rs

1//! Dynamic model capabilities discovery and caching.
2//!
3//! This module provides runtime discovery of model capabilities (context length,
4//! max output tokens, feature support) through multiple strategies:
5//!
6//! 1. **API Discovery** - Query provider model endpoints for metadata
7//!    - Anthropic: `/v1/models/{id}` → `max_input_tokens`
8//!    - Gemini: `/v1beta/models/{id}` → `inputTokenLimit`
9//!    - Grok: `/v1/language-models` → context window per model
10//!    - Ollama: `/api/show` → `model_info.{arch}.context_length`
11//! 2. **Error-Driven Learning** - Parse limits from context_length_exceeded errors
12//! 3. **Cached Knowledge** - Remember discovered limits across requests
13//! 4. **Conservative Fallback** - Safe defaults for unknown models
14//!
15//! ## Design Goals
16//!
17//! - **Dynamic over static** - Learn limits at runtime, not hardcoded
18//! - **Graceful degradation** - Work even when APIs are unavailable
19//! - **Error recovery** - Extract actual limits from error messages
20//!
21//! ## Usage
22//!
23//! ```rust,ignore
24//! use gestura_core_llm::model_capabilities::{ModelCapabilities, ModelCapabilitiesCache};
25//!
26//! let cache = ModelCapabilitiesCache::new();
27//!
28//! // Discover from API (async)
29//! cache.discover_from_api("anthropic", "claude-sonnet-4-20250514", Some(api_key)).await;
30//!
31//! // Learn from an error (sync)
32//! cache.learn_from_error("openai", "gpt-3.5-turbo",
33//!     "maximum context length is 16385 tokens");
34//!
35//! // Get capabilities (uses discovered/learned value, falls back to heuristic)
36//! let caps = cache.get("openai", "gpt-3.5-turbo");
37//! ```
38
39use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41use std::sync::{Arc, RwLock};
42
43/// Model capabilities describing limits and supported features.
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
45pub struct ModelCapabilities {
46    /// Maximum context window in tokens, always stored as **input + output
47    /// combined** regardless of whether the provider exposes a single combined
48    /// limit (e.g. OpenAI) or separate per-modality limits (e.g. Anthropic
49    /// `max_input_tokens`, Gemini `inputTokenLimit`).
50    ///
51    /// **Do not use this field directly for prompt-budget decisions.**  Always
52    /// call [`Self::max_input_tokens()`] instead, which subtracts
53    /// `max_output_tokens` to yield the tokens available for the prompt.
54    ///
55    /// ### Invariant
56    /// `context_length = max_input_tokens() + max_output_tokens`
57    ///
58    /// Discovery code for providers with separate input/output limits must
59    /// store `input_limit + output_limit` here so that `max_input_tokens()`
60    /// correctly recovers `input_limit` without double-subtracting.
61    pub context_length: usize,
62    /// Maximum output/completion tokens the model can generate
63    pub max_output_tokens: usize,
64    /// Whether the model supports native tool/function calling
65    pub supports_tools: bool,
66    /// Whether the model supports vision/image inputs
67    pub supports_vision: bool,
68    /// Whether the model supports streaming responses
69    pub supports_streaming: bool,
70    /// Provider name for reference
71    pub provider: String,
72    /// Model ID for reference
73    pub model_id: String,
74    /// How this capability was discovered
75    pub source: CapabilitySource,
76}
77
78/// How the capability information was obtained
79#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
80pub enum CapabilitySource {
81    /// Queried from provider API
82    ApiDiscovery,
83    /// Extracted from an error message
84    ErrorLearned,
85    /// User-configured override
86    UserConfig,
87    /// Static fallback (least reliable)
88    #[default]
89    StaticFallback,
90}
91
92impl Default for ModelCapabilities {
93    fn default() -> Self {
94        Self {
95            context_length: 8_192, // Very conservative default
96            max_output_tokens: 4_096,
97            supports_tools: true,
98            supports_vision: false,
99            supports_streaming: true,
100            provider: "unknown".to_string(),
101            model_id: "unknown".to_string(),
102            source: CapabilitySource::StaticFallback,
103        }
104    }
105}
106
107impl ModelCapabilities {
108    /// Create capabilities with known values
109    pub fn new(
110        provider: &str,
111        model_id: &str,
112        context_length: usize,
113        max_output_tokens: usize,
114        source: CapabilitySource,
115    ) -> Self {
116        Self {
117            context_length,
118            max_output_tokens,
119            supports_tools: true,
120            supports_vision: false,
121            supports_streaming: true,
122            provider: provider.to_string(),
123            model_id: model_id.to_string(),
124            source,
125        }
126    }
127
128    /// Set vision support
129    pub fn with_vision(mut self, supports: bool) -> Self {
130        self.supports_vision = supports;
131        self
132    }
133
134    /// Set tool support
135    pub fn with_tools(mut self, supports: bool) -> Self {
136        self.supports_tools = supports;
137        self
138    }
139
140    /// Calculate the effective max input tokens (context - reserved output)
141    pub fn max_input_tokens(&self) -> usize {
142        self.context_length.saturating_sub(self.max_output_tokens)
143    }
144
145    /// Check if this capability is from a reliable source
146    pub fn is_reliable(&self) -> bool {
147        matches!(
148            self.source,
149            CapabilitySource::ApiDiscovery | CapabilitySource::UserConfig
150        )
151    }
152}
153
154/// Thread-safe cache for learned model capabilities.
155///
156/// Capabilities are discovered dynamically and cached for future use.
157/// The cache persists for the lifetime of the application.
158#[derive(Debug, Clone, Default)]
159pub struct ModelCapabilitiesCache {
160    cache: Arc<RwLock<HashMap<String, ModelCapabilities>>>,
161}
162
163impl ModelCapabilitiesCache {
164    /// Create a new empty cache
165    pub fn new() -> Self {
166        Self {
167            cache: Arc::new(RwLock::new(HashMap::new())),
168        }
169    }
170
171    /// Generate cache key from provider and model
172    fn cache_key(provider: &str, model_id: &str) -> String {
173        format!("{}:{}", provider.to_lowercase(), model_id.to_lowercase())
174    }
175
176    /// Get capabilities for a model, using cache or falling back to heuristics
177    pub fn get(&self, provider: &str, model_id: &str) -> ModelCapabilities {
178        let key = Self::cache_key(provider, model_id);
179
180        // Check cache first
181        if let Some(caps) = self.cache.read().ok().and_then(|c| c.get(&key).cloned()) {
182            return caps;
183        }
184
185        // Fall back to heuristic-based capabilities
186        get_model_capabilities_heuristic(provider, model_id)
187    }
188
189    /// Learn model capabilities from a context_length_exceeded error message.
190    ///
191    /// Parses error messages like:
192    /// - "maximum context length is 16385 tokens"
193    /// - "your messages resulted in 17063 tokens"
194    ///
195    /// Returns the learned capabilities if parsing succeeded.
196    pub fn learn_from_error(
197        &self,
198        provider: &str,
199        model_id: &str,
200        error_message: &str,
201    ) -> Option<ModelCapabilities> {
202        let context_length = parse_context_length_from_error(error_message)?;
203
204        let caps = ModelCapabilities::new(
205            provider,
206            model_id,
207            context_length,
208            estimate_max_output(context_length),
209            CapabilitySource::ErrorLearned,
210        );
211
212        // Cache the learned capability
213        let key = Self::cache_key(provider, model_id);
214        if let Ok(mut cache) = self.cache.write() {
215            cache.insert(key, caps.clone());
216        }
217
218        tracing::info!(
219            provider = provider,
220            model = model_id,
221            context_length = context_length,
222            "Learned model context limit from error"
223        );
224
225        Some(caps)
226    }
227
228    /// Store capabilities discovered from API
229    pub fn store_from_api(&self, caps: ModelCapabilities) {
230        let key = Self::cache_key(&caps.provider, &caps.model_id);
231        if let Ok(mut cache) = self.cache.write() {
232            cache.insert(key, caps);
233        }
234    }
235
236    /// Store user-configured override
237    pub fn store_user_override(&self, provider: &str, model_id: &str, context_length: usize) {
238        let caps = ModelCapabilities::new(
239            provider,
240            model_id,
241            context_length,
242            estimate_max_output(context_length),
243            CapabilitySource::UserConfig,
244        );
245        let key = Self::cache_key(provider, model_id);
246        if let Ok(mut cache) = self.cache.write() {
247            cache.insert(key, caps);
248        }
249    }
250
251    /// Clear the cache (useful for testing)
252    pub fn clear(&self) {
253        if let Ok(mut cache) = self.cache.write() {
254            cache.clear();
255        }
256    }
257}
258
259/// Parse context length from an error message.
260///
261/// Handles various error formats from different providers:
262/// - OpenAI: "maximum context length is 16385 tokens"
263/// - Anthropic: "prompt is too long: X tokens > Y maximum"
264fn parse_context_length_from_error(error_message: &str) -> Option<usize> {
265    let msg = error_message.to_lowercase();
266
267    // OpenAI format: "maximum context length is 16385 tokens"
268    if let Some(idx) = msg.find("maximum context length is ") {
269        let start = idx + "maximum context length is ".len();
270        return extract_number_at(&msg[start..]);
271    }
272
273    // Alternative: "context length is X tokens"
274    if let Some(idx) = msg.find("context length is ") {
275        let start = idx + "context length is ".len();
276        return extract_number_at(&msg[start..]);
277    }
278
279    // Anthropic format: "X tokens > Y maximum"
280    if let Some(idx) = msg.find(" maximum") {
281        // Look backwards for the number
282        let before_max = &msg[..idx];
283        if let Some(gt_idx) = before_max.rfind("> ") {
284            let start = gt_idx + 2;
285            return extract_number_at(&before_max[start..]);
286        }
287    }
288
289    // Generic: look for "limit of X tokens"
290    if let Some(idx) = msg.find("limit of ") {
291        let start = idx + "limit of ".len();
292        return extract_number_at(&msg[start..]);
293    }
294
295    None
296}
297
298/// Extract a number from the start of a string
299fn extract_number_at(s: &str) -> Option<usize> {
300    let num_str: String = s.chars().take_while(|c| c.is_ascii_digit()).collect();
301    num_str.parse().ok()
302}
303
304/// Estimate max output tokens based on context length
305fn estimate_max_output(context_length: usize) -> usize {
306    match context_length {
307        0..=8_192 => 2_048,
308        8_193..=32_000 => 4_096,
309        32_001..=128_000 => 8_192,
310        _ => 16_384,
311    }
312}
313
314/// Get capabilities using heuristics (static fallback).
315///
316/// This is used when no cached/learned capabilities exist.
317/// Prefer using `ModelCapabilitiesCache::get()` which checks the cache first.
318pub fn get_model_capabilities_heuristic(provider: &str, model_id: &str) -> ModelCapabilities {
319    let model_lower = model_id.to_lowercase();
320
321    match provider.to_lowercase().as_str() {
322        "openai" => get_openai_capabilities(&model_lower, model_id),
323        "anthropic" => get_anthropic_capabilities(&model_lower, model_id),
324        "gemini" => get_gemini_capabilities(&model_lower, model_id),
325        "grok" => get_grok_capabilities(&model_lower, model_id),
326        "ollama" => get_ollama_capabilities(&model_lower, model_id),
327        _ => ModelCapabilities {
328            provider: provider.to_string(),
329            model_id: model_id.to_string(),
330            ..Default::default()
331        },
332    }
333}
334
335/// Convenience function - get capabilities without a cache (uses heuristics only)
336pub fn get_model_capabilities(provider: &str, model_id: &str) -> ModelCapabilities {
337    get_model_capabilities_heuristic(provider, model_id)
338}
339
340fn get_openai_capabilities(model_lower: &str, model_id: &str) -> ModelCapabilities {
341    let src = CapabilitySource::StaticFallback;
342
343    // GPT-4o family (128K context)
344    if model_lower.starts_with("gpt-4o") || model_lower.starts_with("chatgpt-4o") {
345        return ModelCapabilities::new("openai", model_id, 128_000, 16_384, src).with_vision(true);
346    }
347
348    // GPT-4 Turbo (128K context)
349    if model_lower.contains("gpt-4-turbo") || model_lower.contains("gpt-4-1106") {
350        return ModelCapabilities::new("openai", model_id, 128_000, 4_096, src).with_vision(true);
351    }
352
353    // GPT-4 base (8K context)
354    if model_lower.starts_with("gpt-4") && !model_lower.contains("turbo") {
355        return ModelCapabilities::new("openai", model_id, 8_192, 4_096, src);
356    }
357
358    // GPT-3.5-turbo - use CONSERVATIVE default since we don't know which version
359    if model_lower.contains("gpt-3.5-turbo") {
360        // Conservative: assume older 4K limit, will learn actual limit from errors
361        return ModelCapabilities::new("openai", model_id, 4_096, 2_048, src);
362    }
363
364    // o1/o3/o4/o5 reasoning models (128K+ context)
365    if model_lower.starts_with("o1")
366        || model_lower.starts_with("o3")
367        || model_lower.starts_with("o4")
368        || model_lower.starts_with("o5")
369    {
370        return ModelCapabilities::new("openai", model_id, 128_000, 32_768, src);
371    }
372
373    // GPT-5.x and codex models (assume large context)
374    if model_lower.starts_with("gpt-5") || model_lower.contains("codex") {
375        return ModelCapabilities::new("openai", model_id, 128_000, 16_384, src);
376    }
377
378    // Unknown OpenAI model - use VERY conservative defaults
379    // Better to compact too early than hit API errors
380    ModelCapabilities::new("openai", model_id, 8_192, 4_096, src)
381}
382
383fn get_anthropic_capabilities(model_lower: &str, model_id: &str) -> ModelCapabilities {
384    let src = CapabilitySource::StaticFallback;
385
386    // Claude 3.x / 4.x (200 K input, 8 K output — independent limits).
387    // context_length = input + output so that max_input_tokens() = 200 000.
388    if model_lower.contains("claude-3")
389        || model_lower.contains("claude-sonnet-4")
390        || model_lower.contains("claude-opus-4")
391    {
392        return ModelCapabilities::new("anthropic", model_id, 200_000 + 8_192, 8_192, src)
393            .with_vision(true);
394    }
395
396    // Claude 2.x (100 K input, 4 K output — independent limits).
397    if model_lower.contains("claude-2") {
398        return ModelCapabilities::new("anthropic", model_id, 100_000 + 4_096, 4_096, src);
399    }
400
401    // Unknown Anthropic model — conservative (32 K input, 4 K output).
402    ModelCapabilities::new("anthropic", model_id, 32_000 + 4_096, 4_096, src)
403}
404
405fn get_gemini_capabilities(model_lower: &str, model_id: &str) -> ModelCapabilities {
406    let src = CapabilitySource::StaticFallback;
407
408    // Gemini 2.0 (1 M input, 8 K output — independent limits).
409    // context_length = input + output so that max_input_tokens() = 1 000 000.
410    if model_lower.contains("gemini-2") {
411        return ModelCapabilities::new("gemini", model_id, 1_000_000 + 8_192, 8_192, src)
412            .with_vision(true);
413    }
414
415    // Gemini 1.5 Pro (1 M input, 8 K output).
416    if model_lower.contains("1.5-pro") || model_lower.contains("1.5pro") {
417        return ModelCapabilities::new("gemini", model_id, 1_000_000 + 8_192, 8_192, src)
418            .with_vision(true);
419    }
420
421    // Gemini 1.5 Flash (1 M input, 8 K output).
422    if model_lower.contains("1.5-flash") || model_lower.contains("flash") {
423        return ModelCapabilities::new("gemini", model_id, 1_000_000 + 8_192, 8_192, src)
424            .with_vision(true);
425    }
426
427    // Unknown Gemini model — conservative (32 K input, 8 K output).
428    ModelCapabilities::new("gemini", model_id, 32_000 + 8_192, 8_192, src)
429}
430
431fn get_grok_capabilities(model_lower: &str, model_id: &str) -> ModelCapabilities {
432    let src = CapabilitySource::StaticFallback;
433
434    // Grok-2 and Grok-3 (131 072 input, 8 192 output — independent limits).
435    // context_length = input + output so that max_input_tokens() = 131 072.
436    if model_lower.contains("grok-2") || model_lower.contains("grok-3") {
437        return ModelCapabilities::new("grok", model_id, 131_072 + 8_192, 8_192, src)
438            .with_vision(true);
439    }
440
441    // Grok-1 (8 K input, 4 K output).
442    if model_lower.contains("grok-1") || model_lower.contains("grok-beta") {
443        return ModelCapabilities::new("grok", model_id, 8_192 + 4_096, 4_096, src);
444    }
445
446    // Unknown Grok model — conservative (32 K input, 4 K output).
447    ModelCapabilities::new("grok", model_id, 32_000 + 4_096, 4_096, src)
448}
449
450fn get_ollama_capabilities(model_lower: &str, model_id: &str) -> ModelCapabilities {
451    let src = CapabilitySource::StaticFallback;
452
453    // Llama 3.2 (128K context)
454    if model_lower.contains("llama3.2") || model_lower.contains("llama-3.2") {
455        return ModelCapabilities::new("ollama", model_id, 128_000, 4_096, src);
456    }
457
458    // Llama 3.1 (128K context)
459    if model_lower.contains("llama3.1") || model_lower.contains("llama-3.1") {
460        return ModelCapabilities::new("ollama", model_id, 128_000, 4_096, src);
461    }
462
463    // Llama 3 (8K context)
464    if model_lower.contains("llama3") || model_lower.contains("llama-3") {
465        return ModelCapabilities::new("ollama", model_id, 8_192, 4_096, src);
466    }
467
468    // Mistral models (32K context)
469    if model_lower.contains("mistral") {
470        return ModelCapabilities::new("ollama", model_id, 32_000, 4_096, src);
471    }
472
473    // Mixtral (32K context)
474    if model_lower.contains("mixtral") {
475        return ModelCapabilities::new("ollama", model_id, 32_000, 4_096, src);
476    }
477
478    // CodeLlama (16K context)
479    if model_lower.contains("codellama") {
480        return ModelCapabilities::new("ollama", model_id, 16_384, 4_096, src);
481    }
482
483    // Qwen models (32K context for most)
484    if model_lower.contains("qwen") {
485        return ModelCapabilities::new("ollama", model_id, 32_000, 4_096, src);
486    }
487
488    // DeepSeek models (64K context)
489    if model_lower.contains("deepseek") {
490        return ModelCapabilities::new("ollama", model_id, 64_000, 4_096, src);
491    }
492
493    // Unknown Ollama model - very conservative default
494    ModelCapabilities::new("ollama", model_id, 4_096, 2_048, src)
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    #[test]
502    fn test_gpt4o_capabilities() {
503        let caps = get_model_capabilities("openai", "gpt-4o");
504        assert_eq!(caps.context_length, 128_000);
505        assert_eq!(caps.max_output_tokens, 16_384);
506        assert!(caps.supports_vision);
507        assert!(caps.supports_tools);
508    }
509
510    #[test]
511    fn test_gpt35_turbo_uses_conservative_default() {
512        // gpt-3.5-turbo uses conservative default since we can't know which version
513        let caps = get_model_capabilities("openai", "gpt-3.5-turbo");
514        assert_eq!(caps.context_length, 4_096); // Conservative - will learn actual limit
515    }
516
517    #[test]
518    fn test_claude_capabilities() {
519        let caps = get_model_capabilities("anthropic", "claude-sonnet-4-20250514");
520        // context_length = 200 000 (input) + 8 192 (output) per the combined invariant.
521        assert_eq!(caps.context_length, 200_000 + 8_192);
522        assert!(caps.supports_vision);
523    }
524
525    #[test]
526    fn test_gemini_capabilities() {
527        let caps = get_model_capabilities("gemini", "gemini-2.0-flash");
528        // context_length = 1 000 000 (input) + 8 192 (output).
529        assert_eq!(caps.context_length, 1_000_000 + 8_192);
530    }
531
532    #[test]
533    fn test_unknown_model_conservative_defaults() {
534        let caps = get_model_capabilities("openai", "unknown-model-xyz");
535        // OpenAI static fallback is a combined window, so context_length is
536        // used directly without further adjustment.
537        assert_eq!(caps.context_length, 8_192); // Very conservative default
538    }
539
540    #[test]
541    fn test_max_input_tokens() {
542        let caps = get_model_capabilities("openai", "gpt-4o");
543        // OpenAI uses a combined window: 128 K total − 16 K output = 112 K input.
544        assert_eq!(caps.max_input_tokens(), 128_000 - 16_384);
545    }
546
547    // -----------------------------------------------------------------------
548    // Invariant: max_input_tokens() must equal the provider's stated input
549    // limit for every heuristic entry, without double-subtracting output.
550    // -----------------------------------------------------------------------
551
552    #[test]
553    fn test_anthropic_max_input_tokens_equals_stated_input_limit() {
554        // Claude 3/4 — API states 200 000 input tokens.
555        let caps = get_model_capabilities("anthropic", "claude-sonnet-4-20250514");
556        assert_eq!(caps.max_input_tokens(), 200_000);
557
558        // Claude 2 — API states 100 000 input tokens.
559        let caps2 = get_model_capabilities("anthropic", "claude-2.1");
560        assert_eq!(caps2.max_input_tokens(), 100_000);
561    }
562
563    #[test]
564    fn test_gemini_max_input_tokens_equals_stated_input_limit() {
565        // Gemini 2.0 — API states 1 000 000 input tokens.
566        let caps = get_model_capabilities("gemini", "gemini-2.0-flash");
567        assert_eq!(caps.max_input_tokens(), 1_000_000);
568
569        // Gemini 1.5 Pro — API states 1 000 000 input tokens.
570        let caps2 = get_model_capabilities("gemini", "gemini-1.5-pro");
571        assert_eq!(caps2.max_input_tokens(), 1_000_000);
572    }
573
574    #[test]
575    fn test_grok_max_input_tokens_equals_stated_input_limit() {
576        // Grok-2/3 — API states 131 072 input tokens.
577        let caps = get_model_capabilities("grok", "grok-2");
578        assert_eq!(caps.max_input_tokens(), 131_072);
579
580        // Grok-1 — API states 8 192 input tokens.
581        let caps2 = get_model_capabilities("grok", "grok-1");
582        assert_eq!(caps2.max_input_tokens(), 8_192);
583    }
584
585    #[test]
586    fn test_openai_combined_window_is_not_double_subtracted() {
587        // OpenAI uses a true combined window — context_length IS input+output
588        // already, so subtraction in max_input_tokens() is correct and must
589        // NOT be applied a second time.
590        let caps = get_model_capabilities("openai", "gpt-4o");
591        assert_eq!(caps.context_length, 128_000);
592        assert_eq!(caps.max_output_tokens, 16_384);
593        assert_eq!(caps.max_input_tokens(), 128_000 - 16_384);
594    }
595
596    #[test]
597    fn test_context_length_invariant_holds_for_all_static_providers() {
598        // For every static heuristic:
599        //   context_length == max_input_tokens() + max_output_tokens
600        let models = [
601            ("openai", "gpt-4o"),
602            ("openai", "gpt-3.5-turbo"),
603            ("anthropic", "claude-sonnet-4-20250514"),
604            ("anthropic", "claude-2.1"),
605            ("gemini", "gemini-2.0-flash"),
606            ("gemini", "gemini-1.5-pro"),
607            ("grok", "grok-2"),
608            ("grok", "grok-1"),
609            ("ollama", "llama3.1"),
610            ("ollama", "mistral"),
611        ];
612        for (provider, model) in models {
613            let caps = get_model_capabilities(provider, model);
614            assert_eq!(
615                caps.context_length,
616                caps.max_input_tokens() + caps.max_output_tokens,
617                "{provider}/{model}: context_length invariant violated \
618                 (context_length={}, max_input_tokens()={}, max_output_tokens={})",
619                caps.context_length,
620                caps.max_input_tokens(),
621                caps.max_output_tokens,
622            );
623        }
624    }
625
626    #[test]
627    fn test_parse_openai_error() {
628        let error = "This model's maximum context length is 16385 tokens. \
629                     However, your messages resulted in 17063 tokens";
630        let length = parse_context_length_from_error(error);
631        assert_eq!(length, Some(16385));
632    }
633
634    #[test]
635    fn test_parse_generic_error() {
636        let error = "Request exceeds limit of 8192 tokens";
637        let length = parse_context_length_from_error(error);
638        assert_eq!(length, Some(8192));
639    }
640
641    #[test]
642    fn test_cache_learns_from_error() {
643        let cache = ModelCapabilitiesCache::new();
644
645        // Initially uses heuristic (conservative)
646        let caps_before = cache.get("openai", "gpt-3.5-turbo");
647        assert_eq!(caps_before.context_length, 4_096);
648
649        // Learn from error
650        cache.learn_from_error(
651            "openai",
652            "gpt-3.5-turbo",
653            "maximum context length is 16385 tokens",
654        );
655
656        // Now uses learned value
657        let caps_after = cache.get("openai", "gpt-3.5-turbo");
658        assert_eq!(caps_after.context_length, 16385);
659        assert_eq!(caps_after.source, CapabilitySource::ErrorLearned);
660    }
661
662    #[test]
663    fn test_cache_user_override() {
664        let cache = ModelCapabilitiesCache::new();
665
666        cache.store_user_override("openai", "custom-model", 32_000);
667
668        let caps = cache.get("openai", "custom-model");
669        assert_eq!(caps.context_length, 32_000);
670        assert_eq!(caps.source, CapabilitySource::UserConfig);
671    }
672
673    #[test]
674    fn test_capability_source_reliability() {
675        let api_caps =
676            ModelCapabilities::new("test", "test", 1000, 100, CapabilitySource::ApiDiscovery);
677        let error_caps =
678            ModelCapabilities::new("test", "test", 1000, 100, CapabilitySource::ErrorLearned);
679        let static_caps =
680            ModelCapabilities::new("test", "test", 1000, 100, CapabilitySource::StaticFallback);
681
682        assert!(api_caps.is_reliable());
683        assert!(!error_caps.is_reliable()); // Learned is useful but not "reliable"
684        assert!(!static_caps.is_reliable());
685    }
686}