1pub mod default_models;
49pub mod model_capabilities;
50pub mod model_discovery;
51pub mod model_listing;
52pub mod openai;
53pub mod token_tracker;
54
55use gestura_core_foundation::AppError;
56use serde::{Deserialize, Serialize};
57use std::time::Duration;
58
59#[cfg(feature = "openai")]
60use crate::openai::{
61 OpenAiApi, is_openai_model_incompatible_with_agent_session, openai_agent_session_model_message,
62 openai_api_for_model,
63};
64
65const LLM_TIMEOUT_SECS: u64 = 120;
67
68fn create_http_client() -> reqwest::Client {
70 reqwest::Client::builder()
71 .timeout(Duration::from_secs(LLM_TIMEOUT_SECS))
72 .connect_timeout(Duration::from_secs(10))
73 .build()
74 .unwrap_or_else(|_| reqwest::Client::new())
75}
76
77#[derive(Debug, Clone, Default, Serialize, Deserialize)]
79pub struct TokenUsage {
80 pub input_tokens: u32,
82 pub output_tokens: u32,
84 pub total_tokens: u32,
86 pub estimated_cost_usd: Option<f64>,
88 pub model: Option<String>,
90 pub provider: Option<String>,
92}
93
94impl TokenUsage {
95 pub fn new(input_tokens: u32, output_tokens: u32) -> Self {
97 Self {
98 input_tokens,
99 output_tokens,
100 total_tokens: input_tokens + output_tokens,
101 estimated_cost_usd: None,
102 model: None,
103 provider: None,
104 }
105 }
106
107 pub fn unknown() -> Self {
109 Self::default()
110 }
111
112 pub fn with_cost(mut self, cost_usd: f64) -> Self {
114 self.estimated_cost_usd = Some(cost_usd);
115 self
116 }
117
118 pub fn with_model(mut self, model: impl Into<String>) -> Self {
120 self.model = Some(model.into());
121 self
122 }
123
124 pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
126 self.provider = Some(provider.into());
127 self
128 }
129
130 pub fn calculate_cost(&mut self, input_price_per_million: f64, output_price_per_million: f64) {
132 let input_cost = (self.input_tokens as f64 / 1_000_000.0) * input_price_per_million;
133 let output_cost = (self.output_tokens as f64 / 1_000_000.0) * output_price_per_million;
134 self.estimated_cost_usd = Some(input_cost + output_cost);
135 }
136}
137
138#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
140pub struct ToolCallInfo {
141 pub id: String,
143 pub name: String,
145 pub arguments: String,
147}
148
149#[derive(Debug, Clone)]
151pub struct LlmCallResponse {
152 pub text: String,
154 pub usage: TokenUsage,
156 pub tool_calls: Vec<ToolCallInfo>,
158}
159
160impl LlmCallResponse {
161 pub fn new(text: String, usage: TokenUsage) -> Self {
163 Self {
164 text,
165 usage,
166 tool_calls: Vec::new(),
167 }
168 }
169
170 pub fn with_unknown_usage(text: String) -> Self {
172 Self {
173 text,
174 usage: TokenUsage::unknown(),
175 tool_calls: Vec::new(),
176 }
177 }
178
179 pub fn with_tool_calls(text: String, usage: TokenUsage, tool_calls: Vec<ToolCallInfo>) -> Self {
181 Self {
182 text,
183 usage,
184 tool_calls,
185 }
186 }
187}
188
189#[derive(Debug, Clone, Default)]
191pub struct AgentContext {
192 pub agent_id: String,
193}
194
195#[async_trait::async_trait]
197pub trait LlmProvider: Send + Sync {
198 async fn call(&self, prompt: &str) -> Result<String, AppError>;
201
202 async fn call_with_usage(&self, prompt: &str) -> Result<LlmCallResponse, AppError> {
205 let text = self.call(prompt).await?;
206 Ok(LlmCallResponse::with_unknown_usage(text))
207 }
208
209 async fn call_with_tools(
219 &self,
220 prompt: &str,
221 _tools: Option<&[serde_json::Value]>,
222 ) -> Result<LlmCallResponse, AppError> {
223 self.call_with_usage(prompt).await
224 }
225}
226
227pub struct UnconfiguredProvider {
230 pub provider_name: String,
231}
232
233#[async_trait::async_trait]
234impl LlmProvider for UnconfiguredProvider {
235 async fn call(&self, _prompt: &str) -> Result<String, AppError> {
236 Err(AppError::Llm(format!(
237 "LLM provider '{}' is not configured. Please configure it in Settings or run 'gestura config edit'.",
238 self.provider_name
239 )))
240 }
241}
242
243#[cfg(feature = "openai")]
244pub struct OpenAiProvider {
246 pub api_key: String,
247 pub base_url: String,
248 pub model: String,
249}
250
251#[cfg(feature = "openai")]
252impl OpenAiProvider {
253 fn endpoint_path(api: OpenAiApi) -> &'static str {
254 match api {
255 OpenAiApi::ChatCompletions => "/v1/chat/completions",
256 OpenAiApi::Responses => "/v1/responses",
257 }
258 }
259
260 fn enrich_openai_error(
261 &self,
262 api: OpenAiApi,
263 status: reqwest::StatusCode,
264 body: &str,
265 ) -> String {
266 if status == reqwest::StatusCode::NOT_FOUND && body.contains("This is not a chat model") {
267 return format!(
268 "OpenAI model '{}' appears to require /v1/responses, but Gestura selected {}. Raw OpenAI error: {body}",
269 self.model,
270 Self::endpoint_path(api)
271 );
272 }
273
274 format!(
275 "OpenAI {} HTTP {}: {}",
276 Self::endpoint_path(api),
277 status,
278 body
279 )
280 }
281
282 fn parse_usage(&self, response: &serde_json::Value) -> TokenUsage {
284 let usage = &response["usage"];
285 let input_tokens = usage["prompt_tokens"]
286 .as_u64()
287 .or_else(|| usage["input_tokens"].as_u64())
288 .unwrap_or(0) as u32;
289 let output_tokens = usage["completion_tokens"]
290 .as_u64()
291 .or_else(|| usage["output_tokens"].as_u64())
292 .unwrap_or(0) as u32;
293
294 let mut token_usage = TokenUsage::new(input_tokens, output_tokens)
295 .with_model(self.model.clone())
296 .with_provider("openai");
297
298 let (input_price, output_price) = match self.model.as_str() {
303 m if m.starts_with("gpt-4o") => (2.50, 10.0),
304 m if m.starts_with("gpt-4") => (30.0, 60.0),
305 m if m.starts_with("gpt-3.5") => (0.50, 1.50),
306 _ => (2.50, 10.0), };
308 token_usage.calculate_cost(input_price, output_price);
309
310 token_usage
311 }
312}
313
314#[cfg(feature = "openai")]
315fn build_openai_chat_request_body(
316 model: &str,
317 prompt: &str,
318 tools: Option<&[serde_json::Value]>,
319) -> serde_json::Value {
320 let mut body = serde_json::json!({
321 "model": model,
322 "messages": [{"role":"user","content": prompt}]
323 });
324
325 if let Some(tools) = tools
326 && !tools.is_empty()
327 {
328 body["tools"] = serde_json::Value::Array(tools.to_vec());
329 body["tool_choice"] = serde_json::json!("auto");
330 }
331
332 body
333}
334
335#[cfg(feature = "openai")]
336fn build_openai_responses_request_body(
337 model: &str,
338 prompt: &str,
339 tools: Option<&[serde_json::Value]>,
340) -> serde_json::Value {
341 let mut body = serde_json::json!({
342 "model": model,
343 "input": [{"role":"user","content": prompt}]
344 });
345
346 if let Some(tools) = tools
347 && !tools.is_empty()
348 {
349 body["tools"] = serde_json::Value::Array(tools.to_vec());
350 body["tool_choice"] = serde_json::json!("auto");
351 }
352
353 body
354}
355
356#[cfg(feature = "openai")]
357fn extract_openai_responses_text(response: &serde_json::Value) -> String {
358 if let Some(text) = response["output_text"].as_str() {
359 return text.to_string();
360 }
361
362 response["output"]
363 .as_array()
364 .into_iter()
365 .flatten()
366 .filter(|item| item["type"].as_str() == Some("message"))
367 .flat_map(|item| item["content"].as_array().into_iter().flatten())
368 .filter_map(|content| match content["type"].as_str() {
369 Some("output_text") => content["text"].as_str(),
370 _ => None,
371 })
372 .collect::<Vec<_>>()
373 .join("")
374}
375
376#[cfg(feature = "openai")]
377fn extract_openai_responses_tool_calls(response: &serde_json::Value) -> Vec<ToolCallInfo> {
378 let Some(output) = response["output"].as_array() else {
379 return Vec::new();
380 };
381
382 output
383 .iter()
384 .filter(|item| item["type"].as_str() == Some("function_call"))
385 .filter_map(|item| {
386 let name = item["name"].as_str()?;
387 Some(ToolCallInfo {
388 id: item["call_id"]
389 .as_str()
390 .or_else(|| item["id"].as_str())
391 .unwrap_or_default()
392 .to_string(),
393 name: name.to_string(),
394 arguments: item["arguments"].as_str().unwrap_or("{}").to_string(),
395 })
396 })
397 .collect()
398}
399
400#[cfg(feature = "openai")]
401#[async_trait::async_trait]
402impl LlmProvider for OpenAiProvider {
403 async fn call(&self, prompt: &str) -> Result<String, AppError> {
404 let response = self.call_with_usage(prompt).await?;
405 Ok(response.text)
406 }
407
408 async fn call_with_usage(&self, prompt: &str) -> Result<LlmCallResponse, AppError> {
409 self.call_with_tools(prompt, None).await
410 }
411
412 async fn call_with_tools(
413 &self,
414 prompt: &str,
415 tools: Option<&[serde_json::Value]>,
416 ) -> Result<LlmCallResponse, AppError> {
417 if is_openai_model_incompatible_with_agent_session(&self.model) {
418 return Err(AppError::Llm(openai_agent_session_model_message(
419 &self.model,
420 )));
421 }
422
423 let api = openai_api_for_model(&self.model);
424
425 let url = format!(
426 "{}{}",
427 self.base_url.trim_end_matches('/'),
428 Self::endpoint_path(api)
429 );
430 let body = match api {
434 OpenAiApi::ChatCompletions => {
435 build_openai_chat_request_body(&self.model, prompt, tools)
436 }
437 OpenAiApi::Responses => build_openai_responses_request_body(&self.model, prompt, tools),
438 };
439
440 let client = create_http_client();
441 let resp = client
442 .post(&url)
443 .bearer_auth(&self.api_key)
444 .json(&body)
445 .send()
446 .await
447 .map_err(|e| AppError::Llm(format!("openai request failed: {}", e)))?;
448 if !resp.status().is_success() {
449 let status = resp.status();
450 let body = resp.text().await.unwrap_or_default();
451 return Err(AppError::Llm(self.enrich_openai_error(api, status, &body)));
452 }
453 let v: serde_json::Value = resp.json().await?;
454 let (text, tool_calls) = match api {
455 OpenAiApi::ChatCompletions => (
456 v["choices"][0]["message"]["content"]
457 .as_str()
458 .unwrap_or("")
459 .to_string(),
460 extract_openai_tool_calls(&v["choices"][0]["message"]),
461 ),
462 OpenAiApi::Responses => (
463 extract_openai_responses_text(&v),
464 extract_openai_responses_tool_calls(&v),
465 ),
466 };
467
468 let usage = self.parse_usage(&v);
469 tracing::debug!(
470 endpoint = Self::endpoint_path(api),
471 "OpenAI token usage: {} input, {} output, ${:.6} estimated, {} tool calls",
472 usage.input_tokens,
473 usage.output_tokens,
474 usage.estimated_cost_usd.unwrap_or(0.0),
475 tool_calls.len()
476 );
477
478 Ok(LlmCallResponse::with_tool_calls(text, usage, tool_calls))
479 }
480}
481
482#[cfg(feature = "anthropic")]
483pub struct AnthropicProvider {
485 pub api_key: String,
486 pub base_url: String,
487 pub model: String,
488
489 pub thinking_budget_tokens: Option<u32>,
492}
493
494#[cfg(feature = "anthropic")]
495impl AnthropicProvider {
496 fn parse_usage(&self, response: &serde_json::Value) -> TokenUsage {
498 let usage = &response["usage"];
499 let input_tokens = usage["input_tokens"].as_u64().unwrap_or(0) as u32;
500 let output_tokens = usage["output_tokens"].as_u64().unwrap_or(0) as u32;
501
502 let mut token_usage = TokenUsage::new(input_tokens, output_tokens)
503 .with_model(self.model.clone())
504 .with_provider("anthropic");
505
506 let (input_price, output_price) = match self.model.as_str() {
511 m if m.contains("opus") => (15.0, 75.0),
512 m if m.contains("sonnet") => (3.0, 15.0),
513 m if m.contains("haiku") => (0.25, 1.25),
514 _ => (3.0, 15.0), };
516 token_usage.calculate_cost(input_price, output_price);
517
518 token_usage
519 }
520}
521
522#[cfg(any(feature = "openai", feature = "grok", feature = "ollama"))]
523fn extract_openai_tool_calls(message: &serde_json::Value) -> Vec<ToolCallInfo> {
528 let Some(tool_calls) = message["tool_calls"].as_array() else {
529 return Vec::new();
530 };
531
532 tool_calls
533 .iter()
534 .filter_map(|call| {
535 let name = call["function"]["name"].as_str()?;
536 let id = call["id"].as_str().unwrap_or_default().to_string();
537 let arguments = call["function"]["arguments"]
538 .as_str()
539 .unwrap_or("{}")
540 .to_string();
541 Some(ToolCallInfo {
542 id,
543 name: name.to_string(),
544 arguments,
545 })
546 })
547 .collect()
548}
549
550#[cfg(feature = "anthropic")]
551struct AnthropicContent {
553 text: String,
554 thinking: String,
555 tool_calls: Vec<ToolCallInfo>,
556}
557
558#[cfg(feature = "anthropic")]
559fn anthropic_extract_content(response_json: &serde_json::Value) -> AnthropicContent {
564 let mut result = AnthropicContent {
565 text: String::new(),
566 thinking: String::new(),
567 tool_calls: Vec::new(),
568 };
569
570 let Some(blocks) = response_json["content"].as_array() else {
571 return result;
572 };
573
574 for block in blocks {
575 let block_type = block["type"].as_str().unwrap_or("");
576 match block_type {
577 "text" => {
578 if let Some(t) = block.get("text").and_then(|v| v.as_str()) {
579 result.text.push_str(t);
580 }
581 }
582 "thinking" => {
583 if let Some(t) = block
585 .get("thinking")
586 .and_then(|v| v.as_str())
587 .or_else(|| block.get("text").and_then(|v| v.as_str()))
588 {
589 result.thinking.push_str(t);
590 }
591 }
592 "tool_use" => {
593 let id = block["id"].as_str().unwrap_or_default().to_string();
594 let name = block["name"].as_str().unwrap_or_default().to_string();
595 let arguments = if let Some(input) = block.get("input") {
597 serde_json::to_string(input).unwrap_or_default()
598 } else {
599 "{}".to_string()
600 };
601 if !name.is_empty() {
602 result.tool_calls.push(ToolCallInfo {
603 id,
604 name,
605 arguments,
606 });
607 }
608 }
609 _ => {}
610 }
611 }
612
613 result
614}
615
616#[cfg(all(test, feature = "anthropic"))]
620fn anthropic_extract_text_and_thinking(response_json: &serde_json::Value) -> (String, String) {
621 let content = anthropic_extract_content(response_json);
622 (content.text, content.thinking)
623}
624
625#[cfg(feature = "anthropic")]
626#[async_trait::async_trait]
627impl LlmProvider for AnthropicProvider {
628 async fn call(&self, prompt: &str) -> Result<String, AppError> {
629 let response = self.call_with_usage(prompt).await?;
630 Ok(response.text)
631 }
632
633 async fn call_with_usage(&self, prompt: &str) -> Result<LlmCallResponse, AppError> {
634 self.call_with_tools(prompt, None).await
635 }
636
637 async fn call_with_tools(
638 &self,
639 prompt: &str,
640 tools: Option<&[serde_json::Value]>,
641 ) -> Result<LlmCallResponse, AppError> {
642 let url = format!("{}/v1/messages", self.base_url);
643 let mut body = serde_json::json!({
644 "model": self.model,
645 "max_tokens": 512,
646 "messages": [{"role":"user","content": [{"type":"text","text": prompt}]}]
647 });
648
649 if let Some(budget_tokens) = self.thinking_budget_tokens {
650 body["thinking"] =
652 serde_json::json!({ "type": "enabled", "budget_tokens": budget_tokens });
653 }
654
655 if let Some(tools) = tools
657 && !tools.is_empty()
658 {
659 body["tools"] = serde_json::Value::Array(tools.to_vec());
660 }
661
662 let client = create_http_client();
663 let resp = client
664 .post(&url)
665 .header("x-api-key", &self.api_key)
666 .header("anthropic-version", "2023-06-01")
667 .json(&body)
668 .send()
669 .await
670 .map_err(|e| AppError::Llm(format!("anthropic request failed: {}", e)))?;
671 if !resp.status().is_success() {
672 let status = resp.status();
673 let body = resp.text().await.unwrap_or_default();
674 return Err(AppError::Llm(format!(
675 "anthropic http {}: {}",
676 status, body
677 )));
678 }
679 let v: serde_json::Value = resp.json().await?;
680 let content = anthropic_extract_content(&v);
681 let text = if content.thinking.trim().is_empty() {
682 content.text
683 } else {
684 format!("<think>{}</think>{}", content.thinking, content.text)
687 };
688
689 let usage = self.parse_usage(&v);
690 tracing::debug!(
691 "Anthropic token usage: {} input, {} output, ${:.6} estimated, {} tool calls",
692 usage.input_tokens,
693 usage.output_tokens,
694 usage.estimated_cost_usd.unwrap_or(0.0),
695 content.tool_calls.len()
696 );
697
698 Ok(LlmCallResponse::with_tool_calls(
699 text,
700 usage,
701 content.tool_calls,
702 ))
703 }
704}
705
706#[cfg(feature = "grok")]
707pub struct GrokProvider {
709 pub api_key: String,
710 pub base_url: String,
711 pub model: String,
712}
713
714#[cfg(feature = "grok")]
715impl GrokProvider {
716 fn parse_usage(&self, response: &serde_json::Value) -> TokenUsage {
718 let usage = &response["usage"];
719 let input_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
720 let output_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
721
722 let mut token_usage = TokenUsage::new(input_tokens, output_tokens)
723 .with_model(self.model.clone())
724 .with_provider("grok");
725
726 token_usage.calculate_cost(2.0, 10.0);
729
730 token_usage
731 }
732}
733
734#[cfg(feature = "grok")]
735#[async_trait::async_trait]
736impl LlmProvider for GrokProvider {
737 async fn call(&self, prompt: &str) -> Result<String, AppError> {
738 let response = self.call_with_usage(prompt).await?;
739 Ok(response.text)
740 }
741
742 async fn call_with_usage(&self, prompt: &str) -> Result<LlmCallResponse, AppError> {
743 self.call_with_tools(prompt, None).await
744 }
745
746 async fn call_with_tools(
747 &self,
748 prompt: &str,
749 tools: Option<&[serde_json::Value]>,
750 ) -> Result<LlmCallResponse, AppError> {
751 let url = format!(
752 "{}/v1/chat/completions",
753 self.base_url.trim_end_matches('/')
754 );
755 let mut body = serde_json::json!({
757 "model": self.model,
758 "messages": [{"role":"user","content": prompt}],
759 });
760
761 if let Some(tools) = tools
762 && !tools.is_empty()
763 {
764 body["tools"] = serde_json::Value::Array(tools.to_vec());
765 body["tool_choice"] = serde_json::json!("auto");
766 }
767
768 let client = create_http_client();
769 let resp = client
770 .post(&url)
771 .bearer_auth(&self.api_key)
772 .json(&body)
773 .send()
774 .await
775 .map_err(|e| AppError::Llm(format!("grok request failed: {}", e)))?;
776 if !resp.status().is_success() {
777 let status = resp.status();
778 let body = resp.text().await.unwrap_or_default();
779 return Err(AppError::Llm(format!("grok http {}: {}", status, body)));
780 }
781 let v: serde_json::Value = resp.json().await?;
782 let text = v["choices"][0]["message"]["content"]
783 .as_str()
784 .unwrap_or("")
785 .to_string();
786
787 let tool_calls = extract_openai_tool_calls(&v["choices"][0]["message"]);
789
790 let usage = self.parse_usage(&v);
791 tracing::debug!(
792 "Grok token usage: {} input, {} output, ${:.6} estimated, {} tool calls",
793 usage.input_tokens,
794 usage.output_tokens,
795 usage.estimated_cost_usd.unwrap_or(0.0),
796 tool_calls.len()
797 );
798
799 Ok(LlmCallResponse::with_tool_calls(text, usage, tool_calls))
800 }
801}
802
803#[cfg(feature = "gemini")]
804pub struct GeminiProvider {
810 pub api_key: String,
812 pub base_url: String,
814 pub model: String,
816}
817
818#[cfg(feature = "gemini")]
819impl GeminiProvider {
820 fn parse_usage(&self, response: &serde_json::Value) -> TokenUsage {
824 let usage = &response["usageMetadata"];
825 let input_tokens = usage["promptTokenCount"].as_u64().unwrap_or(0) as u32;
826 let output_tokens = usage["candidatesTokenCount"].as_u64().unwrap_or(0) as u32;
827
828 let mut token_usage = TokenUsage::new(input_tokens, output_tokens)
829 .with_model(self.model.clone())
830 .with_provider("gemini");
831
832 let (input_price, output_price) = match self.model.as_str() {
838 m if m.contains("1.5-pro") => (1.25, 5.00),
839 m if m.contains("flash-lite") => (0.075, 0.30),
840 m if m.contains("1.5-flash") => (0.075, 0.30),
841 m if m.contains("flash") => (0.10, 0.40), _ => (0.10, 0.40),
843 };
844 token_usage.calculate_cost(input_price, output_price);
845
846 token_usage
847 }
848}
849
850#[cfg(feature = "gemini")]
852struct GeminiContent {
853 text: String,
854 tool_calls: Vec<ToolCallInfo>,
855}
856
857#[cfg(feature = "gemini")]
863fn gemini_extract_content(response: &serde_json::Value) -> GeminiContent {
864 let mut result = GeminiContent {
865 text: String::new(),
866 tool_calls: Vec::new(),
867 };
868
869 let Some(parts) = response["candidates"][0]["content"]["parts"].as_array() else {
870 return result;
871 };
872
873 for (idx, part) in parts.iter().enumerate() {
874 if let Some(text) = part["text"].as_str() {
875 if !result.text.is_empty() {
876 result.text.push('\n');
877 }
878 result.text.push_str(text);
879 }
880 if let Some(fc) = part.get("functionCall") {
881 let name = fc["name"].as_str().unwrap_or_default().to_string();
882 let args = if let Some(a) = fc.get("args") {
883 serde_json::to_string(a).unwrap_or_default()
884 } else {
885 "{}".to_string()
886 };
887 if !name.is_empty() {
888 result.tool_calls.push(ToolCallInfo {
889 id: format!("gemini-call-{idx}"),
890 name,
891 arguments: args,
892 });
893 }
894 }
895 }
896
897 result
898}
899
900#[cfg(feature = "gemini")]
901#[async_trait::async_trait]
902impl LlmProvider for GeminiProvider {
903 async fn call(&self, prompt: &str) -> Result<String, AppError> {
904 let response = self.call_with_usage(prompt).await?;
905 Ok(response.text)
906 }
907
908 async fn call_with_usage(&self, prompt: &str) -> Result<LlmCallResponse, AppError> {
909 self.call_with_tools(prompt, None).await
910 }
911
912 async fn call_with_tools(
913 &self,
914 prompt: &str,
915 tools: Option<&[serde_json::Value]>,
916 ) -> Result<LlmCallResponse, AppError> {
917 let url = format!(
919 "{}/v1beta/models/{}:generateContent?key={}",
920 self.base_url, self.model, self.api_key
921 );
922
923 let mut body = serde_json::json!({
924 "contents": [{"role": "user", "parts": [{"text": prompt}]}]
925 });
926
927 if let Some(tools) = tools
929 && !tools.is_empty()
930 {
931 body["tools"] = serde_json::json!([{"functionDeclarations": tools}]);
932 body["toolConfig"] = serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}});
933 }
934
935 let client = create_http_client();
936 let resp = client
937 .post(&url)
938 .header("Content-Type", "application/json")
939 .json(&body)
940 .send()
941 .await
942 .map_err(|e| AppError::Llm(format!("gemini request failed: {e}")))?;
943
944 if !resp.status().is_success() {
945 let status = resp.status();
946 let body = resp.text().await.unwrap_or_default();
947 return Err(AppError::Llm(format!("gemini http {status}: {body}")));
948 }
949
950 let v: serde_json::Value = resp.json().await?;
951 let content = gemini_extract_content(&v);
952
953 let usage = self.parse_usage(&v);
954 tracing::debug!(
955 "Gemini token usage: {} input, {} output, ${:.6} estimated, {} tool calls",
956 usage.input_tokens,
957 usage.output_tokens,
958 usage.estimated_cost_usd.unwrap_or(0.0),
959 content.tool_calls.len()
960 );
961
962 Ok(LlmCallResponse::with_tool_calls(
963 content.text,
964 usage,
965 content.tool_calls,
966 ))
967 }
968}
969
970#[cfg(feature = "ollama")]
971pub struct OllamaProvider {
973 pub base_url: String,
974 pub model: String,
975}
976
977#[cfg(feature = "ollama")]
978impl OllamaProvider {
979 fn parse_usage(&self, response: &serde_json::Value) -> TokenUsage {
981 let input_tokens = response["prompt_eval_count"].as_u64().unwrap_or(0) as u32;
983 let output_tokens = response["eval_count"].as_u64().unwrap_or(0) as u32;
984
985 TokenUsage::new(input_tokens, output_tokens)
987 .with_model(self.model.clone())
988 .with_provider("ollama")
989 .with_cost(0.0)
990 }
991}
992
993#[cfg(feature = "ollama")]
994#[async_trait::async_trait]
995impl LlmProvider for OllamaProvider {
996 async fn call(&self, prompt: &str) -> Result<String, AppError> {
997 let response = self.call_with_usage(prompt).await?;
998 Ok(response.text)
999 }
1000
1001 async fn call_with_usage(&self, prompt: &str) -> Result<LlmCallResponse, AppError> {
1002 self.call_with_tools(prompt, None).await
1003 }
1004
1005 async fn call_with_tools(
1006 &self,
1007 prompt: &str,
1008 tools: Option<&[serde_json::Value]>,
1009 ) -> Result<LlmCallResponse, AppError> {
1010 let url = format!("{}/api/chat", self.base_url.trim_end_matches('/'));
1011 let mut body = serde_json::json!({
1013 "model": self.model,
1014 "messages": [{"role":"user","content": prompt}],
1015 "stream": false
1016 });
1017
1018 if let Some(tools) = tools
1019 && !tools.is_empty()
1020 {
1021 body["tools"] = serde_json::Value::Array(tools.to_vec());
1022 }
1023
1024 let client = create_http_client();
1025 let resp = client
1026 .post(&url)
1027 .json(&body)
1028 .send()
1029 .await
1030 .map_err(|e| AppError::Llm(format!("ollama request failed: {}", e)))?;
1031 if !resp.status().is_success() {
1032 let status = resp.status();
1033 let body = resp.text().await.unwrap_or_default();
1034 return Err(AppError::Llm(format!("ollama http {}: {}", status, body)));
1035 }
1036 let v: serde_json::Value = resp.json().await?;
1037 let text = v["message"]["content"].as_str().unwrap_or("").to_string();
1038
1039 let tool_calls = extract_openai_tool_calls(&v["message"]);
1041
1042 let usage = self.parse_usage(&v);
1043 tracing::debug!(
1044 "Ollama token usage: {} input, {} output (local, no cost), {} tool calls",
1045 usage.input_tokens,
1046 usage.output_tokens,
1047 tool_calls.len()
1048 );
1049
1050 Ok(LlmCallResponse::with_tool_calls(text, usage, tool_calls))
1051 }
1052}
1053
1054pub fn unconfigured_provider(provider_name: &str) -> Box<dyn LlmProvider> {
1057 Box::new(UnconfiguredProvider {
1058 provider_name: provider_name.to_string(),
1059 })
1060}
1061
1062#[cfg(test)]
1063mod tests {
1064 use super::*;
1065 #[cfg(any(feature = "anthropic", feature = "gemini"))]
1066 use serde_json::json;
1067
1068 #[tokio::test]
1069 async fn test_unconfigured_provider_returns_error() {
1070 let provider = UnconfiguredProvider {
1071 provider_name: "test".to_string(),
1072 };
1073 let result = provider.call("Hello").await;
1074 assert!(result.is_err());
1075 let err = result.unwrap_err();
1076 assert!(err.to_string().contains("not configured"));
1077 }
1078
1079 #[test]
1080 #[cfg(feature = "openai")]
1081 fn test_openai_responses_output_extraction() {
1082 let response = serde_json::json!({
1083 "output_text": "final answer",
1084 "output": [
1085 {
1086 "type": "function_call",
1087 "id": "fc_123",
1088 "call_id": "call_123",
1089 "name": "shell",
1090 "arguments": "{\"command\":\"pwd\"}"
1091 }
1092 ]
1093 });
1094
1095 assert_eq!(extract_openai_responses_text(&response), "final answer");
1096 assert_eq!(
1097 extract_openai_responses_tool_calls(&response),
1098 vec![ToolCallInfo {
1099 id: "call_123".to_string(),
1100 name: "shell".to_string(),
1101 arguments: "{\"command\":\"pwd\"}".to_string(),
1102 }]
1103 );
1104 }
1105
1106 #[test]
1107 #[cfg(feature = "anthropic")]
1108 fn test_anthropic_extract_text_and_thinking() {
1109 let v = json!({
1110 "content": [
1111 {"type": "thinking", "thinking": "plan\n"},
1112 {"type": "text", "text": "answer"}
1113 ]
1114 });
1115 let (text, thinking) = anthropic_extract_text_and_thinking(&v);
1116 assert_eq!(text, "answer");
1117 assert_eq!(thinking, "plan\n");
1118 }
1119
1120 #[test]
1121 #[cfg(feature = "gemini")]
1122 fn test_gemini_extract_content_text_only() {
1123 let v = json!({
1124 "candidates": [{
1125 "content": {
1126 "parts": [{"text": "Hello, world!"}],
1127 "role": "model"
1128 }
1129 }],
1130 "usageMetadata": {
1131 "promptTokenCount": 5,
1132 "candidatesTokenCount": 3,
1133 "totalTokenCount": 8
1134 }
1135 });
1136 let content = gemini_extract_content(&v);
1137 assert_eq!(content.text, "Hello, world!");
1138 assert!(content.tool_calls.is_empty());
1139 }
1140
1141 #[test]
1142 #[cfg(feature = "gemini")]
1143 fn test_gemini_extract_content_with_tool_calls() {
1144 let v = json!({
1145 "candidates": [{
1146 "content": {
1147 "parts": [
1148 {"text": "Let me check that file."},
1149 {"functionCall": {
1150 "name": "file_read",
1151 "args": {"path": "/tmp/test.txt"}
1152 }}
1153 ],
1154 "role": "model"
1155 }
1156 }]
1157 });
1158 let content = gemini_extract_content(&v);
1159 assert_eq!(content.text, "Let me check that file.");
1160 assert_eq!(content.tool_calls.len(), 1);
1161 assert_eq!(content.tool_calls[0].name, "file_read");
1162 assert_eq!(content.tool_calls[0].id, "gemini-call-1");
1163 let args: serde_json::Value =
1164 serde_json::from_str(&content.tool_calls[0].arguments).unwrap();
1165 assert_eq!(args["path"], "/tmp/test.txt");
1166 }
1167
1168 #[test]
1169 #[cfg(feature = "gemini")]
1170 fn test_gemini_extract_content_multiple_tool_calls() {
1171 let v = json!({
1172 "candidates": [{
1173 "content": {
1174 "parts": [
1175 {"functionCall": {
1176 "name": "file_read",
1177 "args": {"path": "a.txt"}
1178 }},
1179 {"functionCall": {
1180 "name": "shell_exec",
1181 "args": {"command": "ls"}
1182 }}
1183 ],
1184 "role": "model"
1185 }
1186 }]
1187 });
1188 let content = gemini_extract_content(&v);
1189 assert!(content.text.is_empty());
1190 assert_eq!(content.tool_calls.len(), 2);
1191 assert_eq!(content.tool_calls[0].name, "file_read");
1192 assert_eq!(content.tool_calls[0].id, "gemini-call-0");
1193 assert_eq!(content.tool_calls[1].name, "shell_exec");
1194 assert_eq!(content.tool_calls[1].id, "gemini-call-1");
1195 }
1196
1197 #[test]
1198 #[cfg(feature = "gemini")]
1199 fn test_gemini_extract_content_empty_response() {
1200 let v = json!({"candidates": [{"content": {"parts": []}}]});
1201 let content = gemini_extract_content(&v);
1202 assert!(content.text.is_empty());
1203 assert!(content.tool_calls.is_empty());
1204 }
1205
1206 #[test]
1207 #[cfg(feature = "gemini")]
1208 fn test_gemini_parse_usage() {
1209 let provider = GeminiProvider {
1210 api_key: "test".to_string(),
1211 base_url: "https://example.com".to_string(),
1212 model: "gemini-2.0-flash".to_string(),
1213 };
1214 let v = json!({
1215 "usageMetadata": {
1216 "promptTokenCount": 100,
1217 "candidatesTokenCount": 50,
1218 "totalTokenCount": 150
1219 }
1220 });
1221 let usage = provider.parse_usage(&v);
1222 assert_eq!(usage.input_tokens, 100);
1223 assert_eq!(usage.output_tokens, 50);
1224 assert_eq!(usage.total_tokens, 150);
1225 assert_eq!(usage.provider.as_deref(), Some("gemini"));
1226 assert_eq!(usage.model.as_deref(), Some("gemini-2.0-flash"));
1227 let cost = usage.estimated_cost_usd.unwrap();
1230 assert!((cost - 0.00003).abs() < 1e-9);
1231 }
1232
1233 #[test]
1234 #[cfg(feature = "gemini")]
1235 fn test_gemini_parse_usage_pro_pricing() {
1236 let provider = GeminiProvider {
1237 api_key: "test".to_string(),
1238 base_url: "https://example.com".to_string(),
1239 model: "gemini-1.5-pro".to_string(),
1240 };
1241 let v = json!({
1242 "usageMetadata": {
1243 "promptTokenCount": 1_000_000,
1244 "candidatesTokenCount": 1_000_000,
1245 "totalTokenCount": 2_000_000
1246 }
1247 });
1248 let usage = provider.parse_usage(&v);
1249 let cost = usage.estimated_cost_usd.unwrap();
1251 assert!((cost - 6.25).abs() < 1e-6);
1252 }
1253}