gestura_core_llm/
model_listing.rs

1//! Dynamic and static model listing for all LLM providers.
2//!
3//! This module centralises model discovery so that both the GUI (Tauri commands) and
4//! the CLI can fetch / fall back to the same lists without duplicating HTTP logic.
5//!
6//! Each provider's listing is feature-gated to match the rest of `gestura-core-llm`.
7
8use gestura_core_foundation::AppError;
9use serde::{Deserialize, Serialize};
10
11use crate::default_models::{DEFAULT_GEMINI_BASE_URL, DEFAULT_OLLAMA_BASE_URL};
12#[cfg(feature = "openai")]
13use crate::openai::is_agent_capable_openai_model;
14
15/// Timeout for model-listing HTTP calls (shorter than inference calls).
16const MODEL_LIST_TIMEOUT_SECS: u64 = 10;
17
18/// A single model entry returned by listing endpoints.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ModelInfo {
21    /// Provider-specific model identifier (e.g. `gpt-4o`, `claude-sonnet-4-20250514`).
22    pub id: String,
23    /// Human-readable display name.
24    pub name: String,
25    /// Provider key (e.g. `openai`, `anthropic`, `gemini`, `grok`, `ollama`).
26    pub provider: String,
27}
28
29/// Create a lightweight HTTP client for model listing (shorter timeout than inference).
30fn listing_client() -> reqwest::Client {
31    reqwest::Client::builder()
32        .timeout(std::time::Duration::from_secs(MODEL_LIST_TIMEOUT_SECS))
33        .connect_timeout(std::time::Duration::from_secs(5))
34        .build()
35        .unwrap_or_else(|_| reqwest::Client::new())
36}
37
38// ---------------------------------------------------------------------------
39// Public API
40// ---------------------------------------------------------------------------
41
42/// List available models for the given provider.
43///
44/// Tries a live API call; returns an empty list when no API key is provided
45/// or the API is unreachable.
46///
47/// # Arguments
48/// * `provider` – one of `openai`, `anthropic`, `grok`, `gemini`, `ollama`
49/// * `api_key` – required for cloud providers; ignored for `ollama`
50/// * `base_url` – optional override; uses provider default when `None`
51pub async fn list_models_for_provider(
52    provider: &str,
53    api_key: Option<&str>,
54    base_url: Option<&str>,
55) -> Result<Vec<ModelInfo>, AppError> {
56    match provider.to_lowercase().as_str() {
57        #[cfg(feature = "openai")]
58        "openai" => list_openai(api_key, base_url).await,
59        #[cfg(feature = "anthropic")]
60        "anthropic" => list_anthropic(api_key).await,
61        #[cfg(feature = "grok")]
62        "grok" => list_grok(api_key).await,
63        #[cfg(feature = "gemini")]
64        "gemini" => list_gemini(api_key, base_url).await,
65        #[cfg(feature = "ollama")]
66        "ollama" => list_ollama(base_url).await,
67        other => Err(AppError::Config(format!(
68            "Unknown or disabled provider: {other}"
69        ))),
70    }
71}
72
73/// Return the static / fallback model list for a provider (no network).
74///
75/// Returns an empty list — static fallback lists have been removed.
76/// Kept for API compatibility.
77pub fn static_models_for_provider(_provider: &str) -> Vec<ModelInfo> {
78    Vec::new()
79}
80
81// ---------------------------------------------------------------------------
82// Provider-specific implementations
83// ---------------------------------------------------------------------------
84
85#[cfg(feature = "openai")]
86fn is_agent_capable_openai_model_id(model_id: &str) -> bool {
87    is_agent_capable_openai_model(model_id)
88}
89
90#[cfg(feature = "openai")]
91async fn list_openai(
92    api_key: Option<&str>,
93    base_url: Option<&str>,
94) -> Result<Vec<ModelInfo>, AppError> {
95    let key = match api_key.filter(|k| !k.is_empty()) {
96        Some(k) => k,
97        None => return Ok(Vec::new()),
98    };
99    let base = base_url
100        .filter(|u| !u.is_empty())
101        .unwrap_or("https://api.openai.com");
102
103    let url = format!("{}/v1/models", base.trim_end_matches('/'));
104    let resp = listing_client()
105        .get(&url)
106        .bearer_auth(key)
107        .send()
108        .await
109        .map_err(|e| AppError::Llm(format!("openai model list failed: {e}")))?;
110
111    if !resp.status().is_success() {
112        tracing::warn!("OpenAI /v1/models returned {}", resp.status());
113        return Ok(Vec::new());
114    }
115
116    let data: serde_json::Value = resp.json().await?;
117    let mut models: Vec<ModelInfo> = data
118        .get("data")
119        .and_then(|d| d.as_array())
120        .map(|arr| {
121            arr.iter()
122                .filter_map(|m| {
123                    let id = m.get("id")?.as_str()?;
124                    if !is_agent_capable_openai_model_id(id) {
125                        return None;
126                    }
127                    Some(ModelInfo {
128                        id: id.to_string(),
129                        name: gestura_core_foundation::model_display::format_model_name(
130                            "openai", id,
131                        ),
132                        provider: "openai".to_string(),
133                    })
134                })
135                .collect()
136        })
137        .unwrap_or_default();
138
139    models.sort_by(|a, b| a.name.cmp(&b.name));
140    Ok(models)
141}
142
143#[cfg(feature = "anthropic")]
144async fn list_anthropic(api_key: Option<&str>) -> Result<Vec<ModelInfo>, AppError> {
145    let key = match api_key.filter(|k| !k.is_empty()) {
146        Some(k) => k,
147        None => return Ok(Vec::new()),
148    };
149
150    let url = "https://api.anthropic.com/v1/models";
151    let resp = listing_client()
152        .get(url)
153        .header("x-api-key", key)
154        .header("anthropic-version", "2023-06-01")
155        .send()
156        .await
157        .map_err(|e| AppError::Llm(format!("anthropic model list failed: {e}")))?;
158
159    if !resp.status().is_success() {
160        tracing::warn!("Anthropic /v1/models returned {}", resp.status());
161        return Ok(Vec::new());
162    }
163
164    let data: serde_json::Value = resp.json().await?;
165    let mut models: Vec<ModelInfo> = data
166        .get("data")
167        .and_then(|d| d.as_array())
168        .map(|arr| {
169            arr.iter()
170                .filter_map(|m| {
171                    let id = m.get("id")?.as_str()?;
172                    if !id.starts_with("claude-") {
173                        return None;
174                    }
175                    Some(ModelInfo {
176                        id: id.to_string(),
177                        name: gestura_core_foundation::model_display::format_model_name(
178                            "anthropic",
179                            id,
180                        ),
181                        provider: "anthropic".to_string(),
182                    })
183                })
184                .collect()
185        })
186        .unwrap_or_default();
187
188    models.sort_by(|a, b| a.name.cmp(&b.name));
189    Ok(models)
190}
191
192#[cfg(feature = "grok")]
193async fn list_grok(api_key: Option<&str>) -> Result<Vec<ModelInfo>, AppError> {
194    let key = match api_key.filter(|k| !k.is_empty()) {
195        Some(k) => k,
196        None => return Ok(Vec::new()),
197    };
198
199    let url = "https://api.x.ai/v1/models";
200    let resp = listing_client()
201        .get(url)
202        .bearer_auth(key)
203        .send()
204        .await
205        .map_err(|e| AppError::Llm(format!("grok model list failed: {e}")))?;
206
207    if !resp.status().is_success() {
208        tracing::warn!("Grok /v1/models returned {}", resp.status());
209        return Ok(Vec::new());
210    }
211
212    let data: serde_json::Value = resp.json().await?;
213    let mut models: Vec<ModelInfo> = data
214        .get("data")
215        .and_then(|d| d.as_array())
216        .map(|arr| {
217            arr.iter()
218                .filter_map(|m| {
219                    let id = m.get("id")?.as_str()?;
220                    if id.contains("image") {
221                        return None;
222                    }
223                    Some(ModelInfo {
224                        id: id.to_string(),
225                        name: gestura_core_foundation::model_display::format_model_name("grok", id),
226                        provider: "grok".to_string(),
227                    })
228                })
229                .collect()
230        })
231        .unwrap_or_default();
232
233    models.sort_by(|a, b| a.name.cmp(&b.name));
234    Ok(models)
235}
236
237#[cfg(feature = "gemini")]
238async fn list_gemini(
239    api_key: Option<&str>,
240    base_url: Option<&str>,
241) -> Result<Vec<ModelInfo>, AppError> {
242    let key = match api_key.filter(|k| !k.is_empty()) {
243        Some(k) => k,
244        None => return Ok(Vec::new()),
245    };
246    let base = base_url
247        .filter(|u| !u.is_empty())
248        .unwrap_or(DEFAULT_GEMINI_BASE_URL);
249
250    // Gemini uses key as query param, not bearer auth.
251    let url = format!("{}/v1beta/models?key={}", base.trim_end_matches('/'), key);
252    let resp = listing_client()
253        .get(&url)
254        .send()
255        .await
256        .map_err(|e| AppError::Llm(format!("gemini model list failed: {e}")))?;
257
258    if !resp.status().is_success() {
259        tracing::warn!("Gemini /v1beta/models returned {}", resp.status());
260        return Ok(Vec::new());
261    }
262
263    let data: serde_json::Value = resp.json().await?;
264    let mut models: Vec<ModelInfo> = data
265        .get("models")
266        .and_then(|d| d.as_array())
267        .map(|arr| {
268            arr.iter()
269                .filter_map(|m| {
270                    // Gemini returns "name": "models/gemini-2.0-flash" — strip the prefix.
271                    let raw_name = m.get("name")?.as_str()?;
272                    let id = raw_name.strip_prefix("models/").unwrap_or(raw_name);
273                    // Only include generative models (skip embedding, AQA, etc.).
274                    let methods = m
275                        .get("supportedGenerationMethods")
276                        .and_then(|v| v.as_array());
277                    let is_generative = methods
278                        .map(|ms| ms.iter().any(|v| v.as_str() == Some("generateContent")))
279                        .unwrap_or(false);
280                    if !is_generative {
281                        return None;
282                    }
283                    let display = m.get("displayName").and_then(|d| d.as_str()).unwrap_or(id);
284                    Some(ModelInfo {
285                        id: id.to_string(),
286                        name: display.to_string(),
287                        provider: "gemini".to_string(),
288                    })
289                })
290                .collect()
291        })
292        .unwrap_or_default();
293
294    models.sort_by(|a, b| a.name.cmp(&b.name));
295    Ok(models)
296}
297
298#[cfg(feature = "ollama")]
299async fn list_ollama(base_url: Option<&str>) -> Result<Vec<ModelInfo>, AppError> {
300    let base = base_url
301        .filter(|u| !u.is_empty())
302        .unwrap_or(DEFAULT_OLLAMA_BASE_URL);
303
304    let url = format!("{}/api/tags", base.trim_end_matches('/'));
305    let resp = listing_client()
306        .get(&url)
307        .send()
308        .await
309        .map_err(|e| AppError::Llm(format!("ollama model list failed: {e}")))?;
310
311    if !resp.status().is_success() {
312        return Err(AppError::Llm(format!(
313            "Ollama at {} returned status {}",
314            base,
315            resp.status()
316        )));
317    }
318
319    let data: serde_json::Value = resp.json().await?;
320    let models: Vec<ModelInfo> = data
321        .get("models")
322        .and_then(|m| m.as_array())
323        .map(|arr| {
324            arr.iter()
325                .filter_map(|m| {
326                    let name = m.get("name").and_then(|n| n.as_str())?;
327                    Some(ModelInfo {
328                        id: name.to_string(),
329                        name: gestura_core_foundation::model_display::format_model_name(
330                            "ollama", name,
331                        ),
332                        provider: "ollama".to_string(),
333                    })
334                })
335                .collect()
336        })
337        .unwrap_or_default();
338
339    Ok(models)
340}
341
342// ---------------------------------------------------------------------------
343// Ollama connectivity check
344// ---------------------------------------------------------------------------
345
346/// Timeout for the lightweight Ollama connectivity ping.
347const OLLAMA_PING_TIMEOUT_SECS: u64 = 3;
348
349/// Ping the Ollama endpoint to verify it is reachable.
350///
351/// Issues a lightweight `GET /api/tags` with a short timeout.
352/// Returns `true` only on a successful HTTP response.
353///
354/// # Arguments
355/// * `base_url` – Ollama base URL (e.g. `http://localhost:11434`).
356///   Falls back to [`DEFAULT_OLLAMA_BASE_URL`] when empty.
357#[cfg(feature = "ollama")]
358pub async fn check_ollama_connectivity(base_url: &str) -> bool {
359    let base = if base_url.is_empty() {
360        DEFAULT_OLLAMA_BASE_URL
361    } else {
362        base_url
363    };
364    let url = format!("{}/api/tags", base.trim_end_matches('/'));
365    let client = reqwest::Client::builder()
366        .timeout(std::time::Duration::from_secs(OLLAMA_PING_TIMEOUT_SECS))
367        .connect_timeout(std::time::Duration::from_secs(OLLAMA_PING_TIMEOUT_SECS))
368        .build()
369        .unwrap_or_else(|_| reqwest::Client::new());
370    match client.get(&url).send().await {
371        Ok(resp) => resp.status().is_success(),
372        Err(_) => false,
373    }
374}
375
376// ---------------------------------------------------------------------------
377// Tests
378// ---------------------------------------------------------------------------
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn static_openai_returns_empty() {
386        let models = static_models_for_provider("openai");
387        assert!(models.is_empty());
388    }
389
390    #[test]
391    fn static_anthropic_returns_empty() {
392        let models = static_models_for_provider("anthropic");
393        assert!(models.is_empty());
394    }
395
396    #[test]
397    fn static_grok_returns_empty() {
398        let models = static_models_for_provider("grok");
399        assert!(models.is_empty());
400    }
401
402    #[test]
403    fn static_gemini_returns_empty() {
404        let models = static_models_for_provider("gemini");
405        assert!(models.is_empty());
406    }
407
408    #[test]
409    fn static_unknown_returns_empty() {
410        let models = static_models_for_provider("unknown_provider");
411        assert!(models.is_empty());
412    }
413
414    #[tokio::test]
415    async fn list_without_key_returns_empty() {
416        // Cloud providers without an API key should return an empty list.
417        let models = list_models_for_provider("openai", None, None)
418            .await
419            .unwrap();
420        assert!(models.is_empty());
421    }
422
423    #[test]
424    #[cfg(feature = "openai")]
425    fn filters_openai_models_to_agent_capable_session_models() {
426        for allowed in [
427            "gpt-4o",
428            "gpt-4.1",
429            "o4-mini",
430            "gpt-5.4",
431            "gpt-5.3-codex",
432            "codex-1",
433            "chatgpt-4o-latest",
434            "codex-mini-latest",
435        ] {
436            assert!(
437                is_agent_capable_openai_model_id(allowed),
438                "expected {allowed} to be allowed"
439            );
440        }
441
442        for blocked in [
443            "gpt-3.5-turbo-instruct",
444            "gpt-4o-transcribe",
445            "gpt-4o-audio-preview",
446            "gpt-realtime",
447            "gpt-image-1",
448            "text-davinci-003",
449        ] {
450            assert!(
451                !is_agent_capable_openai_model_id(blocked),
452                "expected {blocked} to be filtered out"
453            );
454        }
455    }
456}