gestura_core_llm/
model_discovery.rs1use crate::model_capabilities::{CapabilitySource, ModelCapabilities, ModelCapabilitiesCache};
23use gestura_core_foundation::AppError;
24use std::time::Duration;
25
26const DISCOVERY_TIMEOUT_SECS: u64 = 10;
28
29fn discovery_client() -> reqwest::Client {
31 reqwest::Client::builder()
32 .timeout(Duration::from_secs(DISCOVERY_TIMEOUT_SECS))
33 .connect_timeout(Duration::from_secs(5))
34 .build()
35 .unwrap_or_else(|_| reqwest::Client::new())
36}
37
38pub async fn discover_model_capabilities(
43 cache: &ModelCapabilitiesCache,
44 provider: &str,
45 model_id: &str,
46 api_key: Option<&str>,
47 base_url: Option<&str>,
48) -> Option<ModelCapabilities> {
49 let result = match provider.to_lowercase().as_str() {
50 "gemini" => discover_gemini(model_id, api_key, base_url).await,
51 "anthropic" => discover_anthropic(model_id, api_key).await,
52 "grok" => discover_grok(model_id, api_key).await,
53 "ollama" => discover_ollama(model_id, base_url).await,
54 "openai" => {
56 tracing::debug!("OpenAI doesn't expose context length in API - using heuristics");
57 return None;
58 }
59 _ => {
60 tracing::debug!(provider = provider, "Unknown provider for discovery");
61 return None;
62 }
63 };
64
65 match result {
66 Ok(caps) => {
67 tracing::info!(
68 provider = provider,
69 model = model_id,
70 context_length = caps.context_length,
71 "Discovered model capabilities from API"
72 );
73 cache.store_from_api(caps.clone());
74 Some(caps)
75 }
76 Err(e) => {
77 tracing::debug!(
78 provider = provider,
79 model = model_id,
80 error = %e,
81 "Failed to discover model capabilities - using heuristics"
82 );
83 None
84 }
85 }
86}
87
88async fn discover_gemini(
93 model_id: &str,
94 api_key: Option<&str>,
95 base_url: Option<&str>,
96) -> Result<ModelCapabilities, AppError> {
97 let key = api_key
98 .filter(|k| !k.is_empty())
99 .ok_or_else(|| AppError::Config("Gemini API key required for discovery".into()))?;
100
101 let base = base_url
102 .filter(|u| !u.is_empty())
103 .unwrap_or("https://generativelanguage.googleapis.com");
104
105 let model_path = if model_id.starts_with("models/") {
107 model_id.to_string()
108 } else {
109 format!("models/{}", model_id)
110 };
111
112 let url = format!(
113 "{}/v1beta/{}?key={}",
114 base.trim_end_matches('/'),
115 model_path,
116 key
117 );
118
119 let resp = discovery_client()
120 .get(&url)
121 .send()
122 .await
123 .map_err(|e| AppError::Llm(format!("Gemini discovery failed: {e}")))?;
124
125 if !resp.status().is_success() {
126 return Err(AppError::Llm(format!(
127 "Gemini API returned {}",
128 resp.status()
129 )));
130 }
131
132 let data: serde_json::Value = resp.json().await?;
133
134 let input_limit = data
135 .get("inputTokenLimit")
136 .and_then(|v| v.as_u64())
137 .unwrap_or(32_000) as usize;
138
139 let output_limit = data
140 .get("outputTokenLimit")
141 .and_then(|v| v.as_u64())
142 .unwrap_or(8_192) as usize;
143
144 Ok(ModelCapabilities::new(
149 "gemini",
150 model_id,
151 input_limit + output_limit,
152 output_limit,
153 CapabilitySource::ApiDiscovery,
154 )
155 .with_vision(true))
156}
157
158async fn discover_anthropic(
163 model_id: &str,
164 api_key: Option<&str>,
165) -> Result<ModelCapabilities, AppError> {
166 let key = api_key
167 .filter(|k| !k.is_empty())
168 .ok_or_else(|| AppError::Config("Anthropic API key required for discovery".into()))?;
169
170 let url = format!("https://api.anthropic.com/v1/models/{}", model_id);
171
172 let resp = discovery_client()
173 .get(&url)
174 .header("x-api-key", key)
175 .header("anthropic-version", "2023-06-01")
176 .send()
177 .await
178 .map_err(|e| AppError::Llm(format!("Anthropic discovery failed: {e}")))?;
179
180 if !resp.status().is_success() {
181 return Err(AppError::Llm(format!(
182 "Anthropic API returned {}",
183 resp.status()
184 )));
185 }
186
187 let data: serde_json::Value = resp.json().await?;
188
189 let input_limit = data
194 .get("max_input_tokens")
195 .and_then(|v| v.as_u64())
196 .unwrap_or(200_000) as usize;
197
198 let output_limit = data
199 .get("max_output_tokens")
200 .and_then(|v| v.as_u64())
201 .unwrap_or(8_192) as usize;
202
203 Ok(ModelCapabilities::new(
204 "anthropic",
205 model_id,
206 input_limit + output_limit,
207 output_limit,
208 CapabilitySource::ApiDiscovery,
209 )
210 .with_vision(true))
211}
212
213async fn discover_grok(
218 model_id: &str,
219 api_key: Option<&str>,
220) -> Result<ModelCapabilities, AppError> {
221 let key = api_key
222 .filter(|k| !k.is_empty())
223 .ok_or_else(|| AppError::Config("Grok API key required for discovery".into()))?;
224
225 let url = "https://api.x.ai/v1/language-models";
226
227 let resp = discovery_client()
228 .get(url)
229 .header("Authorization", format!("Bearer {}", key))
230 .send()
231 .await
232 .map_err(|e| AppError::Llm(format!("Grok discovery failed: {e}")))?;
233
234 if !resp.status().is_success() {
235 return Err(AppError::Llm(format!(
236 "Grok API returned {}",
237 resp.status()
238 )));
239 }
240
241 let data: serde_json::Value = resp.json().await?;
242
243 let models = data.get("models").and_then(|m| m.as_array());
245 let model_data = models.and_then(|list| {
246 list.iter().find(|m| {
247 m.get("id")
248 .and_then(|id| id.as_str())
249 .map(|id| id == model_id)
250 .unwrap_or(false)
251 })
252 });
253
254 let (input_limit, output_limit) = if let Some(model) = model_data {
255 let input = model
256 .get("input_modalities")
257 .and_then(|m| m.get("text"))
258 .and_then(|t| t.get("token_limit"))
259 .and_then(|v| v.as_u64())
260 .unwrap_or(131_072) as usize;
261
262 let output = model
263 .get("output_modalities")
264 .and_then(|m| m.get("text"))
265 .and_then(|t| t.get("token_limit"))
266 .and_then(|v| v.as_u64())
267 .unwrap_or(8_192) as usize;
268
269 (input, output)
270 } else {
271 (32_000, 4_096)
273 };
274
275 Ok(ModelCapabilities::new(
278 "grok",
279 model_id,
280 input_limit + output_limit,
281 output_limit,
282 CapabilitySource::ApiDiscovery,
283 ))
284}
285
286async fn discover_ollama(
291 model_id: &str,
292 base_url: Option<&str>,
293) -> Result<ModelCapabilities, AppError> {
294 let base = base_url
295 .filter(|u| !u.is_empty())
296 .unwrap_or("http://localhost:11434");
297
298 let url = format!("{}/api/show", base.trim_end_matches('/'));
299
300 let resp = discovery_client()
301 .post(&url)
302 .json(&serde_json::json!({ "name": model_id }))
303 .send()
304 .await
305 .map_err(|e| AppError::Llm(format!("Ollama discovery failed: {e}")))?;
306
307 if !resp.status().is_success() {
308 return Err(AppError::Llm(format!(
309 "Ollama API returned {}",
310 resp.status()
311 )));
312 }
313
314 let data: serde_json::Value = resp.json().await?;
315
316 let model_info = data.get("model_info");
319
320 let context_length = model_info
321 .and_then(|info| {
322 info.as_object().and_then(|obj| {
324 obj.iter()
325 .find(|(k, _)| k.ends_with(".context_length"))
326 .and_then(|(_, v)| v.as_u64())
327 })
328 })
329 .unwrap_or(4_096) as usize;
330
331 let num_ctx = data
333 .get("parameters")
334 .and_then(|p| p.as_str())
335 .and_then(|params| {
336 params
338 .split_whitespace()
339 .collect::<Vec<_>>()
340 .windows(2)
341 .find(|w| w[0] == "num_ctx")
342 .and_then(|w| w[1].parse::<usize>().ok())
343 });
344
345 let effective_context = num_ctx.unwrap_or(context_length);
346
347 Ok(ModelCapabilities::new(
348 "ollama",
349 model_id,
350 effective_context,
351 effective_context / 4, CapabilitySource::ApiDiscovery,
353 ))
354}
355
356#[cfg(test)]
357mod tests {
358 #[tokio::test]
359 async fn test_ollama_parsing() {
360 let json = serde_json::json!({
362 "model_info": {
363 "llama.context_length": 8192,
364 "llama.embedding_length": 4096
365 }
366 });
367
368 let context = json.get("model_info").and_then(|info| {
369 info.as_object().and_then(|obj| {
370 obj.iter()
371 .find(|(k, _)| k.ends_with(".context_length"))
372 .and_then(|(_, v)| v.as_u64())
373 })
374 });
375
376 assert_eq!(context, Some(8192));
377 }
378}