1use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41use std::sync::{Arc, RwLock};
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
45pub struct ModelCapabilities {
46 pub context_length: usize,
62 pub max_output_tokens: usize,
64 pub supports_tools: bool,
66 pub supports_vision: bool,
68 pub supports_streaming: bool,
70 pub provider: String,
72 pub model_id: String,
74 pub source: CapabilitySource,
76}
77
78#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
80pub enum CapabilitySource {
81 ApiDiscovery,
83 ErrorLearned,
85 UserConfig,
87 #[default]
89 StaticFallback,
90}
91
92impl Default for ModelCapabilities {
93 fn default() -> Self {
94 Self {
95 context_length: 8_192, max_output_tokens: 4_096,
97 supports_tools: true,
98 supports_vision: false,
99 supports_streaming: true,
100 provider: "unknown".to_string(),
101 model_id: "unknown".to_string(),
102 source: CapabilitySource::StaticFallback,
103 }
104 }
105}
106
107impl ModelCapabilities {
108 pub fn new(
110 provider: &str,
111 model_id: &str,
112 context_length: usize,
113 max_output_tokens: usize,
114 source: CapabilitySource,
115 ) -> Self {
116 Self {
117 context_length,
118 max_output_tokens,
119 supports_tools: true,
120 supports_vision: false,
121 supports_streaming: true,
122 provider: provider.to_string(),
123 model_id: model_id.to_string(),
124 source,
125 }
126 }
127
128 pub fn with_vision(mut self, supports: bool) -> Self {
130 self.supports_vision = supports;
131 self
132 }
133
134 pub fn with_tools(mut self, supports: bool) -> Self {
136 self.supports_tools = supports;
137 self
138 }
139
140 pub fn max_input_tokens(&self) -> usize {
142 self.context_length.saturating_sub(self.max_output_tokens)
143 }
144
145 pub fn is_reliable(&self) -> bool {
147 matches!(
148 self.source,
149 CapabilitySource::ApiDiscovery | CapabilitySource::UserConfig
150 )
151 }
152}
153
154#[derive(Debug, Clone, Default)]
159pub struct ModelCapabilitiesCache {
160 cache: Arc<RwLock<HashMap<String, ModelCapabilities>>>,
161}
162
163impl ModelCapabilitiesCache {
164 pub fn new() -> Self {
166 Self {
167 cache: Arc::new(RwLock::new(HashMap::new())),
168 }
169 }
170
171 fn cache_key(provider: &str, model_id: &str) -> String {
173 format!("{}:{}", provider.to_lowercase(), model_id.to_lowercase())
174 }
175
176 pub fn get(&self, provider: &str, model_id: &str) -> ModelCapabilities {
178 let key = Self::cache_key(provider, model_id);
179
180 if let Some(caps) = self.cache.read().ok().and_then(|c| c.get(&key).cloned()) {
182 return caps;
183 }
184
185 get_model_capabilities_heuristic(provider, model_id)
187 }
188
189 pub fn learn_from_error(
197 &self,
198 provider: &str,
199 model_id: &str,
200 error_message: &str,
201 ) -> Option<ModelCapabilities> {
202 let context_length = parse_context_length_from_error(error_message)?;
203
204 let caps = ModelCapabilities::new(
205 provider,
206 model_id,
207 context_length,
208 estimate_max_output(context_length),
209 CapabilitySource::ErrorLearned,
210 );
211
212 let key = Self::cache_key(provider, model_id);
214 if let Ok(mut cache) = self.cache.write() {
215 cache.insert(key, caps.clone());
216 }
217
218 tracing::info!(
219 provider = provider,
220 model = model_id,
221 context_length = context_length,
222 "Learned model context limit from error"
223 );
224
225 Some(caps)
226 }
227
228 pub fn store_from_api(&self, caps: ModelCapabilities) {
230 let key = Self::cache_key(&caps.provider, &caps.model_id);
231 if let Ok(mut cache) = self.cache.write() {
232 cache.insert(key, caps);
233 }
234 }
235
236 pub fn store_user_override(&self, provider: &str, model_id: &str, context_length: usize) {
238 let caps = ModelCapabilities::new(
239 provider,
240 model_id,
241 context_length,
242 estimate_max_output(context_length),
243 CapabilitySource::UserConfig,
244 );
245 let key = Self::cache_key(provider, model_id);
246 if let Ok(mut cache) = self.cache.write() {
247 cache.insert(key, caps);
248 }
249 }
250
251 pub fn clear(&self) {
253 if let Ok(mut cache) = self.cache.write() {
254 cache.clear();
255 }
256 }
257}
258
259fn parse_context_length_from_error(error_message: &str) -> Option<usize> {
265 let msg = error_message.to_lowercase();
266
267 if let Some(idx) = msg.find("maximum context length is ") {
269 let start = idx + "maximum context length is ".len();
270 return extract_number_at(&msg[start..]);
271 }
272
273 if let Some(idx) = msg.find("context length is ") {
275 let start = idx + "context length is ".len();
276 return extract_number_at(&msg[start..]);
277 }
278
279 if let Some(idx) = msg.find(" maximum") {
281 let before_max = &msg[..idx];
283 if let Some(gt_idx) = before_max.rfind("> ") {
284 let start = gt_idx + 2;
285 return extract_number_at(&before_max[start..]);
286 }
287 }
288
289 if let Some(idx) = msg.find("limit of ") {
291 let start = idx + "limit of ".len();
292 return extract_number_at(&msg[start..]);
293 }
294
295 None
296}
297
298fn extract_number_at(s: &str) -> Option<usize> {
300 let num_str: String = s.chars().take_while(|c| c.is_ascii_digit()).collect();
301 num_str.parse().ok()
302}
303
304fn estimate_max_output(context_length: usize) -> usize {
306 match context_length {
307 0..=8_192 => 2_048,
308 8_193..=32_000 => 4_096,
309 32_001..=128_000 => 8_192,
310 _ => 16_384,
311 }
312}
313
314pub fn get_model_capabilities_heuristic(provider: &str, model_id: &str) -> ModelCapabilities {
319 let model_lower = model_id.to_lowercase();
320
321 match provider.to_lowercase().as_str() {
322 "openai" => get_openai_capabilities(&model_lower, model_id),
323 "anthropic" => get_anthropic_capabilities(&model_lower, model_id),
324 "gemini" => get_gemini_capabilities(&model_lower, model_id),
325 "grok" => get_grok_capabilities(&model_lower, model_id),
326 "ollama" => get_ollama_capabilities(&model_lower, model_id),
327 _ => ModelCapabilities {
328 provider: provider.to_string(),
329 model_id: model_id.to_string(),
330 ..Default::default()
331 },
332 }
333}
334
335pub fn get_model_capabilities(provider: &str, model_id: &str) -> ModelCapabilities {
337 get_model_capabilities_heuristic(provider, model_id)
338}
339
340fn get_openai_capabilities(model_lower: &str, model_id: &str) -> ModelCapabilities {
341 let src = CapabilitySource::StaticFallback;
342
343 if model_lower.starts_with("gpt-4o") || model_lower.starts_with("chatgpt-4o") {
345 return ModelCapabilities::new("openai", model_id, 128_000, 16_384, src).with_vision(true);
346 }
347
348 if model_lower.contains("gpt-4-turbo") || model_lower.contains("gpt-4-1106") {
350 return ModelCapabilities::new("openai", model_id, 128_000, 4_096, src).with_vision(true);
351 }
352
353 if model_lower.starts_with("gpt-4") && !model_lower.contains("turbo") {
355 return ModelCapabilities::new("openai", model_id, 8_192, 4_096, src);
356 }
357
358 if model_lower.contains("gpt-3.5-turbo") {
360 return ModelCapabilities::new("openai", model_id, 4_096, 2_048, src);
362 }
363
364 if model_lower.starts_with("o1")
366 || model_lower.starts_with("o3")
367 || model_lower.starts_with("o4")
368 || model_lower.starts_with("o5")
369 {
370 return ModelCapabilities::new("openai", model_id, 128_000, 32_768, src);
371 }
372
373 if model_lower.starts_with("gpt-5") || model_lower.contains("codex") {
375 return ModelCapabilities::new("openai", model_id, 128_000, 16_384, src);
376 }
377
378 ModelCapabilities::new("openai", model_id, 8_192, 4_096, src)
381}
382
383fn get_anthropic_capabilities(model_lower: &str, model_id: &str) -> ModelCapabilities {
384 let src = CapabilitySource::StaticFallback;
385
386 if model_lower.contains("claude-3")
389 || model_lower.contains("claude-sonnet-4")
390 || model_lower.contains("claude-opus-4")
391 {
392 return ModelCapabilities::new("anthropic", model_id, 200_000 + 8_192, 8_192, src)
393 .with_vision(true);
394 }
395
396 if model_lower.contains("claude-2") {
398 return ModelCapabilities::new("anthropic", model_id, 100_000 + 4_096, 4_096, src);
399 }
400
401 ModelCapabilities::new("anthropic", model_id, 32_000 + 4_096, 4_096, src)
403}
404
405fn get_gemini_capabilities(model_lower: &str, model_id: &str) -> ModelCapabilities {
406 let src = CapabilitySource::StaticFallback;
407
408 if model_lower.contains("gemini-2") {
411 return ModelCapabilities::new("gemini", model_id, 1_000_000 + 8_192, 8_192, src)
412 .with_vision(true);
413 }
414
415 if model_lower.contains("1.5-pro") || model_lower.contains("1.5pro") {
417 return ModelCapabilities::new("gemini", model_id, 1_000_000 + 8_192, 8_192, src)
418 .with_vision(true);
419 }
420
421 if model_lower.contains("1.5-flash") || model_lower.contains("flash") {
423 return ModelCapabilities::new("gemini", model_id, 1_000_000 + 8_192, 8_192, src)
424 .with_vision(true);
425 }
426
427 ModelCapabilities::new("gemini", model_id, 32_000 + 8_192, 8_192, src)
429}
430
431fn get_grok_capabilities(model_lower: &str, model_id: &str) -> ModelCapabilities {
432 let src = CapabilitySource::StaticFallback;
433
434 if model_lower.contains("grok-2") || model_lower.contains("grok-3") {
437 return ModelCapabilities::new("grok", model_id, 131_072 + 8_192, 8_192, src)
438 .with_vision(true);
439 }
440
441 if model_lower.contains("grok-1") || model_lower.contains("grok-beta") {
443 return ModelCapabilities::new("grok", model_id, 8_192 + 4_096, 4_096, src);
444 }
445
446 ModelCapabilities::new("grok", model_id, 32_000 + 4_096, 4_096, src)
448}
449
450fn get_ollama_capabilities(model_lower: &str, model_id: &str) -> ModelCapabilities {
451 let src = CapabilitySource::StaticFallback;
452
453 if model_lower.contains("llama3.2") || model_lower.contains("llama-3.2") {
455 return ModelCapabilities::new("ollama", model_id, 128_000, 4_096, src);
456 }
457
458 if model_lower.contains("llama3.1") || model_lower.contains("llama-3.1") {
460 return ModelCapabilities::new("ollama", model_id, 128_000, 4_096, src);
461 }
462
463 if model_lower.contains("llama3") || model_lower.contains("llama-3") {
465 return ModelCapabilities::new("ollama", model_id, 8_192, 4_096, src);
466 }
467
468 if model_lower.contains("mistral") {
470 return ModelCapabilities::new("ollama", model_id, 32_000, 4_096, src);
471 }
472
473 if model_lower.contains("mixtral") {
475 return ModelCapabilities::new("ollama", model_id, 32_000, 4_096, src);
476 }
477
478 if model_lower.contains("codellama") {
480 return ModelCapabilities::new("ollama", model_id, 16_384, 4_096, src);
481 }
482
483 if model_lower.contains("qwen") {
485 return ModelCapabilities::new("ollama", model_id, 32_000, 4_096, src);
486 }
487
488 if model_lower.contains("deepseek") {
490 return ModelCapabilities::new("ollama", model_id, 64_000, 4_096, src);
491 }
492
493 ModelCapabilities::new("ollama", model_id, 4_096, 2_048, src)
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn test_gpt4o_capabilities() {
503 let caps = get_model_capabilities("openai", "gpt-4o");
504 assert_eq!(caps.context_length, 128_000);
505 assert_eq!(caps.max_output_tokens, 16_384);
506 assert!(caps.supports_vision);
507 assert!(caps.supports_tools);
508 }
509
510 #[test]
511 fn test_gpt35_turbo_uses_conservative_default() {
512 let caps = get_model_capabilities("openai", "gpt-3.5-turbo");
514 assert_eq!(caps.context_length, 4_096); }
516
517 #[test]
518 fn test_claude_capabilities() {
519 let caps = get_model_capabilities("anthropic", "claude-sonnet-4-20250514");
520 assert_eq!(caps.context_length, 200_000 + 8_192);
522 assert!(caps.supports_vision);
523 }
524
525 #[test]
526 fn test_gemini_capabilities() {
527 let caps = get_model_capabilities("gemini", "gemini-2.0-flash");
528 assert_eq!(caps.context_length, 1_000_000 + 8_192);
530 }
531
532 #[test]
533 fn test_unknown_model_conservative_defaults() {
534 let caps = get_model_capabilities("openai", "unknown-model-xyz");
535 assert_eq!(caps.context_length, 8_192); }
539
540 #[test]
541 fn test_max_input_tokens() {
542 let caps = get_model_capabilities("openai", "gpt-4o");
543 assert_eq!(caps.max_input_tokens(), 128_000 - 16_384);
545 }
546
547 #[test]
553 fn test_anthropic_max_input_tokens_equals_stated_input_limit() {
554 let caps = get_model_capabilities("anthropic", "claude-sonnet-4-20250514");
556 assert_eq!(caps.max_input_tokens(), 200_000);
557
558 let caps2 = get_model_capabilities("anthropic", "claude-2.1");
560 assert_eq!(caps2.max_input_tokens(), 100_000);
561 }
562
563 #[test]
564 fn test_gemini_max_input_tokens_equals_stated_input_limit() {
565 let caps = get_model_capabilities("gemini", "gemini-2.0-flash");
567 assert_eq!(caps.max_input_tokens(), 1_000_000);
568
569 let caps2 = get_model_capabilities("gemini", "gemini-1.5-pro");
571 assert_eq!(caps2.max_input_tokens(), 1_000_000);
572 }
573
574 #[test]
575 fn test_grok_max_input_tokens_equals_stated_input_limit() {
576 let caps = get_model_capabilities("grok", "grok-2");
578 assert_eq!(caps.max_input_tokens(), 131_072);
579
580 let caps2 = get_model_capabilities("grok", "grok-1");
582 assert_eq!(caps2.max_input_tokens(), 8_192);
583 }
584
585 #[test]
586 fn test_openai_combined_window_is_not_double_subtracted() {
587 let caps = get_model_capabilities("openai", "gpt-4o");
591 assert_eq!(caps.context_length, 128_000);
592 assert_eq!(caps.max_output_tokens, 16_384);
593 assert_eq!(caps.max_input_tokens(), 128_000 - 16_384);
594 }
595
596 #[test]
597 fn test_context_length_invariant_holds_for_all_static_providers() {
598 let models = [
601 ("openai", "gpt-4o"),
602 ("openai", "gpt-3.5-turbo"),
603 ("anthropic", "claude-sonnet-4-20250514"),
604 ("anthropic", "claude-2.1"),
605 ("gemini", "gemini-2.0-flash"),
606 ("gemini", "gemini-1.5-pro"),
607 ("grok", "grok-2"),
608 ("grok", "grok-1"),
609 ("ollama", "llama3.1"),
610 ("ollama", "mistral"),
611 ];
612 for (provider, model) in models {
613 let caps = get_model_capabilities(provider, model);
614 assert_eq!(
615 caps.context_length,
616 caps.max_input_tokens() + caps.max_output_tokens,
617 "{provider}/{model}: context_length invariant violated \
618 (context_length={}, max_input_tokens()={}, max_output_tokens={})",
619 caps.context_length,
620 caps.max_input_tokens(),
621 caps.max_output_tokens,
622 );
623 }
624 }
625
626 #[test]
627 fn test_parse_openai_error() {
628 let error = "This model's maximum context length is 16385 tokens. \
629 However, your messages resulted in 17063 tokens";
630 let length = parse_context_length_from_error(error);
631 assert_eq!(length, Some(16385));
632 }
633
634 #[test]
635 fn test_parse_generic_error() {
636 let error = "Request exceeds limit of 8192 tokens";
637 let length = parse_context_length_from_error(error);
638 assert_eq!(length, Some(8192));
639 }
640
641 #[test]
642 fn test_cache_learns_from_error() {
643 let cache = ModelCapabilitiesCache::new();
644
645 let caps_before = cache.get("openai", "gpt-3.5-turbo");
647 assert_eq!(caps_before.context_length, 4_096);
648
649 cache.learn_from_error(
651 "openai",
652 "gpt-3.5-turbo",
653 "maximum context length is 16385 tokens",
654 );
655
656 let caps_after = cache.get("openai", "gpt-3.5-turbo");
658 assert_eq!(caps_after.context_length, 16385);
659 assert_eq!(caps_after.source, CapabilitySource::ErrorLearned);
660 }
661
662 #[test]
663 fn test_cache_user_override() {
664 let cache = ModelCapabilitiesCache::new();
665
666 cache.store_user_override("openai", "custom-model", 32_000);
667
668 let caps = cache.get("openai", "custom-model");
669 assert_eq!(caps.context_length, 32_000);
670 assert_eq!(caps.source, CapabilitySource::UserConfig);
671 }
672
673 #[test]
674 fn test_capability_source_reliability() {
675 let api_caps =
676 ModelCapabilities::new("test", "test", 1000, 100, CapabilitySource::ApiDiscovery);
677 let error_caps =
678 ModelCapabilities::new("test", "test", 1000, 100, CapabilitySource::ErrorLearned);
679 let static_caps =
680 ModelCapabilities::new("test", "test", 1000, 100, CapabilitySource::StaticFallback);
681
682 assert!(api_caps.is_reliable());
683 assert!(!error_caps.is_reliable()); assert!(!static_caps.is_reliable());
685 }
686}