gestura_core_llm/
lib.rs

1//! Feature-gated LLM provider implementations and shared provider abstractions.
2//!
3//! `gestura-core-llm` is the domain crate behind Gestura's provider layer. It
4//! defines the common `LlmProvider` trait, shared response/token models, model
5//! listing helpers, default model catalogs, and the concrete provider
6//! implementations used by the runtime.
7//!
8//! ## Supported providers
9//!
10//! Provider implementations are enabled with Cargo features and currently cover:
11//!
12//! - OpenAI
13//! - Anthropic
14//! - Grok (xAI)
15//! - Gemini
16//! - Ollama (local)
17//!
18//! ## Design role
19//!
20//! This crate owns provider-specific HTTP behavior and response normalization.
21//! Higher-level concerns such as configuration-driven provider selection,
22//! runtime overrides, and pipeline orchestration remain in `gestura-core`.
23//!
24//! The stable public import path for most consumers remains
25//! `gestura_core::llm_provider::*`.
26//!
27//! ## Shared abstractions
28//!
29//! - `LlmProvider`: async provider interface used by the runtime
30//! - `LlmCallResponse`: normalized response with text, usage, and tool calls
31//! - `TokenUsage`: provider-agnostic token accounting and estimated cost data
32//! - `ToolCallInfo`: normalized native function/tool call representation
33//! - `default_models`, `model_listing`, `token_tracker`: support modules for
34//!   model defaults, discovery, and token accounting
35//!
36//! ## Native tool calling
37//!
38//! Where providers support it, Gestura normalizes native function/tool calling
39//! into a common `ToolCallInfo` representation so the pipeline can process tool
40//! calls consistently across providers.
41//!
42//! ## Feature-gated workspace design
43//!
44//! This crate is intentionally feature-gated so applications can compile only
45//! the providers they need. That keeps optional integrations isolated and makes
46//! the workspace easier to reason about in `cargo doc` and CI.
47
48pub mod default_models;
49pub mod model_capabilities;
50pub mod model_discovery;
51pub mod model_listing;
52pub mod openai;
53pub mod token_tracker;
54
55use gestura_core_foundation::AppError;
56use serde::{Deserialize, Serialize};
57use std::time::Duration;
58
59#[cfg(feature = "openai")]
60use crate::openai::{
61    OpenAiApi, is_openai_model_incompatible_with_agent_session, openai_agent_session_model_message,
62    openai_api_for_model,
63};
64
65/// Default timeout for LLM API calls (2 minutes for slow local models)
66const LLM_TIMEOUT_SECS: u64 = 120;
67
68/// Create a reqwest client with appropriate timeouts
69fn create_http_client() -> reqwest::Client {
70    reqwest::Client::builder()
71        .timeout(Duration::from_secs(LLM_TIMEOUT_SECS))
72        .connect_timeout(Duration::from_secs(10))
73        .build()
74        .unwrap_or_else(|_| reqwest::Client::new())
75}
76
77/// Token usage information from an LLM API call
78#[derive(Debug, Clone, Default, Serialize, Deserialize)]
79pub struct TokenUsage {
80    /// Number of tokens in the input/prompt
81    pub input_tokens: u32,
82    /// Number of tokens in the output/completion
83    pub output_tokens: u32,
84    /// Total tokens (input + output)
85    pub total_tokens: u32,
86    /// Estimated cost in USD (if available)
87    pub estimated_cost_usd: Option<f64>,
88    /// Model used for the request
89    pub model: Option<String>,
90    /// Provider name
91    pub provider: Option<String>,
92}
93
94impl TokenUsage {
95    /// Create a new TokenUsage with the given counts
96    pub fn new(input_tokens: u32, output_tokens: u32) -> Self {
97        Self {
98            input_tokens,
99            output_tokens,
100            total_tokens: input_tokens + output_tokens,
101            estimated_cost_usd: None,
102            model: None,
103            provider: None,
104        }
105    }
106
107    /// Create an empty/unknown token usage (for providers that don't report usage)
108    pub fn unknown() -> Self {
109        Self::default()
110    }
111
112    /// Set the estimated cost based on provider pricing
113    pub fn with_cost(mut self, cost_usd: f64) -> Self {
114        self.estimated_cost_usd = Some(cost_usd);
115        self
116    }
117
118    /// Set the model name
119    pub fn with_model(mut self, model: impl Into<String>) -> Self {
120        self.model = Some(model.into());
121        self
122    }
123
124    /// Set the provider name
125    pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
126        self.provider = Some(provider.into());
127        self
128    }
129
130    /// Calculate cost based on standard pricing (per 1M tokens)
131    pub fn calculate_cost(&mut self, input_price_per_million: f64, output_price_per_million: f64) {
132        let input_cost = (self.input_tokens as f64 / 1_000_000.0) * input_price_per_million;
133        let output_cost = (self.output_tokens as f64 / 1_000_000.0) * output_price_per_million;
134        self.estimated_cost_usd = Some(input_cost + output_cost);
135    }
136}
137
138/// A structured tool call returned by the LLM when using native function calling.
139#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
140pub struct ToolCallInfo {
141    /// Provider-assigned call ID (e.g. `call_abc123` for OpenAI, `toolu_xxx` for Anthropic)
142    pub id: String,
143    /// Tool name
144    pub name: String,
145    /// JSON-encoded arguments string
146    pub arguments: String,
147}
148
149/// Response from an LLM call including token usage
150#[derive(Debug, Clone)]
151pub struct LlmCallResponse {
152    /// The generated text
153    pub text: String,
154    /// Token usage information
155    pub usage: TokenUsage,
156    /// Structured tool calls returned by the model (empty when the model responds with text only)
157    pub tool_calls: Vec<ToolCallInfo>,
158}
159
160impl LlmCallResponse {
161    /// Create a new LlmCallResponse (text-only, no tool calls)
162    pub fn new(text: String, usage: TokenUsage) -> Self {
163        Self {
164            text,
165            usage,
166            tool_calls: Vec::new(),
167        }
168    }
169
170    /// Create a response with unknown token usage
171    pub fn with_unknown_usage(text: String) -> Self {
172        Self {
173            text,
174            usage: TokenUsage::unknown(),
175            tool_calls: Vec::new(),
176        }
177    }
178
179    /// Create a new LlmCallResponse with tool calls
180    pub fn with_tool_calls(text: String, usage: TokenUsage, tool_calls: Vec<ToolCallInfo>) -> Self {
181        Self {
182            text,
183            usage,
184            tool_calls,
185        }
186    }
187}
188
189/// Context hints for provider selection (agent, tenant, etc.)
190#[derive(Debug, Clone, Default)]
191pub struct AgentContext {
192    pub agent_id: String,
193}
194
195/// Unified LLM interface (async)
196#[async_trait::async_trait]
197pub trait LlmProvider: Send + Sync {
198    /// Call the LLM with a prompt and return the generated text.
199    /// For backward compatibility, this returns just the text.
200    async fn call(&self, prompt: &str) -> Result<String, AppError>;
201
202    /// Call the LLM with a prompt and return the response with token usage.
203    /// Default implementation calls `call` and returns unknown usage.
204    async fn call_with_usage(&self, prompt: &str) -> Result<LlmCallResponse, AppError> {
205        let text = self.call(prompt).await?;
206        Ok(LlmCallResponse::with_unknown_usage(text))
207    }
208
209    /// Call the LLM with a prompt **and** optional tool schemas.
210    ///
211    /// When `tools` is `Some`, providers that support native tool/function calling
212    /// will include the schemas in the API request body, enabling the model to
213    /// return structured tool call responses.
214    ///
215    /// The default implementation ignores the tools parameter and delegates to
216    /// [`Self::call_with_usage`]. Providers should override this to pass tools
217    /// natively.
218    async fn call_with_tools(
219        &self,
220        prompt: &str,
221        _tools: Option<&[serde_json::Value]>,
222    ) -> Result<LlmCallResponse, AppError> {
223        self.call_with_usage(prompt).await
224    }
225}
226
227/// A provider that returns an error when no real provider is configured.
228/// Used when config is missing or invalid.
229pub struct UnconfiguredProvider {
230    pub provider_name: String,
231}
232
233#[async_trait::async_trait]
234impl LlmProvider for UnconfiguredProvider {
235    async fn call(&self, _prompt: &str) -> Result<String, AppError> {
236        Err(AppError::Llm(format!(
237            "LLM provider '{}' is not configured. Please configure it in Settings or run 'gestura config edit'.",
238            self.provider_name
239        )))
240    }
241}
242
243#[cfg(feature = "openai")]
244/// HTTP-based OpenAI completion provider
245pub struct OpenAiProvider {
246    pub api_key: String,
247    pub base_url: String,
248    pub model: String,
249}
250
251#[cfg(feature = "openai")]
252impl OpenAiProvider {
253    fn endpoint_path(api: OpenAiApi) -> &'static str {
254        match api {
255            OpenAiApi::ChatCompletions => "/v1/chat/completions",
256            OpenAiApi::Responses => "/v1/responses",
257        }
258    }
259
260    fn enrich_openai_error(
261        &self,
262        api: OpenAiApi,
263        status: reqwest::StatusCode,
264        body: &str,
265    ) -> String {
266        if status == reqwest::StatusCode::NOT_FOUND && body.contains("This is not a chat model") {
267            return format!(
268                "OpenAI model '{}' appears to require /v1/responses, but Gestura selected {}. Raw OpenAI error: {body}",
269                self.model,
270                Self::endpoint_path(api)
271            );
272        }
273
274        format!(
275            "OpenAI {} HTTP {}: {}",
276            Self::endpoint_path(api),
277            status,
278            body
279        )
280    }
281
282    /// Parse token usage from OpenAI API response
283    fn parse_usage(&self, response: &serde_json::Value) -> TokenUsage {
284        let usage = &response["usage"];
285        let input_tokens = usage["prompt_tokens"]
286            .as_u64()
287            .or_else(|| usage["input_tokens"].as_u64())
288            .unwrap_or(0) as u32;
289        let output_tokens = usage["completion_tokens"]
290            .as_u64()
291            .or_else(|| usage["output_tokens"].as_u64())
292            .unwrap_or(0) as u32;
293
294        let mut token_usage = TokenUsage::new(input_tokens, output_tokens)
295            .with_model(self.model.clone())
296            .with_provider("openai");
297
298        // OpenAI pricing (approximate, varies by model)
299        // GPT-4o: $2.50/$10 per 1M tokens (input/output)
300        // GPT-4: $30/$60 per 1M tokens
301        // GPT-3.5-turbo: $0.50/$1.50 per 1M tokens
302        let (input_price, output_price) = match self.model.as_str() {
303            m if m.starts_with("gpt-4o") => (2.50, 10.0),
304            m if m.starts_with("gpt-4") => (30.0, 60.0),
305            m if m.starts_with("gpt-3.5") => (0.50, 1.50),
306            _ => (2.50, 10.0), // Default to GPT-4o pricing
307        };
308        token_usage.calculate_cost(input_price, output_price);
309
310        token_usage
311    }
312}
313
314#[cfg(feature = "openai")]
315fn build_openai_chat_request_body(
316    model: &str,
317    prompt: &str,
318    tools: Option<&[serde_json::Value]>,
319) -> serde_json::Value {
320    let mut body = serde_json::json!({
321        "model": model,
322        "messages": [{"role":"user","content": prompt}]
323    });
324
325    if let Some(tools) = tools
326        && !tools.is_empty()
327    {
328        body["tools"] = serde_json::Value::Array(tools.to_vec());
329        body["tool_choice"] = serde_json::json!("auto");
330    }
331
332    body
333}
334
335#[cfg(feature = "openai")]
336fn build_openai_responses_request_body(
337    model: &str,
338    prompt: &str,
339    tools: Option<&[serde_json::Value]>,
340) -> serde_json::Value {
341    let mut body = serde_json::json!({
342        "model": model,
343        "input": [{"role":"user","content": prompt}]
344    });
345
346    if let Some(tools) = tools
347        && !tools.is_empty()
348    {
349        body["tools"] = serde_json::Value::Array(tools.to_vec());
350        body["tool_choice"] = serde_json::json!("auto");
351    }
352
353    body
354}
355
356#[cfg(feature = "openai")]
357fn extract_openai_responses_text(response: &serde_json::Value) -> String {
358    if let Some(text) = response["output_text"].as_str() {
359        return text.to_string();
360    }
361
362    response["output"]
363        .as_array()
364        .into_iter()
365        .flatten()
366        .filter(|item| item["type"].as_str() == Some("message"))
367        .flat_map(|item| item["content"].as_array().into_iter().flatten())
368        .filter_map(|content| match content["type"].as_str() {
369            Some("output_text") => content["text"].as_str(),
370            _ => None,
371        })
372        .collect::<Vec<_>>()
373        .join("")
374}
375
376#[cfg(feature = "openai")]
377fn extract_openai_responses_tool_calls(response: &serde_json::Value) -> Vec<ToolCallInfo> {
378    let Some(output) = response["output"].as_array() else {
379        return Vec::new();
380    };
381
382    output
383        .iter()
384        .filter(|item| item["type"].as_str() == Some("function_call"))
385        .filter_map(|item| {
386            let name = item["name"].as_str()?;
387            Some(ToolCallInfo {
388                id: item["call_id"]
389                    .as_str()
390                    .or_else(|| item["id"].as_str())
391                    .unwrap_or_default()
392                    .to_string(),
393                name: name.to_string(),
394                arguments: item["arguments"].as_str().unwrap_or("{}").to_string(),
395            })
396        })
397        .collect()
398}
399
400#[cfg(feature = "openai")]
401#[async_trait::async_trait]
402impl LlmProvider for OpenAiProvider {
403    async fn call(&self, prompt: &str) -> Result<String, AppError> {
404        let response = self.call_with_usage(prompt).await?;
405        Ok(response.text)
406    }
407
408    async fn call_with_usage(&self, prompt: &str) -> Result<LlmCallResponse, AppError> {
409        self.call_with_tools(prompt, None).await
410    }
411
412    async fn call_with_tools(
413        &self,
414        prompt: &str,
415        tools: Option<&[serde_json::Value]>,
416    ) -> Result<LlmCallResponse, AppError> {
417        if is_openai_model_incompatible_with_agent_session(&self.model) {
418            return Err(AppError::Llm(openai_agent_session_model_message(
419                &self.model,
420            )));
421        }
422
423        let api = openai_api_for_model(&self.model);
424
425        let url = format!(
426            "{}{}",
427            self.base_url.trim_end_matches('/'),
428            Self::endpoint_path(api)
429        );
430        // NOTE: We intentionally omit `temperature`.
431        // Some OpenAI(-compatible) models only support the default value and will
432        // return HTTP 400 if a non-default temperature is provided.
433        let body = match api {
434            OpenAiApi::ChatCompletions => {
435                build_openai_chat_request_body(&self.model, prompt, tools)
436            }
437            OpenAiApi::Responses => build_openai_responses_request_body(&self.model, prompt, tools),
438        };
439
440        let client = create_http_client();
441        let resp = client
442            .post(&url)
443            .bearer_auth(&self.api_key)
444            .json(&body)
445            .send()
446            .await
447            .map_err(|e| AppError::Llm(format!("openai request failed: {}", e)))?;
448        if !resp.status().is_success() {
449            let status = resp.status();
450            let body = resp.text().await.unwrap_or_default();
451            return Err(AppError::Llm(self.enrich_openai_error(api, status, &body)));
452        }
453        let v: serde_json::Value = resp.json().await?;
454        let (text, tool_calls) = match api {
455            OpenAiApi::ChatCompletions => (
456                v["choices"][0]["message"]["content"]
457                    .as_str()
458                    .unwrap_or("")
459                    .to_string(),
460                extract_openai_tool_calls(&v["choices"][0]["message"]),
461            ),
462            OpenAiApi::Responses => (
463                extract_openai_responses_text(&v),
464                extract_openai_responses_tool_calls(&v),
465            ),
466        };
467
468        let usage = self.parse_usage(&v);
469        tracing::debug!(
470            endpoint = Self::endpoint_path(api),
471            "OpenAI token usage: {} input, {} output, ${:.6} estimated, {} tool calls",
472            usage.input_tokens,
473            usage.output_tokens,
474            usage.estimated_cost_usd.unwrap_or(0.0),
475            tool_calls.len()
476        );
477
478        Ok(LlmCallResponse::with_tool_calls(text, usage, tool_calls))
479    }
480}
481
482#[cfg(feature = "anthropic")]
483/// HTTP-based Anthropic Claude provider
484pub struct AnthropicProvider {
485    pub api_key: String,
486    pub base_url: String,
487    pub model: String,
488
489    /// Optional: enable Anthropic "extended thinking" in non-streaming calls.
490    /// When set, we inject the `thinking` field into the request body.
491    pub thinking_budget_tokens: Option<u32>,
492}
493
494#[cfg(feature = "anthropic")]
495impl AnthropicProvider {
496    /// Parse token usage from Anthropic API response
497    fn parse_usage(&self, response: &serde_json::Value) -> TokenUsage {
498        let usage = &response["usage"];
499        let input_tokens = usage["input_tokens"].as_u64().unwrap_or(0) as u32;
500        let output_tokens = usage["output_tokens"].as_u64().unwrap_or(0) as u32;
501
502        let mut token_usage = TokenUsage::new(input_tokens, output_tokens)
503            .with_model(self.model.clone())
504            .with_provider("anthropic");
505
506        // Anthropic pricing (per 1M tokens)
507        // Claude 3.5 Sonnet: $3/$15
508        // Claude 3 Opus: $15/$75
509        // Claude 3 Haiku: $0.25/$1.25
510        let (input_price, output_price) = match self.model.as_str() {
511            m if m.contains("opus") => (15.0, 75.0),
512            m if m.contains("sonnet") => (3.0, 15.0),
513            m if m.contains("haiku") => (0.25, 1.25),
514            _ => (3.0, 15.0), // Default to Sonnet pricing
515        };
516        token_usage.calculate_cost(input_price, output_price);
517
518        token_usage
519    }
520}
521
522#[cfg(any(feature = "openai", feature = "grok", feature = "ollama"))]
523/// Extract structured tool calls from an OpenAI-compatible `message` object.
524///
525/// Works for OpenAI, Grok, and Ollama — all three use the same
526/// `message.tool_calls[].{id, function.name, function.arguments}` format.
527fn extract_openai_tool_calls(message: &serde_json::Value) -> Vec<ToolCallInfo> {
528    let Some(tool_calls) = message["tool_calls"].as_array() else {
529        return Vec::new();
530    };
531
532    tool_calls
533        .iter()
534        .filter_map(|call| {
535            let name = call["function"]["name"].as_str()?;
536            let id = call["id"].as_str().unwrap_or_default().to_string();
537            let arguments = call["function"]["arguments"]
538                .as_str()
539                .unwrap_or("{}")
540                .to_string();
541            Some(ToolCallInfo {
542                id,
543                name: name.to_string(),
544                arguments,
545            })
546        })
547        .collect()
548}
549
550#[cfg(feature = "anthropic")]
551/// Parsed content from an Anthropic `messages` response.
552struct AnthropicContent {
553    text: String,
554    thinking: String,
555    tool_calls: Vec<ToolCallInfo>,
556}
557
558#[cfg(feature = "anthropic")]
559/// Extracts text, thinking, and tool_use content from an Anthropic `messages` response.
560///
561/// Anthropic returns `content` as an array of blocks (e.g. `text`, `tool_use`, and optionally
562/// `thinking`). We extract all three block types.
563fn anthropic_extract_content(response_json: &serde_json::Value) -> AnthropicContent {
564    let mut result = AnthropicContent {
565        text: String::new(),
566        thinking: String::new(),
567        tool_calls: Vec::new(),
568    };
569
570    let Some(blocks) = response_json["content"].as_array() else {
571        return result;
572    };
573
574    for block in blocks {
575        let block_type = block["type"].as_str().unwrap_or("");
576        match block_type {
577            "text" => {
578                if let Some(t) = block.get("text").and_then(|v| v.as_str()) {
579                    result.text.push_str(t);
580                }
581            }
582            "thinking" => {
583                // Different schemas represent this payload with different keys.
584                if let Some(t) = block
585                    .get("thinking")
586                    .and_then(|v| v.as_str())
587                    .or_else(|| block.get("text").and_then(|v| v.as_str()))
588                {
589                    result.thinking.push_str(t);
590                }
591            }
592            "tool_use" => {
593                let id = block["id"].as_str().unwrap_or_default().to_string();
594                let name = block["name"].as_str().unwrap_or_default().to_string();
595                // Anthropic returns `input` as a JSON object; serialize it to a string.
596                let arguments = if let Some(input) = block.get("input") {
597                    serde_json::to_string(input).unwrap_or_default()
598                } else {
599                    "{}".to_string()
600                };
601                if !name.is_empty() {
602                    result.tool_calls.push(ToolCallInfo {
603                        id,
604                        name,
605                        arguments,
606                    });
607                }
608            }
609            _ => {}
610        }
611    }
612
613    result
614}
615
616/// Backwards-compatible wrapper that extracts only text and thinking.
617///
618/// Used by test code to validate extraction without needing the full `AnthropicContent` struct.
619#[cfg(all(test, feature = "anthropic"))]
620fn anthropic_extract_text_and_thinking(response_json: &serde_json::Value) -> (String, String) {
621    let content = anthropic_extract_content(response_json);
622    (content.text, content.thinking)
623}
624
625#[cfg(feature = "anthropic")]
626#[async_trait::async_trait]
627impl LlmProvider for AnthropicProvider {
628    async fn call(&self, prompt: &str) -> Result<String, AppError> {
629        let response = self.call_with_usage(prompt).await?;
630        Ok(response.text)
631    }
632
633    async fn call_with_usage(&self, prompt: &str) -> Result<LlmCallResponse, AppError> {
634        self.call_with_tools(prompt, None).await
635    }
636
637    async fn call_with_tools(
638        &self,
639        prompt: &str,
640        tools: Option<&[serde_json::Value]>,
641    ) -> Result<LlmCallResponse, AppError> {
642        let url = format!("{}/v1/messages", self.base_url);
643        let mut body = serde_json::json!({
644            "model": self.model,
645            "max_tokens": 512,
646            "messages": [{"role":"user","content": [{"type":"text","text": prompt}]}]
647        });
648
649        if let Some(budget_tokens) = self.thinking_budget_tokens {
650            // `body` is created from a JSON object literal above, so direct indexing is safe.
651            body["thinking"] =
652                serde_json::json!({ "type": "enabled", "budget_tokens": budget_tokens });
653        }
654
655        // Anthropic uses its own tool schema format: {name, description, input_schema}.
656        if let Some(tools) = tools
657            && !tools.is_empty()
658        {
659            body["tools"] = serde_json::Value::Array(tools.to_vec());
660        }
661
662        let client = create_http_client();
663        let resp = client
664            .post(&url)
665            .header("x-api-key", &self.api_key)
666            .header("anthropic-version", "2023-06-01")
667            .json(&body)
668            .send()
669            .await
670            .map_err(|e| AppError::Llm(format!("anthropic request failed: {}", e)))?;
671        if !resp.status().is_success() {
672            let status = resp.status();
673            let body = resp.text().await.unwrap_or_default();
674            return Err(AppError::Llm(format!(
675                "anthropic http {}: {}",
676                status, body
677            )));
678        }
679        let v: serde_json::Value = resp.json().await?;
680        let content = anthropic_extract_content(&v);
681        let text = if content.thinking.trim().is_empty() {
682            content.text
683        } else {
684            // Normalize provider-native thinking into our generic <think> format so the rest of the
685            // pipeline can split it consistently.
686            format!("<think>{}</think>{}", content.thinking, content.text)
687        };
688
689        let usage = self.parse_usage(&v);
690        tracing::debug!(
691            "Anthropic token usage: {} input, {} output, ${:.6} estimated, {} tool calls",
692            usage.input_tokens,
693            usage.output_tokens,
694            usage.estimated_cost_usd.unwrap_or(0.0),
695            content.tool_calls.len()
696        );
697
698        Ok(LlmCallResponse::with_tool_calls(
699            text,
700            usage,
701            content.tool_calls,
702        ))
703    }
704}
705
706#[cfg(feature = "grok")]
707/// HTTP-based Grok (xAI) provider (OpenAI-compatible endpoint)
708pub struct GrokProvider {
709    pub api_key: String,
710    pub base_url: String,
711    pub model: String,
712}
713
714#[cfg(feature = "grok")]
715impl GrokProvider {
716    /// Parse token usage from Grok API response (OpenAI-compatible format)
717    fn parse_usage(&self, response: &serde_json::Value) -> TokenUsage {
718        let usage = &response["usage"];
719        let input_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
720        let output_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
721
722        let mut token_usage = TokenUsage::new(input_tokens, output_tokens)
723            .with_model(self.model.clone())
724            .with_provider("grok");
725
726        // Grok pricing (per 1M tokens) - xAI pricing
727        // Grok-2: $2/$10 (estimated)
728        token_usage.calculate_cost(2.0, 10.0);
729
730        token_usage
731    }
732}
733
734#[cfg(feature = "grok")]
735#[async_trait::async_trait]
736impl LlmProvider for GrokProvider {
737    async fn call(&self, prompt: &str) -> Result<String, AppError> {
738        let response = self.call_with_usage(prompt).await?;
739        Ok(response.text)
740    }
741
742    async fn call_with_usage(&self, prompt: &str) -> Result<LlmCallResponse, AppError> {
743        self.call_with_tools(prompt, None).await
744    }
745
746    async fn call_with_tools(
747        &self,
748        prompt: &str,
749        tools: Option<&[serde_json::Value]>,
750    ) -> Result<LlmCallResponse, AppError> {
751        let url = format!(
752            "{}/v1/chat/completions",
753            self.base_url.trim_end_matches('/')
754        );
755        // Grok is OpenAI-compatible, so uses the same tool schema format.
756        let mut body = serde_json::json!({
757            "model": self.model,
758            "messages": [{"role":"user","content": prompt}],
759        });
760
761        if let Some(tools) = tools
762            && !tools.is_empty()
763        {
764            body["tools"] = serde_json::Value::Array(tools.to_vec());
765            body["tool_choice"] = serde_json::json!("auto");
766        }
767
768        let client = create_http_client();
769        let resp = client
770            .post(&url)
771            .bearer_auth(&self.api_key)
772            .json(&body)
773            .send()
774            .await
775            .map_err(|e| AppError::Llm(format!("grok request failed: {}", e)))?;
776        if !resp.status().is_success() {
777            let status = resp.status();
778            let body = resp.text().await.unwrap_or_default();
779            return Err(AppError::Llm(format!("grok http {}: {}", status, body)));
780        }
781        let v: serde_json::Value = resp.json().await?;
782        let text = v["choices"][0]["message"]["content"]
783            .as_str()
784            .unwrap_or("")
785            .to_string();
786
787        // Extract structured tool calls (Grok uses OpenAI-compatible format).
788        let tool_calls = extract_openai_tool_calls(&v["choices"][0]["message"]);
789
790        let usage = self.parse_usage(&v);
791        tracing::debug!(
792            "Grok token usage: {} input, {} output, ${:.6} estimated, {} tool calls",
793            usage.input_tokens,
794            usage.output_tokens,
795            usage.estimated_cost_usd.unwrap_or(0.0),
796            tool_calls.len()
797        );
798
799        Ok(LlmCallResponse::with_tool_calls(text, usage, tool_calls))
800    }
801}
802
803#[cfg(feature = "gemini")]
804/// HTTP-based Google Gemini provider (Generative Language API).
805///
806/// Gemini uses a distinct authentication scheme (API key as a query parameter)
807/// and a unique response format where text and tool calls are returned as
808/// `parts` inside `candidates[0].content`.
809pub struct GeminiProvider {
810    /// API key for the Generative Language API.
811    pub api_key: String,
812    /// Base URL (default: `https://generativelanguage.googleapis.com`).
813    pub base_url: String,
814    /// Model identifier (e.g. `gemini-2.0-flash`).
815    pub model: String,
816}
817
818#[cfg(feature = "gemini")]
819impl GeminiProvider {
820    /// Parse token usage from a Gemini `generateContent` response.
821    ///
822    /// Gemini reports usage in `usageMetadata.{promptTokenCount, candidatesTokenCount}`.
823    fn parse_usage(&self, response: &serde_json::Value) -> TokenUsage {
824        let usage = &response["usageMetadata"];
825        let input_tokens = usage["promptTokenCount"].as_u64().unwrap_or(0) as u32;
826        let output_tokens = usage["candidatesTokenCount"].as_u64().unwrap_or(0) as u32;
827
828        let mut token_usage = TokenUsage::new(input_tokens, output_tokens)
829            .with_model(self.model.clone())
830            .with_provider("gemini");
831
832        // Gemini pricing (per 1M tokens, as of 2026-02)
833        // Gemini 2.0 Flash:      $0.10 / $0.40  (input / output)
834        // Gemini 2.0 Flash-Lite: $0.075 / $0.30
835        // Gemini 1.5 Pro:        $1.25 / $5.00
836        // Gemini 1.5 Flash:      $0.075 / $0.30
837        let (input_price, output_price) = match self.model.as_str() {
838            m if m.contains("1.5-pro") => (1.25, 5.00),
839            m if m.contains("flash-lite") => (0.075, 0.30),
840            m if m.contains("1.5-flash") => (0.075, 0.30),
841            m if m.contains("flash") => (0.10, 0.40), // 2.0 Flash default
842            _ => (0.10, 0.40),
843        };
844        token_usage.calculate_cost(input_price, output_price);
845
846        token_usage
847    }
848}
849
850/// Parsed content from a Gemini `generateContent` response.
851#[cfg(feature = "gemini")]
852struct GeminiContent {
853    text: String,
854    tool_calls: Vec<ToolCallInfo>,
855}
856
857/// Extract text and `functionCall` parts from a Gemini response.
858///
859/// Gemini returns `candidates[0].content.parts[]` where each part is either
860/// `{"text": "..."}` or `{"functionCall": {"name": "...", "args": {...}}}`.
861/// Gemini does not assign call-specific IDs, so we synthesize one per call.
862#[cfg(feature = "gemini")]
863fn gemini_extract_content(response: &serde_json::Value) -> GeminiContent {
864    let mut result = GeminiContent {
865        text: String::new(),
866        tool_calls: Vec::new(),
867    };
868
869    let Some(parts) = response["candidates"][0]["content"]["parts"].as_array() else {
870        return result;
871    };
872
873    for (idx, part) in parts.iter().enumerate() {
874        if let Some(text) = part["text"].as_str() {
875            if !result.text.is_empty() {
876                result.text.push('\n');
877            }
878            result.text.push_str(text);
879        }
880        if let Some(fc) = part.get("functionCall") {
881            let name = fc["name"].as_str().unwrap_or_default().to_string();
882            let args = if let Some(a) = fc.get("args") {
883                serde_json::to_string(a).unwrap_or_default()
884            } else {
885                "{}".to_string()
886            };
887            if !name.is_empty() {
888                result.tool_calls.push(ToolCallInfo {
889                    id: format!("gemini-call-{idx}"),
890                    name,
891                    arguments: args,
892                });
893            }
894        }
895    }
896
897    result
898}
899
900#[cfg(feature = "gemini")]
901#[async_trait::async_trait]
902impl LlmProvider for GeminiProvider {
903    async fn call(&self, prompt: &str) -> Result<String, AppError> {
904        let response = self.call_with_usage(prompt).await?;
905        Ok(response.text)
906    }
907
908    async fn call_with_usage(&self, prompt: &str) -> Result<LlmCallResponse, AppError> {
909        self.call_with_tools(prompt, None).await
910    }
911
912    async fn call_with_tools(
913        &self,
914        prompt: &str,
915        tools: Option<&[serde_json::Value]>,
916    ) -> Result<LlmCallResponse, AppError> {
917        // Gemini authenticates via query parameter, not Bearer token.
918        let url = format!(
919            "{}/v1beta/models/{}:generateContent?key={}",
920            self.base_url, self.model, self.api_key
921        );
922
923        let mut body = serde_json::json!({
924            "contents": [{"role": "user", "parts": [{"text": prompt}]}]
925        });
926
927        // Gemini wraps tool schemas inside `functionDeclarations`.
928        if let Some(tools) = tools
929            && !tools.is_empty()
930        {
931            body["tools"] = serde_json::json!([{"functionDeclarations": tools}]);
932            body["toolConfig"] = serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}});
933        }
934
935        let client = create_http_client();
936        let resp = client
937            .post(&url)
938            .header("Content-Type", "application/json")
939            .json(&body)
940            .send()
941            .await
942            .map_err(|e| AppError::Llm(format!("gemini request failed: {e}")))?;
943
944        if !resp.status().is_success() {
945            let status = resp.status();
946            let body = resp.text().await.unwrap_or_default();
947            return Err(AppError::Llm(format!("gemini http {status}: {body}")));
948        }
949
950        let v: serde_json::Value = resp.json().await?;
951        let content = gemini_extract_content(&v);
952
953        let usage = self.parse_usage(&v);
954        tracing::debug!(
955            "Gemini token usage: {} input, {} output, ${:.6} estimated, {} tool calls",
956            usage.input_tokens,
957            usage.output_tokens,
958            usage.estimated_cost_usd.unwrap_or(0.0),
959            content.tool_calls.len()
960        );
961
962        Ok(LlmCallResponse::with_tool_calls(
963            content.text,
964            usage,
965            content.tool_calls,
966        ))
967    }
968}
969
970#[cfg(feature = "ollama")]
971/// HTTP-based Ollama local provider
972pub struct OllamaProvider {
973    pub base_url: String,
974    pub model: String,
975}
976
977#[cfg(feature = "ollama")]
978impl OllamaProvider {
979    /// Parse token usage from Ollama API response
980    fn parse_usage(&self, response: &serde_json::Value) -> TokenUsage {
981        // Ollama returns eval_count (output tokens) and prompt_eval_count (input tokens)
982        let input_tokens = response["prompt_eval_count"].as_u64().unwrap_or(0) as u32;
983        let output_tokens = response["eval_count"].as_u64().unwrap_or(0) as u32;
984
985        // Ollama is local, so no cost
986        TokenUsage::new(input_tokens, output_tokens)
987            .with_model(self.model.clone())
988            .with_provider("ollama")
989            .with_cost(0.0)
990    }
991}
992
993#[cfg(feature = "ollama")]
994#[async_trait::async_trait]
995impl LlmProvider for OllamaProvider {
996    async fn call(&self, prompt: &str) -> Result<String, AppError> {
997        let response = self.call_with_usage(prompt).await?;
998        Ok(response.text)
999    }
1000
1001    async fn call_with_usage(&self, prompt: &str) -> Result<LlmCallResponse, AppError> {
1002        self.call_with_tools(prompt, None).await
1003    }
1004
1005    async fn call_with_tools(
1006        &self,
1007        prompt: &str,
1008        tools: Option<&[serde_json::Value]>,
1009    ) -> Result<LlmCallResponse, AppError> {
1010        let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
1011        // Ollama uses OpenAI-compatible tool schema format.
1012        let mut body = serde_json::json!({
1013            "model": self.model,
1014            "messages": [{"role":"user","content": prompt}],
1015            "stream": false
1016        });
1017
1018        if let Some(tools) = tools
1019            && !tools.is_empty()
1020        {
1021            body["tools"] = serde_json::Value::Array(tools.to_vec());
1022        }
1023
1024        let client = create_http_client();
1025        let resp = client
1026            .post(&url)
1027            .json(&body)
1028            .send()
1029            .await
1030            .map_err(|e| AppError::Llm(format!("ollama request failed: {}", e)))?;
1031        if !resp.status().is_success() {
1032            let status = resp.status();
1033            let body = resp.text().await.unwrap_or_default();
1034            return Err(AppError::Llm(format!("ollama http {}: {}", status, body)));
1035        }
1036        let v: serde_json::Value = resp.json().await?;
1037        let text = v["message"]["content"].as_str().unwrap_or("").to_string();
1038
1039        // Extract structured tool calls (Ollama uses OpenAI-compatible format).
1040        let tool_calls = extract_openai_tool_calls(&v["message"]);
1041
1042        let usage = self.parse_usage(&v);
1043        tracing::debug!(
1044            "Ollama token usage: {} input, {} output (local, no cost), {} tool calls",
1045            usage.input_tokens,
1046            usage.output_tokens,
1047            tool_calls.len()
1048        );
1049
1050        Ok(LlmCallResponse::with_tool_calls(text, usage, tool_calls))
1051    }
1052}
1053
1054/// Create an unconfigured provider that returns an error when called.
1055/// Used when a provider is not properly configured.
1056pub fn unconfigured_provider(provider_name: &str) -> Box<dyn LlmProvider> {
1057    Box::new(UnconfiguredProvider {
1058        provider_name: provider_name.to_string(),
1059    })
1060}
1061
1062#[cfg(test)]
1063mod tests {
1064    use super::*;
1065    #[cfg(any(feature = "anthropic", feature = "gemini"))]
1066    use serde_json::json;
1067
1068    #[tokio::test]
1069    async fn test_unconfigured_provider_returns_error() {
1070        let provider = UnconfiguredProvider {
1071            provider_name: "test".to_string(),
1072        };
1073        let result = provider.call("Hello").await;
1074        assert!(result.is_err());
1075        let err = result.unwrap_err();
1076        assert!(err.to_string().contains("not configured"));
1077    }
1078
1079    #[test]
1080    #[cfg(feature = "openai")]
1081    fn test_openai_responses_output_extraction() {
1082        let response = serde_json::json!({
1083            "output_text": "final answer",
1084            "output": [
1085                {
1086                    "type": "function_call",
1087                    "id": "fc_123",
1088                    "call_id": "call_123",
1089                    "name": "shell",
1090                    "arguments": "{\"command\":\"pwd\"}"
1091                }
1092            ]
1093        });
1094
1095        assert_eq!(extract_openai_responses_text(&response), "final answer");
1096        assert_eq!(
1097            extract_openai_responses_tool_calls(&response),
1098            vec![ToolCallInfo {
1099                id: "call_123".to_string(),
1100                name: "shell".to_string(),
1101                arguments: "{\"command\":\"pwd\"}".to_string(),
1102            }]
1103        );
1104    }
1105
1106    #[test]
1107    #[cfg(feature = "anthropic")]
1108    fn test_anthropic_extract_text_and_thinking() {
1109        let v = json!({
1110            "content": [
1111                {"type": "thinking", "thinking": "plan\n"},
1112                {"type": "text", "text": "answer"}
1113            ]
1114        });
1115        let (text, thinking) = anthropic_extract_text_and_thinking(&v);
1116        assert_eq!(text, "answer");
1117        assert_eq!(thinking, "plan\n");
1118    }
1119
1120    #[test]
1121    #[cfg(feature = "gemini")]
1122    fn test_gemini_extract_content_text_only() {
1123        let v = json!({
1124            "candidates": [{
1125                "content": {
1126                    "parts": [{"text": "Hello, world!"}],
1127                    "role": "model"
1128                }
1129            }],
1130            "usageMetadata": {
1131                "promptTokenCount": 5,
1132                "candidatesTokenCount": 3,
1133                "totalTokenCount": 8
1134            }
1135        });
1136        let content = gemini_extract_content(&v);
1137        assert_eq!(content.text, "Hello, world!");
1138        assert!(content.tool_calls.is_empty());
1139    }
1140
1141    #[test]
1142    #[cfg(feature = "gemini")]
1143    fn test_gemini_extract_content_with_tool_calls() {
1144        let v = json!({
1145            "candidates": [{
1146                "content": {
1147                    "parts": [
1148                        {"text": "Let me check that file."},
1149                        {"functionCall": {
1150                            "name": "file_read",
1151                            "args": {"path": "/tmp/test.txt"}
1152                        }}
1153                    ],
1154                    "role": "model"
1155                }
1156            }]
1157        });
1158        let content = gemini_extract_content(&v);
1159        assert_eq!(content.text, "Let me check that file.");
1160        assert_eq!(content.tool_calls.len(), 1);
1161        assert_eq!(content.tool_calls[0].name, "file_read");
1162        assert_eq!(content.tool_calls[0].id, "gemini-call-1");
1163        let args: serde_json::Value =
1164            serde_json::from_str(&content.tool_calls[0].arguments).unwrap();
1165        assert_eq!(args["path"], "/tmp/test.txt");
1166    }
1167
1168    #[test]
1169    #[cfg(feature = "gemini")]
1170    fn test_gemini_extract_content_multiple_tool_calls() {
1171        let v = json!({
1172            "candidates": [{
1173                "content": {
1174                    "parts": [
1175                        {"functionCall": {
1176                            "name": "file_read",
1177                            "args": {"path": "a.txt"}
1178                        }},
1179                        {"functionCall": {
1180                            "name": "shell_exec",
1181                            "args": {"command": "ls"}
1182                        }}
1183                    ],
1184                    "role": "model"
1185                }
1186            }]
1187        });
1188        let content = gemini_extract_content(&v);
1189        assert!(content.text.is_empty());
1190        assert_eq!(content.tool_calls.len(), 2);
1191        assert_eq!(content.tool_calls[0].name, "file_read");
1192        assert_eq!(content.tool_calls[0].id, "gemini-call-0");
1193        assert_eq!(content.tool_calls[1].name, "shell_exec");
1194        assert_eq!(content.tool_calls[1].id, "gemini-call-1");
1195    }
1196
1197    #[test]
1198    #[cfg(feature = "gemini")]
1199    fn test_gemini_extract_content_empty_response() {
1200        let v = json!({"candidates": [{"content": {"parts": []}}]});
1201        let content = gemini_extract_content(&v);
1202        assert!(content.text.is_empty());
1203        assert!(content.tool_calls.is_empty());
1204    }
1205
1206    #[test]
1207    #[cfg(feature = "gemini")]
1208    fn test_gemini_parse_usage() {
1209        let provider = GeminiProvider {
1210            api_key: "test".to_string(),
1211            base_url: "https://example.com".to_string(),
1212            model: "gemini-2.0-flash".to_string(),
1213        };
1214        let v = json!({
1215            "usageMetadata": {
1216                "promptTokenCount": 100,
1217                "candidatesTokenCount": 50,
1218                "totalTokenCount": 150
1219            }
1220        });
1221        let usage = provider.parse_usage(&v);
1222        assert_eq!(usage.input_tokens, 100);
1223        assert_eq!(usage.output_tokens, 50);
1224        assert_eq!(usage.total_tokens, 150);
1225        assert_eq!(usage.provider.as_deref(), Some("gemini"));
1226        assert_eq!(usage.model.as_deref(), Some("gemini-2.0-flash"));
1227        // 2.0 Flash: $0.10/1M input, $0.40/1M output
1228        // 100 input → 0.00001, 50 output → 0.00002 → total 0.00003
1229        let cost = usage.estimated_cost_usd.unwrap();
1230        assert!((cost - 0.00003).abs() < 1e-9);
1231    }
1232
1233    #[test]
1234    #[cfg(feature = "gemini")]
1235    fn test_gemini_parse_usage_pro_pricing() {
1236        let provider = GeminiProvider {
1237            api_key: "test".to_string(),
1238            base_url: "https://example.com".to_string(),
1239            model: "gemini-1.5-pro".to_string(),
1240        };
1241        let v = json!({
1242            "usageMetadata": {
1243                "promptTokenCount": 1_000_000,
1244                "candidatesTokenCount": 1_000_000,
1245                "totalTokenCount": 2_000_000
1246            }
1247        });
1248        let usage = provider.parse_usage(&v);
1249        // 1.5 Pro: $1.25/1M input, $5.00/1M output → total $6.25
1250        let cost = usage.estimated_cost_usd.unwrap();
1251        assert!((cost - 6.25).abs() < 1e-6);
1252    }
1253}