gestura_core_llm/
model_discovery.rs

1//! Dynamic model metadata discovery via provider APIs.
2//!
3//! This module queries provider APIs at runtime to discover actual model
4//! capabilities (context length, output limits, features) rather than relying
5//! on static mappings.
6//!
7//! ## Supported Providers
8//!
9//! | Provider | Endpoint | API field(s) used | Stored as `context_length` |
10//! |----------|----------|-------------------|---------------------------|
11//! | Gemini | `GET /v1beta/models` | `inputTokenLimit` + `outputTokenLimit` | `input + output` |
12//! | Anthropic | `GET /v1/models` | `max_input_tokens` + `max_output_tokens` | `input + output` |
13//! | Grok (xAI) | `GET /v1/language-models` | `input_modalities.text.token_limit` + output | `input + output` |
14//! | Ollama | `POST /api/show` | `model_info.*.context_length` (combined) | as-is (already combined) |
15//! | OpenAI | N/A | Uses error-driven learning | N/A |
16//!
17//! For providers that expose separate input and output limits, discovery stores
18//! `input_limit + output_limit` as `context_length` so that
19//! `ModelCapabilities::max_input_tokens()` (which subtracts `max_output_tokens`)
20//! recovers the correct prompt budget without double-subtracting.
21
22use crate::model_capabilities::{CapabilitySource, ModelCapabilities, ModelCapabilitiesCache};
23use gestura_core_foundation::AppError;
24use std::time::Duration;
25
26/// Timeout for metadata discovery API calls (shorter than inference)
27const DISCOVERY_TIMEOUT_SECS: u64 = 10;
28
29/// Create a lightweight HTTP client for discovery calls
30fn discovery_client() -> reqwest::Client {
31    reqwest::Client::builder()
32        .timeout(Duration::from_secs(DISCOVERY_TIMEOUT_SECS))
33        .connect_timeout(Duration::from_secs(5))
34        .build()
35        .unwrap_or_else(|_| reqwest::Client::new())
36}
37
38/// Discover model capabilities from provider API and store in cache.
39///
40/// Returns the discovered capabilities, or None if discovery failed.
41/// Failures are logged but don't prevent operation - we fall back to heuristics.
42pub async fn discover_model_capabilities(
43    cache: &ModelCapabilitiesCache,
44    provider: &str,
45    model_id: &str,
46    api_key: Option<&str>,
47    base_url: Option<&str>,
48) -> Option<ModelCapabilities> {
49    let result = match provider.to_lowercase().as_str() {
50        "gemini" => discover_gemini(model_id, api_key, base_url).await,
51        "anthropic" => discover_anthropic(model_id, api_key).await,
52        "grok" => discover_grok(model_id, api_key).await,
53        "ollama" => discover_ollama(model_id, base_url).await,
54        // OpenAI doesn't expose context length in API - use error-driven learning
55        "openai" => {
56            tracing::debug!("OpenAI doesn't expose context length in API - using heuristics");
57            return None;
58        }
59        _ => {
60            tracing::debug!(provider = provider, "Unknown provider for discovery");
61            return None;
62        }
63    };
64
65    match result {
66        Ok(caps) => {
67            tracing::info!(
68                provider = provider,
69                model = model_id,
70                context_length = caps.context_length,
71                "Discovered model capabilities from API"
72            );
73            cache.store_from_api(caps.clone());
74            Some(caps)
75        }
76        Err(e) => {
77            tracing::debug!(
78                provider = provider,
79                model = model_id,
80                error = %e,
81                "Failed to discover model capabilities - using heuristics"
82            );
83            None
84        }
85    }
86}
87
88/// Discover capabilities for a Gemini model via Google's API.
89///
90/// Endpoint: `GET /v1beta/models/{model_id}`
91/// Returns: `inputTokenLimit`, `outputTokenLimit`
92async fn discover_gemini(
93    model_id: &str,
94    api_key: Option<&str>,
95    base_url: Option<&str>,
96) -> Result<ModelCapabilities, AppError> {
97    let key = api_key
98        .filter(|k| !k.is_empty())
99        .ok_or_else(|| AppError::Config("Gemini API key required for discovery".into()))?;
100
101    let base = base_url
102        .filter(|u| !u.is_empty())
103        .unwrap_or("https://generativelanguage.googleapis.com");
104
105    // Gemini model IDs may or may not have "models/" prefix
106    let model_path = if model_id.starts_with("models/") {
107        model_id.to_string()
108    } else {
109        format!("models/{}", model_id)
110    };
111
112    let url = format!(
113        "{}/v1beta/{}?key={}",
114        base.trim_end_matches('/'),
115        model_path,
116        key
117    );
118
119    let resp = discovery_client()
120        .get(&url)
121        .send()
122        .await
123        .map_err(|e| AppError::Llm(format!("Gemini discovery failed: {e}")))?;
124
125    if !resp.status().is_success() {
126        return Err(AppError::Llm(format!(
127            "Gemini API returned {}",
128            resp.status()
129        )));
130    }
131
132    let data: serde_json::Value = resp.json().await?;
133
134    let input_limit = data
135        .get("inputTokenLimit")
136        .and_then(|v| v.as_u64())
137        .unwrap_or(32_000) as usize;
138
139    let output_limit = data
140        .get("outputTokenLimit")
141        .and_then(|v| v.as_u64())
142        .unwrap_or(8_192) as usize;
143
144    // Gemini reports input and output limits independently.  Store their sum
145    // as `context_length` (the "combined window" invariant) so that
146    // `max_input_tokens() = context_length - max_output_tokens` correctly
147    // returns `input_limit` without double-subtracting.
148    Ok(ModelCapabilities::new(
149        "gemini",
150        model_id,
151        input_limit + output_limit,
152        output_limit,
153        CapabilitySource::ApiDiscovery,
154    )
155    .with_vision(true))
156}
157
158/// Discover capabilities for an Anthropic model via their API.
159///
160/// Endpoint: `GET /v1/models/{model_id}`
161/// Returns: `max_input_tokens`, `max_output_tokens`
162async fn discover_anthropic(
163    model_id: &str,
164    api_key: Option<&str>,
165) -> Result<ModelCapabilities, AppError> {
166    let key = api_key
167        .filter(|k| !k.is_empty())
168        .ok_or_else(|| AppError::Config("Anthropic API key required for discovery".into()))?;
169
170    let url = format!("https://api.anthropic.com/v1/models/{}", model_id);
171
172    let resp = discovery_client()
173        .get(&url)
174        .header("x-api-key", key)
175        .header("anthropic-version", "2023-06-01")
176        .send()
177        .await
178        .map_err(|e| AppError::Llm(format!("Anthropic discovery failed: {e}")))?;
179
180    if !resp.status().is_success() {
181        return Err(AppError::Llm(format!(
182            "Anthropic API returned {}",
183            resp.status()
184        )));
185    }
186
187    let data: serde_json::Value = resp.json().await?;
188
189    // Anthropic returns max_input_tokens and max_output_tokens as independent
190    // limits (not a shared combined window).  Store their sum as
191    // `context_length` so that `max_input_tokens() = context_length -
192    // max_output_tokens` correctly recovers `input_limit`.
193    let input_limit = data
194        .get("max_input_tokens")
195        .and_then(|v| v.as_u64())
196        .unwrap_or(200_000) as usize;
197
198    let output_limit = data
199        .get("max_output_tokens")
200        .and_then(|v| v.as_u64())
201        .unwrap_or(8_192) as usize;
202
203    Ok(ModelCapabilities::new(
204        "anthropic",
205        model_id,
206        input_limit + output_limit,
207        output_limit,
208        CapabilitySource::ApiDiscovery,
209    )
210    .with_vision(true))
211}
212
213/// Discover capabilities for a Grok (xAI) model via their API.
214///
215/// Endpoint: `GET /v1/language-models`
216/// Returns: List of models with context info
217async fn discover_grok(
218    model_id: &str,
219    api_key: Option<&str>,
220) -> Result<ModelCapabilities, AppError> {
221    let key = api_key
222        .filter(|k| !k.is_empty())
223        .ok_or_else(|| AppError::Config("Grok API key required for discovery".into()))?;
224
225    let url = "https://api.x.ai/v1/language-models";
226
227    let resp = discovery_client()
228        .get(url)
229        .header("Authorization", format!("Bearer {}", key))
230        .send()
231        .await
232        .map_err(|e| AppError::Llm(format!("Grok discovery failed: {e}")))?;
233
234    if !resp.status().is_success() {
235        return Err(AppError::Llm(format!(
236            "Grok API returned {}",
237            resp.status()
238        )));
239    }
240
241    let data: serde_json::Value = resp.json().await?;
242
243    // Find the model in the list
244    let models = data.get("models").and_then(|m| m.as_array());
245    let model_data = models.and_then(|list| {
246        list.iter().find(|m| {
247            m.get("id")
248                .and_then(|id| id.as_str())
249                .map(|id| id == model_id)
250                .unwrap_or(false)
251        })
252    });
253
254    let (input_limit, output_limit) = if let Some(model) = model_data {
255        let input = model
256            .get("input_modalities")
257            .and_then(|m| m.get("text"))
258            .and_then(|t| t.get("token_limit"))
259            .and_then(|v| v.as_u64())
260            .unwrap_or(131_072) as usize;
261
262        let output = model
263            .get("output_modalities")
264            .and_then(|m| m.get("text"))
265            .and_then(|t| t.get("token_limit"))
266            .and_then(|v| v.as_u64())
267            .unwrap_or(8_192) as usize;
268
269        (input, output)
270    } else {
271        // Model not in list, use conservative defaults
272        (32_000, 4_096)
273    };
274
275    // Grok reports input and output limits independently.  Store their sum
276    // as `context_length` so that `max_input_tokens()` returns `input_limit`.
277    Ok(ModelCapabilities::new(
278        "grok",
279        model_id,
280        input_limit + output_limit,
281        output_limit,
282        CapabilitySource::ApiDiscovery,
283    ))
284}
285
286/// Discover capabilities for an Ollama model via local API.
287///
288/// Endpoint: `POST /api/show`
289/// Returns: `model_info.{architecture}.context_length`
290async fn discover_ollama(
291    model_id: &str,
292    base_url: Option<&str>,
293) -> Result<ModelCapabilities, AppError> {
294    let base = base_url
295        .filter(|u| !u.is_empty())
296        .unwrap_or("http://localhost:11434");
297
298    let url = format!("{}/api/show", base.trim_end_matches('/'));
299
300    let resp = discovery_client()
301        .post(&url)
302        .json(&serde_json::json!({ "name": model_id }))
303        .send()
304        .await
305        .map_err(|e| AppError::Llm(format!("Ollama discovery failed: {e}")))?;
306
307    if !resp.status().is_success() {
308        return Err(AppError::Llm(format!(
309            "Ollama API returned {}",
310            resp.status()
311        )));
312    }
313
314    let data: serde_json::Value = resp.json().await?;
315
316    // Ollama stores context_length in model_info under architecture-specific keys
317    // e.g., "llama.context_length", "gemma.context_length"
318    let model_info = data.get("model_info");
319
320    let context_length = model_info
321        .and_then(|info| {
322            // Try to find any key ending with ".context_length"
323            info.as_object().and_then(|obj| {
324                obj.iter()
325                    .find(|(k, _)| k.ends_with(".context_length"))
326                    .and_then(|(_, v)| v.as_u64())
327            })
328        })
329        .unwrap_or(4_096) as usize;
330
331    // Also check num_ctx in parameters (user override)
332    let num_ctx = data
333        .get("parameters")
334        .and_then(|p| p.as_str())
335        .and_then(|params| {
336            // Parse "num_ctx N" from parameters string
337            params
338                .split_whitespace()
339                .collect::<Vec<_>>()
340                .windows(2)
341                .find(|w| w[0] == "num_ctx")
342                .and_then(|w| w[1].parse::<usize>().ok())
343        });
344
345    let effective_context = num_ctx.unwrap_or(context_length);
346
347    Ok(ModelCapabilities::new(
348        "ollama",
349        model_id,
350        effective_context,
351        effective_context / 4, // Rough estimate for output
352        CapabilitySource::ApiDiscovery,
353    ))
354}
355
356#[cfg(test)]
357mod tests {
358    #[tokio::test]
359    async fn test_ollama_parsing() {
360        // Test that we can parse context_length from model_info
361        let json = serde_json::json!({
362            "model_info": {
363                "llama.context_length": 8192,
364                "llama.embedding_length": 4096
365            }
366        });
367
368        let context = json.get("model_info").and_then(|info| {
369            info.as_object().and_then(|obj| {
370                obj.iter()
371                    .find(|(k, _)| k.ends_with(".context_length"))
372                    .and_then(|(_, v)| v.as_u64())
373            })
374        });
375
376        assert_eq!(context, Some(8192));
377    }
378}