1use crate::config::StreamingConfig;
7use futures_util::StreamExt;
8use gestura_core_foundation::AppError;
9use gestura_core_llm::TokenUsage;
10use gestura_core_llm::openai::{
11 OpenAiApi, is_openai_model_incompatible_with_agent_session, openai_agent_session_model_message,
12 openai_api_for_model,
13};
14use gestura_core_retry::RetryPolicy;
15use gestura_core_tools::schemas::ProviderToolSchemas;
16use std::collections::{BTreeMap, HashMap, HashSet};
17use std::sync::Arc;
18use std::sync::atomic::{AtomicU8, Ordering};
19use std::time::Duration;
20use tokio::sync::mpsc;
21use tracing::Instrument as _;
22
23const STREAMING_TIMEOUT_SECS: u64 = 300;
25const STREAM_CHUNK_BUFFER_CAPACITY: usize = 256;
26const STATUS_CHUNK_SEND_TIMEOUT: Duration = Duration::from_millis(100);
27const TOKEN_USAGE_CHUNK_SEND_TIMEOUT: Duration = Duration::from_millis(100);
28
29async fn send_status_chunk_best_effort(tx: &mpsc::Sender<StreamChunk>, chunk: StreamChunk) {
30 debug_assert!(matches!(chunk, StreamChunk::Status { .. }));
31
32 match tokio::time::timeout(STATUS_CHUNK_SEND_TIMEOUT, tx.send(chunk)).await {
33 Ok(Ok(())) | Ok(Err(_)) => {}
34 Err(_) => {
35 tracing::debug!(
36 timeout_ms = STATUS_CHUNK_SEND_TIMEOUT.as_millis(),
37 "Dropping transient status chunk because the stream receiver is not draining fast enough"
38 );
39 }
40 }
41}
42
43async fn send_token_usage_chunk_best_effort(tx: &mpsc::Sender<StreamChunk>, chunk: StreamChunk) {
44 debug_assert!(matches!(chunk, StreamChunk::TokenUsageUpdate { .. }));
45
46 match tokio::time::timeout(TOKEN_USAGE_CHUNK_SEND_TIMEOUT, tx.send(chunk)).await {
47 Ok(Ok(())) | Ok(Err(_)) => {}
48 Err(_) => {
49 tracing::debug!(
50 timeout_ms = TOKEN_USAGE_CHUNK_SEND_TIMEOUT.as_millis(),
51 "Dropping transient token-usage chunk because the stream receiver is not draining fast enough"
52 );
53 }
54 }
55}
56
57pub mod pricing {
60 pub const OPENAI_GPT4_TURBO_INPUT: f64 = 10.0;
62 pub const OPENAI_GPT4_TURBO_OUTPUT: f64 = 30.0;
63
64 pub const OPENAI_GPT35_TURBO_INPUT: f64 = 0.50;
66 pub const OPENAI_GPT35_TURBO_OUTPUT: f64 = 1.50;
67
68 pub const ANTHROPIC_CLAUDE_35_SONNET_INPUT: f64 = 3.0;
70 pub const ANTHROPIC_CLAUDE_35_SONNET_OUTPUT: f64 = 15.0;
71
72 pub const ANTHROPIC_CLAUDE_3_OPUS_INPUT: f64 = 15.0;
74 pub const ANTHROPIC_CLAUDE_3_OPUS_OUTPUT: f64 = 75.0;
75
76 pub const ANTHROPIC_CLAUDE_3_HAIKU_INPUT: f64 = 0.25;
78 pub const ANTHROPIC_CLAUDE_3_HAIKU_OUTPUT: f64 = 1.25;
79
80 pub const GEMINI_20_FLASH_INPUT: f64 = 0.10;
82 pub const GEMINI_20_FLASH_OUTPUT: f64 = 0.40;
83
84 pub const GEMINI_20_FLASH_LITE_INPUT: f64 = 0.075;
86 pub const GEMINI_20_FLASH_LITE_OUTPUT: f64 = 0.30;
87
88 pub const GEMINI_15_PRO_INPUT: f64 = 1.25;
90 pub const GEMINI_15_PRO_OUTPUT: f64 = 5.00;
91
92 pub const GEMINI_15_FLASH_INPUT: f64 = 0.075;
94 pub const GEMINI_15_FLASH_OUTPUT: f64 = 0.30;
95
96 pub const XAI_GROK_INPUT: f64 = 5.0;
98 pub const XAI_GROK_OUTPUT: f64 = 15.0;
99
100 pub const OLLAMA_INPUT: f64 = 0.0;
102 pub const OLLAMA_OUTPUT: f64 = 0.0;
103
104 pub const DEFAULT_INPUT: f64 = 1.0;
106 pub const DEFAULT_OUTPUT: f64 = 3.0;
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub enum TokenUsageStatus {
112 Green,
114 Yellow,
116 Red,
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
122#[serde(rename_all = "lowercase")]
123pub enum ShellOutputStream {
124 Stdout,
126 Stderr,
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
132#[serde(rename_all = "snake_case")]
133pub enum ShellProcessState {
134 Started,
136 Completed,
138 Failed,
140 Stopped,
142 Paused,
144 Resumed,
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
150#[serde(rename_all = "snake_case")]
151pub enum ShellSessionState {
152 Starting,
154 Idle,
156 Busy,
158 Interrupting,
160 Stopping,
162 Stopped,
164 Failed,
166}
167
168#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
170pub struct TaskRuntimeTaskView {
171 pub id: String,
173 pub name: String,
175 pub status: String,
177}
178
179#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
181pub struct TaskRuntimeSnapshot {
182 pub root_task_id: String,
184 pub current_task: Option<TaskRuntimeTaskView>,
186 #[serde(default, skip_serializing_if = "Vec::is_empty")]
188 pub ready_tasks: Vec<TaskRuntimeTaskView>,
189 #[serde(default, skip_serializing_if = "Vec::is_empty")]
191 pub parallel_ready_tasks: Vec<TaskRuntimeTaskView>,
192 #[serde(default, skip_serializing_if = "Vec::is_empty")]
194 pub blocked_tasks: Vec<TaskRuntimeTaskView>,
195 #[serde(default, skip_serializing_if = "Vec::is_empty")]
197 pub open_tasks: Vec<TaskRuntimeTaskView>,
198 #[serde(default, skip_serializing_if = "Vec::is_empty")]
200 pub completed_tasks: Vec<TaskRuntimeTaskView>,
201 #[serde(default, skip_serializing_if = "Vec::is_empty")]
203 pub missing_requirements: Vec<String>,
204 pub status_message: String,
206}
207
208#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
210#[serde(rename_all = "snake_case")]
211pub enum NarrationStage {
212 Context,
214 Planning,
216 Execution,
218 Verification,
220 Blocked,
222 Progress,
224}
225
226impl NarrationStage {
227 pub const fn as_str(&self) -> &'static str {
229 match self {
230 Self::Context => "context",
231 Self::Planning => "planning",
232 Self::Execution => "execution",
233 Self::Verification => "verification",
234 Self::Blocked => "blocked",
235 Self::Progress => "progress",
236 }
237 }
238}
239
240#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
242pub struct PublicNarration {
243 pub title: String,
245 pub message: String,
247 #[serde(default, skip_serializing_if = "Option::is_none")]
249 pub summary: Option<String>,
250 #[serde(default, skip_serializing_if = "Option::is_none")]
252 pub reason: Option<String>,
253 #[serde(default, skip_serializing_if = "Option::is_none")]
255 pub next_step: Option<String>,
256 #[serde(default, skip_serializing_if = "Vec::is_empty")]
258 pub evidence: Vec<String>,
259}
260
261#[derive(Debug, Clone)]
263pub enum StreamChunk {
264 Thinking(String),
266 Narration {
268 narration: PublicNarration,
270 stage: NarrationStage,
271 },
272 Text(String),
274 ToolCallStart { id: String, name: String },
276 ToolCallArgs(String),
278 ToolCallEnd,
280 ToolCallResult {
282 name: String,
284 success: bool,
286 output: String,
288 duration_ms: u64,
290 },
291 RetryAttempt {
293 attempt: u32,
295 max_attempts: u32,
297 delay_ms: u64,
299 error_message: String,
301 },
302 ContextCompacted {
304 messages_before: usize,
306 messages_after: usize,
308 tokens_saved: usize,
310 summary: String,
312 },
313 TokenUsageUpdate {
316 estimated: usize,
318 limit: usize,
320 percentage: u8,
322 status: TokenUsageStatus,
324 estimated_cost: f64,
326 },
327
328 Status {
334 message: String,
336 },
337 ConfigRequest {
342 operation: String,
344 key: String,
346 value: Option<String>,
348 requires_confirmation: bool,
350 },
351 ToolConfirmationRequired {
357 confirmation_id: String,
359 tool_name: String,
361 tool_args: String,
363 description: String,
365 risk_level: u8,
367 category: String,
369 },
370 ToolBlocked {
375 tool_name: String,
377 reason: String,
379 },
380 MemoryBankSaved {
383 file_path: String,
385 session_id: String,
387 summary: String,
389 messages_saved: usize,
391 },
392 AgentLoopIteration {
400 iteration: u32,
402 },
403 TaskRuntimeSnapshot {
405 snapshot: TaskRuntimeSnapshot,
407 },
408 ShellOutput {
414 process_id: String,
416 shell_session_id: Option<String>,
418 stream: ShellOutputStream,
420 data: String,
422 },
423 ShellLifecycle {
430 process_id: String,
432 shell_session_id: Option<String>,
434 state: ShellProcessState,
436 exit_code: Option<i32>,
438 duration_ms: Option<u64>,
440 command: String,
442 cwd: Option<String>,
444 },
445 ShellSessionLifecycle {
447 shell_session_id: String,
449 state: ShellSessionState,
451 cwd: Option<String>,
453 active_process_id: Option<String>,
455 active_command: Option<String>,
457 available_for_reuse: bool,
459 interactive: bool,
461 user_managed: bool,
463 },
464 Done(Option<TokenUsage>),
466 Cancelled,
468 Paused,
474 Error(String),
476 ContextOverflow {
484 error_message: String,
486 },
487 ReflectionStarted {
492 reason: String,
494 },
495 ReflectionComplete {
500 summary: String,
502 stored: bool,
504 promoted: bool,
506 },
507}
508
509#[derive(Debug, Clone, Copy, PartialEq, Eq)]
510enum AttemptOutcome {
511 Success,
512 RetryableError,
513 ContextOverflowError,
515 FatalError,
516 Cancelled,
517 Paused,
518 UnexpectedEnd,
519}
520
521#[derive(Debug, Clone)]
522struct AttemptForwardResult {
523 outcome: AttemptOutcome,
524 forwarded_output: bool,
526 error: Option<String>,
528}
529
530async fn forward_attempt_stream(
536 attempt_rx: &mut mpsc::Receiver<StreamChunk>,
537 tx: &mpsc::Sender<StreamChunk>,
538) -> AttemptForwardResult {
539 let mut forwarded_output = false;
540
541 while let Some(chunk) = attempt_rx.recv().await {
542 match &chunk {
543 StreamChunk::Text(_)
544 | StreamChunk::Thinking(_)
545 | StreamChunk::ToolCallStart { .. }
546 | StreamChunk::ToolCallArgs(_)
547 | StreamChunk::ToolCallEnd
548 | StreamChunk::ToolCallResult { .. } => {
549 forwarded_output = true;
550 let _ = tx.send(chunk).await;
551 }
552 StreamChunk::RetryAttempt { .. } => {
553 let _ = tx.send(chunk).await;
555 }
556 StreamChunk::ContextCompacted { .. } => {
557 let _ = tx.send(chunk).await;
559 }
560 StreamChunk::TokenUsageUpdate { .. } => {
561 send_token_usage_chunk_best_effort(tx, chunk).await;
563 }
564 StreamChunk::Status { .. } => {
565 send_status_chunk_best_effort(tx, chunk).await;
567 }
568 StreamChunk::ConfigRequest { .. } => {
569 let _ = tx.send(chunk).await;
571 }
572 StreamChunk::ToolConfirmationRequired { .. } => {
573 let _ = tx.send(chunk).await;
575 }
576 StreamChunk::ToolBlocked { .. } => {
577 let _ = tx.send(chunk).await;
579 }
580 StreamChunk::MemoryBankSaved { .. } => {
581 let _ = tx.send(chunk).await;
583 }
584 StreamChunk::AgentLoopIteration { .. } => {
585 let _ = tx.send(chunk).await;
587 }
588 StreamChunk::Narration { .. } => {
589 let _ = tx.try_send(chunk);
591 }
592 StreamChunk::TaskRuntimeSnapshot { .. } => {
593 let _ = tx.try_send(chunk);
595 }
596 StreamChunk::ReflectionStarted { .. } | StreamChunk::ReflectionComplete { .. } => {
597 let _ = tx.send(chunk).await;
599 }
600 StreamChunk::ShellOutput { .. } => {
601 let _ = tx.send(chunk).await;
605 }
606 StreamChunk::ShellLifecycle { .. } => {
607 let _ = tx.send(chunk).await;
609 }
610 StreamChunk::ShellSessionLifecycle { .. } => {
611 let _ = tx.send(chunk).await;
613 }
614 StreamChunk::Done(_) => {
615 let _ = tx.send(chunk).await;
616 return AttemptForwardResult {
617 outcome: AttemptOutcome::Success,
618 forwarded_output,
619 error: None,
620 };
621 }
622 StreamChunk::Cancelled => {
623 let _ = tx.send(StreamChunk::Cancelled).await;
624 return AttemptForwardResult {
625 outcome: AttemptOutcome::Cancelled,
626 forwarded_output,
627 error: None,
628 };
629 }
630 StreamChunk::Paused => {
631 let _ = tx.send(StreamChunk::Paused).await;
632 return AttemptForwardResult {
633 outcome: AttemptOutcome::Paused,
634 forwarded_output,
635 error: None,
636 };
637 }
638 StreamChunk::Error(e) => {
639 if is_context_overflow_message(e) {
642 return AttemptForwardResult {
643 outcome: AttemptOutcome::ContextOverflowError,
644 forwarded_output,
645 error: Some(e.clone()),
646 };
647 }
648
649 if forwarded_output {
652 let _ = tx.send(StreamChunk::Error(e.clone())).await;
653 return AttemptForwardResult {
654 outcome: AttemptOutcome::FatalError,
655 forwarded_output,
656 error: Some(e.clone()),
657 };
658 }
659
660 return AttemptForwardResult {
661 outcome: AttemptOutcome::RetryableError,
662 forwarded_output,
663 error: Some(e.clone()),
664 };
665 }
666 StreamChunk::ContextOverflow { error_message } => {
667 return AttemptForwardResult {
669 outcome: AttemptOutcome::ContextOverflowError,
670 forwarded_output,
671 error: Some(error_message.clone()),
672 };
673 }
674 }
675 }
676
677 AttemptForwardResult {
678 outcome: AttemptOutcome::UnexpectedEnd,
679 forwarded_output,
680 error: None,
681 }
682}
683
684#[derive(Clone, Debug)]
686pub struct CancellationToken {
687 disposition: Arc<AtomicU8>,
688}
689
690#[derive(Clone, Copy, Debug, PartialEq, Eq)]
692#[repr(u8)]
693pub enum CancellationDisposition {
694 Running = 0,
695 Cancelled = 1,
696 Paused = 2,
697}
698
699impl CancellationToken {
700 pub fn new() -> Self {
702 Self {
703 disposition: Arc::new(AtomicU8::new(CancellationDisposition::Running as u8)),
704 }
705 }
706
707 pub fn cancel(&self) {
709 self.disposition
710 .store(CancellationDisposition::Cancelled as u8, Ordering::SeqCst);
711 }
712
713 pub fn pause(&self) {
715 let _ = self.disposition.compare_exchange(
716 CancellationDisposition::Running as u8,
717 CancellationDisposition::Paused as u8,
718 Ordering::SeqCst,
719 Ordering::SeqCst,
720 );
721 }
722
723 pub fn is_cancelled(&self) -> bool {
725 !matches!(self.disposition(), CancellationDisposition::Running)
726 }
727
728 pub fn is_pause_requested(&self) -> bool {
730 matches!(self.disposition(), CancellationDisposition::Paused)
731 }
732
733 pub fn disposition(&self) -> CancellationDisposition {
735 match self.disposition.load(Ordering::SeqCst) {
736 value if value == CancellationDisposition::Paused as u8 => {
737 CancellationDisposition::Paused
738 }
739 value if value == CancellationDisposition::Cancelled as u8 => {
740 CancellationDisposition::Cancelled
741 }
742 _ => CancellationDisposition::Running,
743 }
744 }
745
746 pub fn interruption_chunk(&self) -> StreamChunk {
748 match self.disposition() {
749 CancellationDisposition::Paused => StreamChunk::Paused,
750 CancellationDisposition::Cancelled | CancellationDisposition::Running => {
751 StreamChunk::Cancelled
752 }
753 }
754 }
755}
756
757impl Default for CancellationToken {
758 fn default() -> Self {
759 Self::new()
760 }
761}
762
763fn create_streaming_client() -> reqwest::Client {
765 reqwest::Client::builder()
766 .timeout(Duration::from_secs(STREAMING_TIMEOUT_SECS))
767 .connect_timeout(Duration::from_secs(10))
768 .build()
769 .unwrap_or_else(|_| reqwest::Client::new())
770}
771
772struct ThinkingParser {
775 in_think_block: bool,
776 buffer: String,
778}
779
780impl ThinkingParser {
781 fn new() -> Self {
782 Self {
783 in_think_block: false,
784 buffer: String::new(),
785 }
786 }
787
788 fn process(&mut self, chunk: &str) -> Vec<StreamChunk> {
789 let mut chunks = Vec::new();
790
791 let input = if self.buffer.is_empty() {
793 chunk.to_string()
794 } else {
795 std::mem::take(&mut self.buffer) + chunk
796 };
797
798 let mut remaining = input.as_str();
799
800 while !remaining.is_empty() {
801 if self.in_think_block {
802 if let Some(end_idx) = remaining.find("</think>") {
803 let thinking_content = &remaining[..end_idx];
804 if !thinking_content.is_empty() {
805 chunks.push(StreamChunk::Thinking(thinking_content.to_string()));
806 }
807 self.in_think_block = false;
808 remaining = &remaining[end_idx + 8..];
809 } else {
810 let partial = Self::find_partial_end_tag(remaining);
812 if partial > 0 {
813 let safe_len = remaining.len() - partial;
814 if safe_len > 0 {
815 chunks.push(StreamChunk::Thinking(remaining[..safe_len].to_string()));
816 }
817 self.buffer = remaining[safe_len..].to_string();
818 } else {
819 chunks.push(StreamChunk::Thinking(remaining.to_string()));
820 }
821 break;
822 }
823 } else if let Some(start_idx) = remaining.find("<think>") {
824 let text_content = &remaining[..start_idx];
825 if !text_content.is_empty() {
826 chunks.push(StreamChunk::Text(text_content.to_string()));
827 }
828 self.in_think_block = true;
829 remaining = &remaining[start_idx + 7..];
830 } else {
831 let partial = Self::find_partial_start_tag(remaining);
833 if partial > 0 {
834 let safe_len = remaining.len() - partial;
835 if safe_len > 0 {
836 chunks.push(StreamChunk::Text(remaining[..safe_len].to_string()));
837 }
838 self.buffer = remaining[safe_len..].to_string();
839 } else {
840 chunks.push(StreamChunk::Text(remaining.to_string()));
841 }
842 break;
843 }
844 }
845 chunks
846 }
847
848 fn find_partial_start_tag(s: &str) -> usize {
850 const TAG: &str = "<think>";
851 for len in (1..TAG.len()).rev() {
852 if s.ends_with(&TAG[..len]) {
853 return len;
854 }
855 }
856 0
857 }
858
859 fn find_partial_end_tag(s: &str) -> usize {
861 const TAG: &str = "</think>";
862 for len in (1..TAG.len()).rev() {
863 if s.ends_with(&TAG[..len]) {
864 return len;
865 }
866 }
867 0
868 }
869}
870
871pub fn split_think_blocks(input: &str) -> (String, Option<String>) {
876 let mut parser = ThinkingParser::new();
877 let mut content = String::new();
878 let mut thinking = String::new();
879
880 for chunk in parser.process(input) {
881 match chunk {
882 StreamChunk::Text(t) => content.push_str(&t),
883 StreamChunk::Thinking(t) => thinking.push_str(&t),
884 _ => {}
885 }
886 }
887
888 let thinking = if thinking.trim().is_empty() {
889 None
890 } else {
891 Some(thinking)
892 };
893
894 (content, thinking)
895}
896
897fn collect_complete_lines(buffer: &mut String, incoming: &str) -> Vec<String> {
898 buffer.push_str(incoming);
899 let mut out = Vec::new();
900 let mut start = 0usize;
901
902 {
903 let bytes = buffer.as_bytes();
904 for (i, b) in bytes.iter().enumerate() {
905 if *b == b'\n' {
906 let line = buffer[start..i].trim_end_matches('\r');
907 out.push(line.to_string());
908 start = i + 1;
909 }
910 }
911 }
912
913 if start > 0 {
914 buffer.drain(..start);
915 }
916
917 out
918}
919
920fn build_openai_chat_request_body(
922 model: &str,
923 prompt: &str,
924 tools: Option<&[serde_json::Value]>,
925) -> serde_json::Value {
926 let mut body = serde_json::json!({
927 "model": model,
928 "messages": [{"role": "user", "content": prompt}],
929 "stream": true
930 });
931
932 if let Some(tools) = tools
934 && !tools.is_empty()
935 {
936 body["tools"] = serde_json::Value::Array(tools.to_vec());
937 body["tool_choice"] = serde_json::json!("auto");
938 }
939
940 body
941}
942
943fn build_openai_responses_request_body(
945 model: &str,
946 prompt: &str,
947 tools: Option<&[serde_json::Value]>,
948) -> serde_json::Value {
949 let mut body = serde_json::json!({
950 "model": model,
951 "input": [{"role": "user", "content": prompt}],
952 "stream": true
953 });
954
955 if let Some(tools) = tools
956 && !tools.is_empty()
957 {
958 body["tools"] = serde_json::Value::Array(tools.to_vec());
959 body["tool_choice"] = serde_json::json!("auto");
960 }
961
962 body
963}
964
965fn openai_endpoint_path(api: OpenAiApi) -> &'static str {
966 match api {
967 OpenAiApi::ChatCompletions => "/v1/chat/completions",
968 OpenAiApi::Responses => "/v1/responses",
969 }
970}
971
972fn format_openai_http_error(
973 status: reqwest::StatusCode,
974 provider_name: &str,
975 model: &str,
976 api: OpenAiApi,
977 body: &str,
978 retry_after: Option<Duration>,
979) -> String {
980 if status == reqwest::StatusCode::NOT_FOUND && body.contains("This is not a chat model") {
981 let mut message = format!(
982 "{provider_name} model '{}' appears to require /v1/responses, but Gestura selected {}. Raw provider error: {}",
983 model.trim(),
984 openai_endpoint_path(api),
985 body
986 );
987 if let Some(retry_after) = retry_after {
988 message.push_str(&format_retry_after_suffix(retry_after));
989 }
990 return message;
991 }
992
993 let mut message = format!(
994 "{provider_name} {} HTTP {}: {}",
995 openai_endpoint_path(api),
996 status,
997 body
998 );
999 if let Some(retry_after) = retry_after {
1000 message.push_str(&format_retry_after_suffix(retry_after));
1001 }
1002 message
1003}
1004
1005fn format_retry_after_suffix(retry_after: Duration) -> String {
1006 format!(
1007 " Provider suggested retrying after {} seconds.",
1008 retry_after.as_secs().max(1)
1009 )
1010}
1011
1012fn parse_retry_after_value(value: &str) -> Option<Duration> {
1013 let seconds = value.trim().parse::<u64>().ok()?;
1014 Some(Duration::from_secs(seconds.max(1)))
1015}
1016
1017fn response_retry_after_hint(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
1018 headers
1019 .get(reqwest::header::RETRY_AFTER)?
1020 .to_str()
1021 .ok()
1022 .and_then(parse_retry_after_value)
1023}
1024
1025fn retry_after_hint_from_error_message(message: &str) -> Option<Duration> {
1026 let marker = "provider suggested retrying after ";
1027 let lower = message.to_ascii_lowercase();
1028 let start = lower.find(marker)? + marker.len();
1029 let remainder = &lower[start..];
1030 let seconds = remainder
1031 .chars()
1032 .take_while(|ch| ch.is_ascii_digit())
1033 .collect::<String>()
1034 .parse::<u64>()
1035 .ok()?;
1036 Some(Duration::from_secs(seconds.max(1)))
1037}
1038
1039fn error_is_rate_limited_message(message: &str) -> bool {
1040 let lower = message.to_ascii_lowercase();
1041 lower.contains("http 429")
1042 || lower.contains("rate limit")
1043 || lower.contains("too many requests")
1044 || lower.contains("quota")
1045}
1046
1047fn select_streaming_retry_delay(
1048 policy: &RetryPolicy,
1049 retry_attempt: u32,
1050 error_message: &str,
1051) -> Duration {
1052 let base_delay = policy.delay_for_attempt(retry_attempt);
1053
1054 if let Some(retry_after) = retry_after_hint_from_error_message(error_message) {
1055 return retry_after.max(base_delay);
1056 }
1057
1058 if error_is_rate_limited_message(error_message) {
1059 return base_delay.max(Duration::from_secs(5));
1060 }
1061
1062 base_delay
1063}
1064
1065#[derive(Debug, Clone, Default, PartialEq, Eq)]
1066struct PendingOpenAiToolCall {
1067 id: String,
1068 name: String,
1069 arguments: String,
1070}
1071
1072#[derive(Debug, Clone, Default, PartialEq, Eq)]
1073struct PendingOpenAiResponsesToolCall {
1074 id: String,
1075 name: String,
1076 arguments: String,
1077 finished: bool,
1078}
1079
1080fn merge_openai_tool_call_delta(
1081 pending: &mut BTreeMap<usize, PendingOpenAiToolCall>,
1082 call: &serde_json::Value,
1083 fallback_index: usize,
1084) {
1085 let index = call
1086 .get("index")
1087 .and_then(|value| value.as_u64())
1088 .map(|value| value as usize)
1089 .unwrap_or(fallback_index);
1090
1091 let entry = pending.entry(index).or_default();
1092
1093 if let Some(id) = call["id"].as_str()
1094 && !id.is_empty()
1095 {
1096 entry.id = id.to_string();
1097 }
1098
1099 if let Some(name) = call["function"]["name"].as_str()
1100 && !name.is_empty()
1101 {
1102 entry.name = name.to_string();
1103 }
1104
1105 if let Some(arguments) = call["function"]["arguments"].as_str()
1106 && !arguments.is_empty()
1107 {
1108 entry.arguments.push_str(arguments);
1109 }
1110}
1111
1112fn take_openai_tool_calls(
1113 pending: &mut BTreeMap<usize, PendingOpenAiToolCall>,
1114) -> Vec<(usize, PendingOpenAiToolCall)> {
1115 std::mem::take(pending)
1116 .into_iter()
1117 .filter(|(_, call)| !call.name.is_empty())
1118 .collect()
1119}
1120
1121async fn emit_openai_tool_calls(
1122 tx: &mpsc::Sender<StreamChunk>,
1123 pending: &mut BTreeMap<usize, PendingOpenAiToolCall>,
1124) {
1125 for (index, call) in take_openai_tool_calls(pending) {
1126 let id = if call.id.is_empty() {
1127 format!("openai-tool-{index}")
1128 } else {
1129 call.id
1130 };
1131
1132 let _ = tx
1133 .send(StreamChunk::ToolCallStart {
1134 id,
1135 name: call.name,
1136 })
1137 .await;
1138
1139 if !call.arguments.is_empty() {
1140 let _ = tx.send(StreamChunk::ToolCallArgs(call.arguments)).await;
1141 }
1142
1143 let _ = tx.send(StreamChunk::ToolCallEnd).await;
1144 }
1145}
1146
1147fn merge_openai_responses_tool_item(
1148 pending: &mut BTreeMap<usize, PendingOpenAiResponsesToolCall>,
1149 tool_indices: &mut HashMap<String, usize>,
1150 event: &serde_json::Value,
1151 fallback_index: usize,
1152) {
1153 let index = resolve_openai_responses_tool_index(tool_indices, event, fallback_index);
1154
1155 let item = event.get("item").unwrap_or(event);
1156 let entry = pending.entry(index).or_default();
1157
1158 if let Some(id) = item["call_id"].as_str().or_else(|| item["id"].as_str())
1159 && !id.is_empty()
1160 {
1161 entry.id = id.to_string();
1162 }
1163
1164 if let Some(name) = item["name"].as_str()
1165 && !name.is_empty()
1166 {
1167 entry.name = name.to_string();
1168 }
1169
1170 if let Some(arguments) = item["arguments"].as_str()
1171 && !arguments.is_empty()
1172 {
1173 entry.arguments = arguments.to_string();
1174 }
1175
1176 if event["type"].as_str() == Some("response.output_item.done")
1177 || item["status"].as_str() == Some("completed")
1178 {
1179 entry.finished = true;
1180 }
1181}
1182
1183fn merge_openai_responses_tool_argument_delta(
1184 pending: &mut BTreeMap<usize, PendingOpenAiResponsesToolCall>,
1185 tool_indices: &mut HashMap<String, usize>,
1186 event: &serde_json::Value,
1187 fallback_index: usize,
1188) {
1189 let index = resolve_openai_responses_tool_index(tool_indices, event, fallback_index);
1190
1191 let entry = pending.entry(index).or_default();
1192
1193 if let Some(id) = event["call_id"].as_str()
1194 && !id.is_empty()
1195 {
1196 entry.id = id.to_string();
1197 } else if entry.id.is_empty()
1198 && let Some(id) = event["item_id"].as_str()
1199 && !id.is_empty()
1200 {
1201 entry.id = id.to_string();
1202 }
1203
1204 if let Some(delta) = event["delta"].as_str()
1205 && !delta.is_empty()
1206 {
1207 entry.arguments.push_str(delta);
1208 }
1209}
1210
1211fn complete_openai_responses_tool_arguments(
1212 pending: &mut BTreeMap<usize, PendingOpenAiResponsesToolCall>,
1213 tool_indices: &mut HashMap<String, usize>,
1214 event: &serde_json::Value,
1215 fallback_index: usize,
1216) {
1217 let index = resolve_openai_responses_tool_index(tool_indices, event, fallback_index);
1218
1219 let entry = pending.entry(index).or_default();
1220
1221 if let Some(id) = event["call_id"].as_str()
1222 && !id.is_empty()
1223 {
1224 entry.id = id.to_string();
1225 } else if entry.id.is_empty()
1226 && let Some(id) = event["item_id"].as_str()
1227 && !id.is_empty()
1228 {
1229 entry.id = id.to_string();
1230 }
1231
1232 if let Some(arguments) = event["arguments"].as_str()
1233 && !arguments.is_empty()
1234 {
1235 entry.arguments = arguments.to_string();
1236 }
1237
1238 entry.finished = true;
1239}
1240
1241async fn emit_ready_openai_responses_tool_calls(
1242 tx: &mpsc::Sender<StreamChunk>,
1243 pending: &mut BTreeMap<usize, PendingOpenAiResponsesToolCall>,
1244 emitted_ids: &mut HashSet<String>,
1245 flush_all: bool,
1246) {
1247 let mut ready = Vec::new();
1248
1249 for (&index, call) in pending.iter() {
1250 if call.name.is_empty() {
1251 if flush_all {
1252 continue;
1253 }
1254 break;
1255 }
1256
1257 if flush_all || call.finished {
1258 ready.push(index);
1259 continue;
1260 }
1261
1262 break;
1263 }
1264
1265 for index in ready {
1266 if let Some(call) = pending.remove(&index) {
1267 let id = if call.id.is_empty() {
1268 format!("openai-response-tool-{index}")
1269 } else {
1270 call.id
1271 };
1272
1273 if !emitted_ids.insert(id.clone()) {
1274 tracing::debug!(
1275 tool_call_id = %id,
1276 pending_index = index,
1277 "Skipping duplicate OpenAI Responses tool-call emission"
1278 );
1279 continue;
1280 }
1281
1282 let _ = tx
1283 .send(StreamChunk::ToolCallStart {
1284 id,
1285 name: call.name,
1286 })
1287 .await;
1288
1289 if !call.arguments.is_empty() {
1290 let _ = tx.send(StreamChunk::ToolCallArgs(call.arguments)).await;
1291 }
1292
1293 let _ = tx.send(StreamChunk::ToolCallEnd).await;
1294 }
1295 }
1296}
1297
1298fn openai_responses_output_index(event: &serde_json::Value) -> Option<usize> {
1299 event
1300 .get("output_index")
1301 .and_then(|value| value.as_u64())
1302 .map(|value| value as usize)
1303}
1304
1305fn openai_responses_tool_aliases(event: &serde_json::Value) -> Vec<String> {
1306 let item = event.get("item").unwrap_or(event);
1307 let mut aliases = Vec::with_capacity(4);
1308
1309 for candidate in [
1310 item["call_id"].as_str(),
1311 event["call_id"].as_str(),
1312 item["id"].as_str(),
1313 event["item_id"].as_str(),
1314 ] {
1315 if let Some(alias) = candidate.filter(|alias| !alias.is_empty())
1316 && !aliases.iter().any(|existing| existing == alias)
1317 {
1318 aliases.push(alias.to_string());
1319 }
1320 }
1321
1322 aliases
1323}
1324
1325fn resolve_openai_responses_tool_index(
1326 tool_indices: &mut HashMap<String, usize>,
1327 event: &serde_json::Value,
1328 fallback_index: usize,
1329) -> usize {
1330 let aliases = openai_responses_tool_aliases(event);
1331
1332 if let Some(existing_index) = aliases
1333 .iter()
1334 .find_map(|alias| tool_indices.get(alias).copied())
1335 {
1336 for alias in aliases {
1337 tool_indices.insert(alias, existing_index);
1338 }
1339 return existing_index;
1340 }
1341
1342 let index = openai_responses_output_index(event).unwrap_or(fallback_index);
1343 for alias in aliases {
1344 tool_indices.insert(alias, index);
1345 }
1346 index
1347}
1348
1349async fn stream_openai_chat_compatible(
1350 api_key: &str,
1351 base_url: &str,
1352 model: &str,
1353 prompt: &str,
1354 tools: Option<&[serde_json::Value]>,
1355 tx: mpsc::Sender<StreamChunk>,
1356 cancel_token: CancellationToken,
1357) -> Result<(), AppError> {
1358 let url = format!(
1359 "{}{}",
1360 base_url.trim_end_matches('/'),
1361 openai_endpoint_path(OpenAiApi::ChatCompletions)
1362 );
1363 let body = build_openai_chat_request_body(model, prompt, tools);
1364
1365 let client = create_streaming_client();
1366 let response = client
1367 .post(&url)
1368 .bearer_auth(api_key)
1369 .json(&body)
1370 .send()
1371 .await
1372 .map_err(|e| AppError::Llm(format!("OpenAI streaming request failed: {}", e)))?;
1373
1374 if !response.status().is_success() {
1375 let status = response.status();
1376 let retry_after = response_retry_after_hint(response.headers());
1377 let body = response.text().await.unwrap_or_default();
1378
1379 tracing::error!(
1381 status = %status,
1382 body_len = body.len(),
1383 "[CONTEXT_OVERFLOW_CHECK] HTTP error received in stream_openai_chat_compatible"
1384 );
1385
1386 let error_msg = format_openai_http_error(
1387 status,
1388 "OpenAI",
1389 model,
1390 OpenAiApi::ChatCompletions,
1391 &body,
1392 retry_after,
1393 );
1394
1395 let is_overflow =
1397 is_context_overflow_message(&error_msg) || is_context_overflow_message(&body);
1398 tracing::error!(
1399 is_overflow = is_overflow,
1400 body_preview = %body.chars().take(300).collect::<String>(),
1401 "[CONTEXT_OVERFLOW_CHECK] Checking for context overflow"
1402 );
1403
1404 if is_overflow {
1405 tracing::error!("[CONTEXT_OVERFLOW_CHECK] Returning AppError::ContextOverflow");
1406 return Err(AppError::ContextOverflow(error_msg));
1407 }
1408
1409 return Err(AppError::Llm(error_msg));
1410 }
1411
1412 let mut stream = response.bytes_stream();
1413 let mut parser = ThinkingParser::new();
1414 let mut line_buffer = String::new();
1415 let mut pending_tool_calls = BTreeMap::<usize, PendingOpenAiToolCall>::new();
1421
1422 while let Some(chunk_result) = stream.next().await {
1423 if cancel_token.is_cancelled() {
1424 let _ = tx.send(cancel_token.interruption_chunk()).await;
1425 return Ok(());
1426 }
1427
1428 match chunk_result {
1429 Ok(bytes) => {
1430 let text = String::from_utf8_lossy(&bytes);
1431 for line in collect_complete_lines(&mut line_buffer, &text) {
1432 let Some(data) = line.strip_prefix("data: ") else {
1433 continue;
1434 };
1435 if data == "[DONE]" {
1436 emit_openai_tool_calls(&tx, &mut pending_tool_calls).await;
1437 let _ = tx.send(StreamChunk::Done(None)).await;
1438 return Ok(());
1439 }
1440 if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
1441 if let Some(content) = json["choices"][0]["delta"]["content"].as_str()
1443 && !content.is_empty()
1444 {
1445 let chunks = parser.process(content);
1446 for chunk in chunks {
1447 let _ = tx.send(chunk).await;
1448 }
1449 }
1450
1451 if let Some(tool_calls) =
1453 json["choices"][0]["delta"]["tool_calls"].as_array()
1454 {
1455 for (fallback_index, call) in tool_calls.iter().enumerate() {
1456 merge_openai_tool_call_delta(
1457 &mut pending_tool_calls,
1458 call,
1459 fallback_index,
1460 );
1461 }
1462 }
1463
1464 if let Some(finish_reason) = json["choices"][0]["finish_reason"].as_str()
1466 && finish_reason == "tool_calls"
1467 {
1468 emit_openai_tool_calls(&tx, &mut pending_tool_calls).await;
1469 }
1470 }
1471 }
1472 }
1473 Err(e) => {
1474 let _ = tx
1475 .send(StreamChunk::Error(format!("Stream error: {}", e)))
1476 .await;
1477 return Err(AppError::Llm(format!("Stream error: {}", e)));
1478 }
1479 }
1480 }
1481
1482 let _ = tx.send(StreamChunk::Done(None)).await;
1483 Ok(())
1484}
1485
1486async fn stream_openai_responses(
1487 api_key: &str,
1488 base_url: &str,
1489 model: &str,
1490 prompt: &str,
1491 tools: Option<&[serde_json::Value]>,
1492 tx: mpsc::Sender<StreamChunk>,
1493 cancel_token: CancellationToken,
1494) -> Result<(), AppError> {
1495 let url = format!(
1496 "{}{}",
1497 base_url.trim_end_matches('/'),
1498 openai_endpoint_path(OpenAiApi::Responses)
1499 );
1500 let body = build_openai_responses_request_body(model, prompt, tools);
1501
1502 let client = create_streaming_client();
1503 let response = client
1504 .post(&url)
1505 .bearer_auth(api_key)
1506 .json(&body)
1507 .send()
1508 .await
1509 .map_err(|e| AppError::Llm(format!("OpenAI streaming request failed: {}", e)))?;
1510
1511 if !response.status().is_success() {
1512 let status = response.status();
1513 let retry_after = response_retry_after_hint(response.headers());
1514 let body = response.text().await.unwrap_or_default();
1515 let error_msg = format_openai_http_error(
1516 status,
1517 "OpenAI",
1518 model,
1519 OpenAiApi::Responses,
1520 &body,
1521 retry_after,
1522 );
1523
1524 if is_context_overflow_message(&error_msg) || is_context_overflow_message(&body) {
1526 return Err(AppError::ContextOverflow(error_msg));
1527 }
1528
1529 return Err(AppError::Llm(error_msg));
1530 }
1531
1532 let mut stream = response.bytes_stream();
1533 let mut parser = ThinkingParser::new();
1534 let mut line_buffer = String::new();
1535 let mut pending_tool_calls = BTreeMap::<usize, PendingOpenAiResponsesToolCall>::new();
1536 let mut tool_call_indices = HashMap::<String, usize>::new();
1537 let mut emitted_tool_call_ids = HashSet::<String>::new();
1538 let mut fallback_index = 0usize;
1539
1540 while let Some(chunk_result) = stream.next().await {
1541 if cancel_token.is_cancelled() {
1542 let _ = tx.send(cancel_token.interruption_chunk()).await;
1543 return Ok(());
1544 }
1545
1546 match chunk_result {
1547 Ok(bytes) => {
1548 let text = String::from_utf8_lossy(&bytes);
1549 for line in collect_complete_lines(&mut line_buffer, &text) {
1550 let Some(data) = line.strip_prefix("data: ") else {
1551 continue;
1552 };
1553 if data == "[DONE]" {
1554 emit_ready_openai_responses_tool_calls(
1555 &tx,
1556 &mut pending_tool_calls,
1557 &mut emitted_tool_call_ids,
1558 true,
1559 )
1560 .await;
1561 let _ = tx.send(StreamChunk::Done(None)).await;
1562 return Ok(());
1563 }
1564
1565 let Ok(json) = serde_json::from_str::<serde_json::Value>(data) else {
1566 continue;
1567 };
1568
1569 match json["type"].as_str().unwrap_or_default() {
1570 "response.output_text.delta" => {
1571 if let Some(delta) = json["delta"].as_str()
1572 && !delta.is_empty()
1573 {
1574 for chunk in parser.process(delta) {
1575 let _ = tx.send(chunk).await;
1576 }
1577 }
1578 }
1579 "response.output_item.added" | "response.output_item.done" => {
1580 if json["item"]["type"].as_str() == Some("function_call") {
1581 merge_openai_responses_tool_item(
1582 &mut pending_tool_calls,
1583 &mut tool_call_indices,
1584 &json,
1585 fallback_index,
1586 );
1587 emit_ready_openai_responses_tool_calls(
1588 &tx,
1589 &mut pending_tool_calls,
1590 &mut emitted_tool_call_ids,
1591 false,
1592 )
1593 .await;
1594 }
1595 }
1596 "response.function_call_arguments.delta" => {
1597 merge_openai_responses_tool_argument_delta(
1598 &mut pending_tool_calls,
1599 &mut tool_call_indices,
1600 &json,
1601 fallback_index,
1602 );
1603 }
1604 "response.function_call_arguments.done" => {
1605 complete_openai_responses_tool_arguments(
1606 &mut pending_tool_calls,
1607 &mut tool_call_indices,
1608 &json,
1609 fallback_index,
1610 );
1611 emit_ready_openai_responses_tool_calls(
1612 &tx,
1613 &mut pending_tool_calls,
1614 &mut emitted_tool_call_ids,
1615 false,
1616 )
1617 .await;
1618 }
1619 "response.completed" => {
1620 emit_ready_openai_responses_tool_calls(
1621 &tx,
1622 &mut pending_tool_calls,
1623 &mut emitted_tool_call_ids,
1624 true,
1625 )
1626 .await;
1627 let _ = tx.send(StreamChunk::Done(None)).await;
1628 return Ok(());
1629 }
1630 "response.failed" => {
1631 let message = json["response"]["error"]["message"]
1632 .as_str()
1633 .unwrap_or("OpenAI Responses stream failed")
1634 .to_string();
1635 let _ = tx.send(StreamChunk::Error(message.clone())).await;
1636 return Err(AppError::Llm(message));
1637 }
1638 _ => {}
1639 }
1640
1641 fallback_index = fallback_index.saturating_add(1);
1642 }
1643 }
1644 Err(e) => {
1645 let _ = tx
1646 .send(StreamChunk::Error(format!("Stream error: {}", e)))
1647 .await;
1648 return Err(AppError::Llm(format!("Stream error: {}", e)));
1649 }
1650 }
1651 }
1652
1653 emit_ready_openai_responses_tool_calls(
1654 &tx,
1655 &mut pending_tool_calls,
1656 &mut emitted_tool_call_ids,
1657 true,
1658 )
1659 .await;
1660 let _ = tx.send(StreamChunk::Done(None)).await;
1661 Ok(())
1662}
1663
1664pub async fn stream_openai(
1665 api_key: &str,
1666 base_url: &str,
1667 model: &str,
1668 prompt: &str,
1669 tools: Option<&[serde_json::Value]>,
1670 tx: mpsc::Sender<StreamChunk>,
1671 cancel_token: CancellationToken,
1672) -> Result<(), AppError> {
1673 if is_openai_model_incompatible_with_agent_session(model) {
1674 return Err(AppError::Llm(openai_agent_session_model_message(model)));
1675 }
1676
1677 match openai_api_for_model(model) {
1678 OpenAiApi::ChatCompletions => {
1679 stream_openai_chat_compatible(api_key, base_url, model, prompt, tools, tx, cancel_token)
1680 .await
1681 }
1682 OpenAiApi::Responses => {
1683 stream_openai_responses(api_key, base_url, model, prompt, tools, tx, cancel_token).await
1684 }
1685 }
1686}
1687
1688#[derive(Debug)]
1693pub struct AnthropicStreamRequest<'a> {
1694 pub api_key: &'a str,
1695 pub base_url: &'a str,
1696 pub model: &'a str,
1697 pub thinking_budget_tokens: Option<u32>,
1698 pub prompt: &'a str,
1699 pub tools: Option<&'a [serde_json::Value]>,
1700 pub tx: mpsc::Sender<StreamChunk>,
1701 pub cancel_token: CancellationToken,
1702}
1703
1704fn build_anthropic_messages_body(
1705 model: &str,
1706 prompt: &str,
1707 thinking_budget_tokens: Option<u32>,
1708 tools: Option<&[serde_json::Value]>,
1709) -> serde_json::Value {
1710 let mut body = serde_json::json!({
1711 "model": model,
1712 "max_tokens": 4096,
1713 "messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
1714 "stream": true
1715 });
1716
1717 if let Some(tools) = tools
1719 && !tools.is_empty()
1720 {
1721 body["tools"] = serde_json::Value::Array(tools.to_vec());
1722 }
1723
1724 if let Some(budget_tokens) = thinking_budget_tokens {
1726 body["thinking"] = serde_json::json!({ "type": "enabled", "budget_tokens": budget_tokens });
1727 }
1728
1729 body
1730}
1731
1732pub async fn stream_anthropic(req: AnthropicStreamRequest<'_>) -> Result<(), AppError> {
1733 let AnthropicStreamRequest {
1734 api_key,
1735 base_url,
1736 model,
1737 thinking_budget_tokens,
1738 prompt,
1739 tools,
1740 tx,
1741 cancel_token,
1742 } = req;
1743
1744 let url = format!("{}/v1/messages", base_url);
1745 let body = build_anthropic_messages_body(model, prompt, thinking_budget_tokens, tools);
1746
1747 let client = create_streaming_client();
1748 let response = client
1749 .post(&url)
1750 .header("x-api-key", api_key)
1751 .header("anthropic-version", "2023-06-01")
1752 .json(&body)
1753 .send()
1754 .await
1755 .map_err(|e| AppError::Llm(format!("Anthropic streaming request failed: {}", e)))?;
1756
1757 if !response.status().is_success() {
1758 let status = response.status();
1759 let body = response.text().await.unwrap_or_default();
1760 let error_msg = format!("Anthropic HTTP {}: {}", status, body);
1761
1762 if is_context_overflow_message(&error_msg) || is_context_overflow_message(&body) {
1764 return Err(AppError::ContextOverflow(error_msg));
1765 }
1766
1767 return Err(AppError::Llm(error_msg));
1768 }
1769
1770 let mut stream = response.bytes_stream();
1771 let mut line_buffer = String::new();
1772 let mut parser = ThinkingParser::new();
1773 let mut in_tool_block = false;
1774
1775 while let Some(chunk_result) = stream.next().await {
1776 if cancel_token.is_cancelled() {
1777 let _ = tx.send(cancel_token.interruption_chunk()).await;
1778 return Ok(());
1779 }
1780
1781 match chunk_result {
1782 Ok(bytes) => {
1783 let text = String::from_utf8_lossy(&bytes);
1784 for line in collect_complete_lines(&mut line_buffer, &text) {
1785 let Some(data) = line.strip_prefix("data: ") else {
1786 continue;
1787 };
1788 let Ok(json) = serde_json::from_str::<serde_json::Value>(data) else {
1789 continue;
1790 };
1791 match json["type"].as_str() {
1792 Some("message_stop") => {
1793 let _ = tx.send(StreamChunk::Done(None)).await;
1794 return Ok(());
1795 }
1796 Some("content_block_delta") => {
1797 if let Some(delta) = json["delta"].as_object() {
1798 if let Some(content) = delta.get("text").and_then(|v| v.as_str())
1800 && !content.is_empty()
1801 {
1802 for chunk in parser.process(content) {
1803 let _ = tx.send(chunk).await;
1804 }
1805 }
1806
1807 if let Some(thinking) =
1808 delta.get("thinking").and_then(|v| v.as_str())
1809 && !thinking.is_empty()
1810 {
1811 let _ =
1812 tx.send(StreamChunk::Thinking(thinking.to_string())).await;
1813 }
1814
1815 if in_tool_block
1816 && let Some(partial_json) =
1817 delta.get("partial_json").and_then(|v| v.as_str())
1818 && !partial_json.is_empty()
1819 {
1820 let _ = tx
1821 .send(StreamChunk::ToolCallArgs(partial_json.to_string()))
1822 .await;
1823 }
1824 }
1825 }
1826 Some("content_block_start") => {
1827 let block = &json["content_block"];
1830 if block["type"].as_str() == Some("tool_use")
1831 && let Some(name) = block["name"].as_str()
1832 {
1833 let id = block["id"].as_str().unwrap_or_default().to_string();
1834 let _ = tx
1835 .send(StreamChunk::ToolCallStart {
1836 id,
1837 name: name.to_string(),
1838 })
1839 .await;
1840 in_tool_block = true;
1841 }
1842 }
1843 Some("content_block_stop") => {
1844 if in_tool_block {
1845 let _ = tx.send(StreamChunk::ToolCallEnd).await;
1846 in_tool_block = false;
1847 }
1848 }
1849 _ => {}
1850 }
1851 }
1852 }
1853 Err(e) => {
1854 let _ = tx
1855 .send(StreamChunk::Error(format!("Stream error: {}", e)))
1856 .await;
1857 return Err(AppError::Llm(format!("Stream error: {}", e)));
1858 }
1859 }
1860 }
1861
1862 let _ = tx.send(StreamChunk::Done(None)).await;
1863 Ok(())
1864}
1865
1866fn build_gemini_body(prompt: &str, tools: Option<&[serde_json::Value]>) -> serde_json::Value {
1868 let mut body = serde_json::json!({
1869 "contents": [{"role": "user", "parts": [{"text": prompt}]}]
1870 });
1871
1872 if let Some(tools) = tools
1873 && !tools.is_empty()
1874 {
1875 body["tools"] = serde_json::json!([{"functionDeclarations": tools}]);
1876 body["toolConfig"] = serde_json::json!({"functionCallingConfig": {"mode": "AUTO"}});
1877 }
1878
1879 body
1880}
1881
1882pub async fn stream_gemini(
1887 api_key: &str,
1888 base_url: &str,
1889 model: &str,
1890 prompt: &str,
1891 tools: Option<&[serde_json::Value]>,
1892 tx: mpsc::Sender<StreamChunk>,
1893 cancel_token: CancellationToken,
1894) -> Result<(), AppError> {
1895 let url = format!(
1896 "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
1897 base_url, model, api_key
1898 );
1899 let body = build_gemini_body(prompt, tools);
1900
1901 let client = create_streaming_client();
1902 let response = client
1903 .post(&url)
1904 .json(&body)
1905 .send()
1906 .await
1907 .map_err(|e| AppError::Llm(format!("Gemini streaming request failed: {}", e)))?;
1908
1909 if !response.status().is_success() {
1910 let status = response.status();
1911 let body = response.text().await.unwrap_or_default();
1912 let error_msg = format!("Gemini HTTP {}: {}", status, body);
1913
1914 if is_context_overflow_message(&error_msg) || is_context_overflow_message(&body) {
1916 return Err(AppError::ContextOverflow(error_msg));
1917 }
1918
1919 return Err(AppError::Llm(error_msg));
1920 }
1921
1922 let mut stream = response.bytes_stream();
1923 let mut line_buffer = String::new();
1924 let mut parser = ThinkingParser::new();
1925 let mut last_usage: Option<TokenUsage> = None;
1926
1927 while let Some(chunk_result) = stream.next().await {
1928 if cancel_token.is_cancelled() {
1929 let _ = tx.send(cancel_token.interruption_chunk()).await;
1930 return Ok(());
1931 }
1932
1933 match chunk_result {
1934 Ok(bytes) => {
1935 let text = String::from_utf8_lossy(&bytes);
1936 for line in collect_complete_lines(&mut line_buffer, &text) {
1937 let Some(data) = line.strip_prefix("data: ") else {
1938 continue;
1939 };
1940 let Ok(json) = serde_json::from_str::<serde_json::Value>(data) else {
1941 continue;
1942 };
1943
1944 if let Some(usage) = json.get("usageMetadata") {
1946 let input_tokens = usage["promptTokenCount"].as_u64().unwrap_or(0) as u32;
1947 let output_tokens =
1948 usage["candidatesTokenCount"].as_u64().unwrap_or(0) as u32;
1949 last_usage = Some(
1950 TokenUsage::new(input_tokens, output_tokens).with_provider("gemini"),
1951 );
1952 }
1953
1954 if let Some(parts) = json
1956 .pointer("/candidates/0/content/parts")
1957 .and_then(|v| v.as_array())
1958 {
1959 for (idx, part) in parts.iter().enumerate() {
1960 if let Some(content) = part["text"].as_str()
1962 && !content.is_empty()
1963 {
1964 for chunk in parser.process(content) {
1965 let _ = tx.send(chunk).await;
1966 }
1967 }
1968
1969 if let Some(fc) = part.get("functionCall")
1972 && let Some(name) = fc["name"].as_str()
1973 {
1974 let id = format!("gemini-call-{}", idx);
1975 let _ = tx
1976 .send(StreamChunk::ToolCallStart {
1977 id,
1978 name: name.to_string(),
1979 })
1980 .await;
1981
1982 let args = fc
1983 .get("args")
1984 .map(|a| a.to_string())
1985 .unwrap_or_else(|| "{}".to_string());
1986 let _ = tx.send(StreamChunk::ToolCallArgs(args)).await;
1987 let _ = tx.send(StreamChunk::ToolCallEnd).await;
1988 }
1989 }
1990 }
1991 }
1992 }
1993 Err(e) => {
1994 let _ = tx
1995 .send(StreamChunk::Error(format!("Stream error: {}", e)))
1996 .await;
1997 return Err(AppError::Llm(format!("Stream error: {}", e)));
1998 }
1999 }
2000 }
2001
2002 let _ = tx.send(StreamChunk::Done(last_usage)).await;
2003 Ok(())
2004}
2005
2006const OLLAMA_KEEPALIVE_INTERVAL_SECS: u64 = 30;
2017
2018pub async fn stream_ollama(
2020 base_url: &str,
2021 model: &str,
2022 prompt: &str,
2023 tools: Option<&[serde_json::Value]>,
2024 tx: mpsc::Sender<StreamChunk>,
2025 cancel_token: CancellationToken,
2026) -> Result<(), AppError> {
2027 let url = format!("{}/api/chat", base_url.trim_end_matches('/'));
2028 let mut body = serde_json::json!({
2029 "model": model,
2030 "messages": [{"role": "user", "content": prompt}],
2031 "stream": true
2032 });
2033
2034 if let Some(tools) = tools
2036 && !tools.is_empty()
2037 {
2038 body["tools"] = serde_json::Value::Array(tools.to_vec());
2039 }
2040
2041 tracing::debug!(
2042 model = model,
2043 url = %url,
2044 tools_count = tools.map(|t| t.len()).unwrap_or(0),
2045 has_tools = tools.map(|t| !t.is_empty()).unwrap_or(false),
2046 "[Ollama] Starting stream request"
2047 );
2048
2049 let client = create_streaming_client();
2050
2051 let pre_conn_tx = tx.clone();
2057 let pre_conn_model = model.to_string();
2058 let pre_conn_handle = tokio::spawn(
2059 {
2060 let interval = Duration::from_secs(OLLAMA_KEEPALIVE_INTERVAL_SECS);
2061 async move {
2062 loop {
2063 tokio::time::sleep(interval).await;
2064 tracing::debug!(
2065 model = %pre_conn_model,
2066 "[Ollama] Pre-connection keepalive: model still loading"
2067 );
2068 send_status_chunk_best_effort(
2069 &pre_conn_tx,
2070 StreamChunk::Status {
2071 message: format!("Loading model '{pre_conn_model}'…"),
2072 },
2073 )
2074 .await;
2075 }
2076 }
2077 }
2078 .instrument(tracing::Span::current()),
2079 );
2080
2081 let send_result = client.post(&url).json(&body).send().await;
2082
2083 pre_conn_handle.abort();
2086 let _ = pre_conn_handle.await;
2087
2088 let response = send_result
2089 .map_err(|e| AppError::Llm(format!("Ollama streaming request failed: {}", e)))?;
2090
2091 if !response.status().is_success() {
2092 let status = response.status();
2093 let body = response.text().await.unwrap_or_default();
2094 let error_msg = format!("Ollama HTTP {}: {}", status, body);
2095
2096 if is_context_overflow_message(&error_msg) || is_context_overflow_message(&body) {
2098 return Err(AppError::ContextOverflow(error_msg));
2099 }
2100
2101 return Err(AppError::Llm(error_msg));
2102 }
2103
2104 send_status_chunk_best_effort(
2108 &tx,
2109 StreamChunk::Status {
2110 message: format!("Connected to Ollama — loading model '{}'…", model),
2111 },
2112 )
2113 .await;
2114 tracing::debug!(
2115 model = model,
2116 "[Ollama] HTTP connection established; 'Connected' status sent"
2117 );
2118
2119 let mut stream = response.bytes_stream();
2120 let mut parser = ThinkingParser::new();
2121 let mut line_buffer = String::new();
2122
2123 let keepalive_interval = Duration::from_secs(OLLAMA_KEEPALIVE_INTERVAL_SECS);
2124 let keepalive_sleep = tokio::time::sleep(keepalive_interval);
2125 tokio::pin!(keepalive_sleep);
2126
2127 loop {
2128 tokio::select! {
2129 maybe_chunk = stream.next() => {
2130 let Some(chunk_result) = maybe_chunk else {
2131 break;
2133 };
2134
2135 if cancel_token.is_cancelled() {
2136 let _ = tx.send(cancel_token.interruption_chunk()).await;
2137 return Ok(());
2138 }
2139
2140 match chunk_result {
2141 Ok(bytes) => {
2142 let text = String::from_utf8_lossy(&bytes);
2143 for line in collect_complete_lines(&mut line_buffer, &text) {
2144 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&line) {
2145 tracing::trace!(
2146 done = json["done"].as_bool().unwrap_or(false),
2147 has_content = json["message"]["content"].as_str().map(|s| !s.is_empty()).unwrap_or(false),
2148 has_tool_calls = json["message"]["tool_calls"].is_array(),
2149 "[Ollama] NDJSON line parsed"
2150 );
2151
2152 if let Some(content) = json["message"]["content"].as_str()
2154 && !content.is_empty()
2155 {
2156 let chunks = parser.process(content);
2157 for chunk in chunks {
2158 let _ = tx.send(chunk).await;
2159 }
2160 }
2161
2162 if let Some(tool_calls) = json["message"]["tool_calls"].as_array() {
2164 tracing::debug!(
2165 model = model,
2166 count = tool_calls.len(),
2167 "[Ollama] Tool calls found in NDJSON line"
2168 );
2169 for call in tool_calls {
2170 let name = call["function"]["name"].as_str().unwrap_or_default();
2171 let args = &call["function"]["arguments"];
2172
2173 if !name.is_empty() {
2174 tracing::debug!(
2175 tool = name,
2176 "[Ollama] Emitting ToolCallStart/Args/End"
2177 );
2178 let id = format!("ollama-tool-{}", uuid::Uuid::new_v4());
2179 let _ = tx
2180 .send(StreamChunk::ToolCallStart {
2181 id,
2182 name: name.to_string(),
2183 })
2184 .await;
2185
2186 let args_str = if args.is_object() || args.is_array() {
2187 serde_json::to_string(args).unwrap_or_default()
2188 } else {
2189 args.as_str().unwrap_or("{}").to_string()
2190 };
2191
2192 let _ = tx.send(StreamChunk::ToolCallArgs(args_str)).await;
2193 let _ = tx.send(StreamChunk::ToolCallEnd).await;
2194 tracing::debug!(tool = name, "[Ollama] ToolCallEnd emitted");
2195 }
2196 }
2197 }
2198
2199 if json["done"].as_bool() == Some(true) {
2201 tracing::debug!(model = model, "[Ollama] done=true — sending Done chunk");
2202 let _ = tx.send(StreamChunk::Done(None)).await;
2203 return Ok(());
2204 }
2205 }
2206 }
2207 }
2208 Err(e) => {
2209 let _ = tx
2210 .send(StreamChunk::Error(format!("Stream error: {}", e)))
2211 .await;
2212 return Err(AppError::Llm(format!("Stream error: {}", e)));
2213 }
2214 }
2215 }
2216
2217 () = &mut keepalive_sleep => {
2227 if cancel_token.is_cancelled() {
2228 let _ = tx.send(cancel_token.interruption_chunk()).await;
2229 return Ok(());
2230 }
2231 tracing::debug!(model = model, "[Ollama] Keepalive firing — sending Status chunk");
2232 send_status_chunk_best_effort(
2233 &tx,
2234 StreamChunk::Status {
2235 message: format!("Working… (model '{}')", model),
2236 },
2237 )
2238 .await;
2239 tracing::debug!(
2240 model = model,
2241 "[Ollama] Keepalive Status sent"
2242 );
2243 keepalive_sleep
2245 .as_mut()
2246 .reset(tokio::time::Instant::now() + keepalive_interval);
2247 }
2248 }
2249 }
2250
2251 let _ = tx.send(StreamChunk::Done(None)).await;
2252 Ok(())
2253}
2254
2255async fn stream_unconfigured_error(
2260 provider_name: &str,
2261 tx: mpsc::Sender<StreamChunk>,
2262) -> Result<(), AppError> {
2263 let message = format!(
2264 "LLM provider '{}' is not configured. Please configure it in Settings or run 'gestura config edit'.",
2265 provider_name
2266 );
2267 send_status_chunk_best_effort(
2269 &tx,
2270 StreamChunk::Status {
2271 message: message.clone(),
2272 },
2273 )
2274 .await;
2275 let _ = tx.send(StreamChunk::Error(message.clone())).await;
2276 Err(AppError::Llm(message))
2277}
2278
2279fn is_unconfigured_provider_message(message: &str) -> bool {
2284 message.contains("is not configured") || message.contains("not configured")
2285}
2286
2287fn is_context_overflow_message(message: &str) -> bool {
2292 let msg_lower = message.to_lowercase();
2293 let is_overflow = msg_lower.contains("contextlengthexceeded")
2296 || msg_lower.contains("context_length_exceeded")
2297 || msg_lower.contains("context length")
2298 || msg_lower.contains("maximum context")
2299 || (msg_lower.contains("tokens") && msg_lower.contains("exceeds"))
2300 || (msg_lower.contains("token") && msg_lower.contains("limit"));
2301
2302 if is_overflow {
2303 tracing::warn!(
2304 message_preview = %message.chars().take(200).collect::<String>(),
2305 "Detected context overflow error"
2306 );
2307 }
2308
2309 is_overflow
2310}
2311
2312fn is_unconfigured_provider_error(err: &AppError) -> bool {
2314 match err {
2315 AppError::Llm(msg) => is_unconfigured_provider_message(msg),
2316 _ => false,
2317 }
2318}
2319
2320pub async fn start_streaming(
2324 config: &StreamingConfig,
2325 prompt: &str,
2326 tool_schemas: Option<ProviderToolSchemas>,
2327 tx: mpsc::Sender<StreamChunk>,
2328 cancel_token: CancellationToken,
2329) -> Result<(), AppError> {
2330 async {
2331 match config.primary.as_str() {
2332 "openai" => {
2333 if let Some(c) = &config.openai {
2334 let openai_tools =
2335 tool_schemas
2336 .as_ref()
2337 .map(|schemas| match openai_api_for_model(&c.model) {
2338 OpenAiApi::ChatCompletions => schemas.openai.as_slice(),
2339 OpenAiApi::Responses => schemas.openai_responses.as_slice(),
2340 });
2341 stream_openai(
2342 &c.api_key,
2343 c.base_url.as_deref().unwrap_or("https://api.openai.com"),
2344 &c.model,
2345 prompt,
2346 openai_tools,
2347 tx,
2348 cancel_token,
2349 )
2350 .await
2351 } else {
2352 stream_unconfigured_error("openai", tx).await
2353 }
2354 }
2355 "anthropic" => {
2356 if let Some(c) = &config.anthropic {
2357 stream_anthropic(AnthropicStreamRequest {
2358 api_key: &c.api_key,
2359 base_url: c.base_url.as_deref().unwrap_or("https://api.anthropic.com"),
2360 model: &c.model,
2361 thinking_budget_tokens: c.thinking_budget_tokens,
2362 prompt,
2363 tools: tool_schemas.as_ref().map(|s| s.anthropic.as_slice()),
2364 tx,
2365 cancel_token,
2366 })
2367 .await
2368 } else {
2369 stream_unconfigured_error("anthropic", tx).await
2370 }
2371 }
2372 "grok" => {
2373 if let Some(c) = &config.grok {
2375 stream_openai_chat_compatible(
2376 &c.api_key,
2377 c.base_url.as_deref().unwrap_or("https://api.x.ai"),
2378 &c.model,
2379 prompt,
2380 tool_schemas.as_ref().map(|s| s.openai.as_slice()),
2381 tx,
2382 cancel_token,
2383 )
2384 .await
2385 } else {
2386 stream_unconfigured_error("grok", tx).await
2387 }
2388 }
2389 "gemini" => {
2390 if let Some(c) = &config.gemini {
2391 stream_gemini(
2392 &c.api_key,
2393 c.base_url
2394 .as_deref()
2395 .unwrap_or("https://generativelanguage.googleapis.com"),
2396 &c.model,
2397 prompt,
2398 tool_schemas.as_ref().map(|s| s.gemini.as_slice()),
2399 tx,
2400 cancel_token,
2401 )
2402 .await
2403 } else {
2404 stream_unconfigured_error("gemini", tx).await
2405 }
2406 }
2407 "ollama" => {
2408 if let Some(c) = &config.ollama {
2409 stream_ollama(
2410 &c.base_url,
2411 &c.model,
2412 prompt,
2413 tool_schemas.as_ref().map(|s| s.openai.as_slice()),
2414 tx,
2415 cancel_token,
2416 )
2417 .await
2418 } else {
2419 stream_unconfigured_error("ollama", tx).await
2420 }
2421 }
2422 other => stream_unconfigured_error(other, tx).await,
2423 }
2424 }
2425 .instrument(tracing::info_span!(
2426 "agent.streaming.request",
2427 provider = %config.primary,
2428 has_tool_schemas = tool_schemas.is_some()
2429 ))
2430 .await
2431}
2432
2433pub async fn start_streaming_with_fallback(
2436 config: &StreamingConfig,
2437 prompt: &str,
2438 tool_schemas: Option<ProviderToolSchemas>,
2439 tx: mpsc::Sender<StreamChunk>,
2440 cancel_token: CancellationToken,
2441) -> Result<(), AppError> {
2442 let retry_policy = RetryPolicy::for_streaming();
2444 let total_attempts = retry_policy.max_attempts.max(1) as usize;
2445 let mut last_error: Option<AppError> = None;
2446 let mut skipped_retries_due_to_unconfigured = false;
2447
2448 for attempt in 0..total_attempts {
2449 if cancel_token.is_cancelled() {
2450 let _ = tx.send(cancel_token.interruption_chunk()).await;
2451 return Ok(());
2452 }
2453
2454 let (attempt_tx, mut attempt_rx) =
2456 mpsc::channel::<StreamChunk>(STREAM_CHUNK_BUFFER_CAPACITY);
2457 let attempt_cancel = cancel_token.clone();
2458 let config_clone = config.clone();
2459 let prompt_clone = prompt.to_string();
2460 let tool_schemas_clone = tool_schemas.clone();
2461
2462 let attempt_span = tracing::info_span!(
2464 "agent.streaming.fallback_attempt",
2465 attempt = attempt + 1,
2466 total_attempts = total_attempts
2467 );
2468 let handle = tokio::spawn(
2469 async move {
2470 start_streaming(
2471 &config_clone,
2472 &prompt_clone,
2473 tool_schemas_clone,
2474 attempt_tx,
2475 attempt_cancel,
2476 )
2477 .await
2478 }
2479 .instrument(attempt_span),
2480 );
2481
2482 let forward = forward_attempt_stream(&mut attempt_rx, &tx).await;
2485
2486 match handle.await {
2488 Ok(Ok(())) => {}
2489 Ok(Err(e)) => {
2490 last_error = Some(e);
2491 }
2492 Err(e) => {
2493 last_error = Some(AppError::Llm(format!("Task failed: {}", e)));
2494 }
2495 }
2496
2497 match forward.outcome {
2498 AttemptOutcome::Success => return Ok(()),
2499 AttemptOutcome::Cancelled | AttemptOutcome::Paused => return Ok(()),
2500 AttemptOutcome::FatalError => {
2501 let err = AppError::Llm(
2502 forward
2503 .error
2504 .clone()
2505 .unwrap_or_else(|| "Streaming failed".to_string()),
2506 );
2507 return Err(err);
2508 }
2509 AttemptOutcome::ContextOverflowError => {
2510 let error_msg = forward
2513 .error
2514 .clone()
2515 .unwrap_or_else(|| "Context length exceeded".to_string());
2516
2517 tracing::warn!(
2518 error = %error_msg,
2519 "Context overflow detected - returning to pipeline for compaction"
2520 );
2521
2522 let _ = tx
2524 .send(StreamChunk::ContextOverflow {
2525 error_message: error_msg.clone(),
2526 })
2527 .await;
2528
2529 return Err(AppError::ContextOverflow(error_msg));
2530 }
2531 AttemptOutcome::RetryableError => {
2532 if let Some(ref e) = forward.error {
2533 last_error = Some(AppError::Llm(e.clone()));
2534 }
2535 }
2536 AttemptOutcome::UnexpectedEnd => {
2537 if forward.forwarded_output {
2538 let err = AppError::Llm(
2540 "Streaming ended unexpectedly (no terminal event received)".to_string(),
2541 );
2542 let _ = tx.send(StreamChunk::Error(err.to_string())).await;
2543 return Err(err);
2544 }
2545 }
2547 }
2548
2549 let unconfigured = forward
2552 .error
2553 .as_deref()
2554 .map(is_unconfigured_provider_message)
2555 .unwrap_or(false)
2556 || last_error
2557 .as_ref()
2558 .map(is_unconfigured_provider_error)
2559 .unwrap_or(false);
2560
2561 if unconfigured {
2562 skipped_retries_due_to_unconfigured = true;
2563 break;
2564 }
2565
2566 let is_context_overflow = forward
2569 .error
2570 .as_deref()
2571 .map(is_context_overflow_message)
2572 .unwrap_or(false)
2573 || matches!(&last_error, Some(AppError::ContextOverflow(_)));
2574
2575 if is_context_overflow {
2576 let error_msg = forward
2577 .error
2578 .clone()
2579 .or_else(|| last_error.as_ref().map(|e| e.to_string()))
2580 .unwrap_or_else(|| "Context length exceeded".to_string());
2581
2582 tracing::warn!(
2583 error = %error_msg,
2584 "Context overflow detected - skipping retries, returning for compaction"
2585 );
2586
2587 let _ = tx
2589 .send(StreamChunk::ContextOverflow {
2590 error_message: error_msg.clone(),
2591 })
2592 .await;
2593
2594 return Err(AppError::ContextOverflow(error_msg));
2595 }
2596
2597 if attempt + 1 < total_attempts {
2599 let error_msg = last_error
2601 .as_ref()
2602 .map(|e| e.to_string())
2603 .unwrap_or_else(|| "Unknown error".to_string());
2604 let retry_delay =
2605 select_streaming_retry_delay(&retry_policy, attempt as u32 + 1, &error_msg);
2606
2607 tracing::warn!(
2608 attempt = attempt + 1,
2609 delay_ms = retry_delay.as_millis(),
2610 error = %error_msg,
2611 "Primary LLM failed, retrying after backoff"
2612 );
2613
2614 let _ = tx
2616 .send(StreamChunk::RetryAttempt {
2617 attempt: attempt as u32 + 1,
2618 max_attempts: total_attempts as u32,
2619 delay_ms: retry_delay.as_millis() as u64,
2620 error_message: error_msg,
2621 })
2622 .await;
2623
2624 tokio::time::sleep(retry_delay).await;
2625 }
2626 }
2627
2628 if let Some(ref fallback_provider) = config.fallback {
2630 if skipped_retries_due_to_unconfigured {
2631 tracing::info!(
2632 fallback = fallback_provider,
2633 "Primary LLM is not configured, trying fallback provider"
2634 );
2635 } else {
2636 tracing::info!(
2637 fallback = fallback_provider,
2638 "Primary LLM exhausted retries, trying fallback provider"
2639 );
2640 }
2641
2642 let mut fallback_config = config.clone();
2644 fallback_config.primary = fallback_provider.clone();
2645
2646 let result = start_streaming(
2648 &fallback_config,
2649 prompt,
2650 tool_schemas,
2651 tx.clone(),
2652 cancel_token,
2653 )
2654 .await;
2655
2656 if result.is_ok() {
2657 return Ok(());
2658 }
2659
2660 tracing::error!("Fallback provider also failed");
2661 }
2662
2663 if let Some(error) = last_error {
2665 let _ = tx.send(StreamChunk::Error(error.to_string())).await;
2666 Err(error)
2667 } else {
2668 let err = AppError::Llm("All LLM providers failed".to_string());
2669 let _ = tx.send(StreamChunk::Error(err.to_string())).await;
2670 Err(err)
2671 }
2672}
2673
2674#[cfg(test)]
2675mod tests {
2676 use super::*;
2677
2678 #[test]
2679 fn openai_http_error_includes_retry_after_hint_when_present() {
2680 let message = format_openai_http_error(
2681 reqwest::StatusCode::TOO_MANY_REQUESTS,
2682 "OpenAI",
2683 "gpt-5.4",
2684 OpenAiApi::Responses,
2685 "rate limit reached",
2686 Some(Duration::from_secs(12)),
2687 );
2688
2689 assert!(message.contains("HTTP 429"));
2690 assert!(message.contains("retrying after 12 seconds"));
2691 }
2692
2693 #[test]
2694 fn retry_delay_prefers_provider_retry_after_hint() {
2695 let policy = RetryPolicy {
2696 max_attempts: 3,
2697 initial_delay_ms: 1_000,
2698 max_delay_ms: 8_000,
2699 backoff_multiplier: 2.0,
2700 jitter_factor: 0.0,
2701 };
2702
2703 let delay = select_streaming_retry_delay(
2704 &policy,
2705 1,
2706 "OpenAI /v1/responses HTTP 429: rate limit reached. Provider suggested retrying after 12 seconds.",
2707 );
2708
2709 assert_eq!(delay, Duration::from_secs(12));
2710 }
2711
2712 #[test]
2713 fn retry_delay_uses_rate_limit_floor_without_retry_after_hint() {
2714 let policy = RetryPolicy {
2715 max_attempts: 3,
2716 initial_delay_ms: 1_000,
2717 max_delay_ms: 8_000,
2718 backoff_multiplier: 2.0,
2719 jitter_factor: 0.0,
2720 };
2721
2722 let delay = select_streaming_retry_delay(
2723 &policy,
2724 1,
2725 "OpenAI /v1/responses HTTP 429: Too many requests",
2726 );
2727
2728 assert_eq!(delay, Duration::from_secs(5));
2729 }
2730
2731 #[test]
2732 fn test_cancellation_token() {
2733 let token = CancellationToken::new();
2734 assert!(!token.is_cancelled());
2735 token.cancel();
2736 assert!(token.is_cancelled());
2737 assert!(!token.is_pause_requested());
2738 assert!(matches!(token.interruption_chunk(), StreamChunk::Cancelled));
2739 }
2740
2741 #[test]
2742 fn test_cancellation_token_pause_intent() {
2743 let token = CancellationToken::new();
2744 token.pause();
2745
2746 assert!(token.is_cancelled());
2747 assert!(token.is_pause_requested());
2748 assert!(matches!(token.interruption_chunk(), StreamChunk::Paused));
2749
2750 token.cancel();
2751 assert!(token.is_cancelled());
2752 assert!(!token.is_pause_requested());
2753 assert!(matches!(token.interruption_chunk(), StreamChunk::Cancelled));
2754 }
2755
2756 #[test]
2757 fn split_think_blocks_extracts_thinking() {
2758 let input = "<think>plan</think>answer";
2759 let (content, thinking) = split_think_blocks(input);
2760 assert_eq!(content, "answer");
2761 assert_eq!(thinking.as_deref(), Some("plan"));
2762 }
2763
2764 #[test]
2765 fn thinking_parser_handles_complete_tags() {
2766 let mut parser = ThinkingParser::new();
2767 let chunks = parser.process("<think>thinking content</think>response text");
2768
2769 assert_eq!(chunks.len(), 2);
2770 assert!(matches!(&chunks[0], StreamChunk::Thinking(t) if t == "thinking content"));
2771 assert!(matches!(&chunks[1], StreamChunk::Text(t) if t == "response text"));
2772 }
2773
2774 #[test]
2775 fn thinking_parser_handles_split_start_tag() {
2776 let mut parser = ThinkingParser::new();
2777
2778 let chunks1 = parser.process("Hello <thi");
2780 assert_eq!(chunks1.len(), 1);
2781 assert!(matches!(&chunks1[0], StreamChunk::Text(t) if t == "Hello "));
2782
2783 let chunks2 = parser.process("nk>thinking</think>done");
2785 assert_eq!(chunks2.len(), 2);
2786 assert!(matches!(&chunks2[0], StreamChunk::Thinking(t) if t == "thinking"));
2787 assert!(matches!(&chunks2[1], StreamChunk::Text(t) if t == "done"));
2788 }
2789
2790 #[test]
2791 fn thinking_parser_handles_split_end_tag() {
2792 let mut parser = ThinkingParser::new();
2793
2794 let chunks1 = parser.process("<think>thinking content</th");
2796 assert_eq!(chunks1.len(), 1);
2797 assert!(matches!(&chunks1[0], StreamChunk::Thinking(t) if t == "thinking content"));
2798
2799 let chunks2 = parser.process("ink>response");
2801 assert_eq!(chunks2.len(), 1);
2802 assert!(matches!(&chunks2[0], StreamChunk::Text(t) if t == "response"));
2803 }
2804
2805 #[test]
2806 fn thinking_parser_handles_text_before_think() {
2807 let mut parser = ThinkingParser::new();
2808 let chunks = parser.process("prefix<think>thought</think>suffix");
2809
2810 assert_eq!(chunks.len(), 3);
2811 assert!(matches!(&chunks[0], StreamChunk::Text(t) if t == "prefix"));
2812 assert!(matches!(&chunks[1], StreamChunk::Thinking(t) if t == "thought"));
2813 assert!(matches!(&chunks[2], StreamChunk::Text(t) if t == "suffix"));
2814 }
2815
2816 #[test]
2817 fn thinking_parser_handles_no_think_tags() {
2818 let mut parser = ThinkingParser::new();
2819 let chunks = parser.process("just regular text");
2820
2821 assert_eq!(chunks.len(), 1);
2822 assert!(matches!(&chunks[0], StreamChunk::Text(t) if t == "just regular text"));
2823 }
2824
2825 #[test]
2826 fn openai_body_includes_tools_and_tool_choice_when_provided() {
2827 let tools = vec![serde_json::json!({
2828 "type": "function",
2829 "function": {
2830 "name": "shell",
2831 "description": "Run a command",
2832 "parameters": {"type": "object", "properties": {}}
2833 }
2834 })];
2835
2836 let body = build_openai_chat_request_body("gpt-test", "hi", Some(&tools));
2837 assert!(body.get("tools").is_some());
2838 assert_eq!(
2839 body.get("tool_choice").and_then(|v| v.as_str()),
2840 Some("auto")
2841 );
2842 }
2843
2844 #[test]
2845 fn openai_body_omits_tools_when_none() {
2846 let body = build_openai_chat_request_body("gpt-test", "hi", None);
2847 assert!(body.get("tools").is_none());
2848 assert!(body.get("tool_choice").is_none());
2849 }
2850
2851 #[test]
2852 fn openai_body_omits_temperature() {
2853 let body = build_openai_chat_request_body("gpt-test", "hi", None);
2854 assert!(body.get("temperature").is_none());
2855 }
2856
2857 #[test]
2858 fn openai_responses_body_uses_responses_shape() {
2859 let tools = vec![serde_json::json!({
2860 "type": "function",
2861 "name": "shell",
2862 "description": "Run a command",
2863 "parameters": {"type": "object", "properties": {}}
2864 })];
2865
2866 let body = build_openai_responses_request_body("gpt-5.4", "hi", Some(&tools));
2867 assert_eq!(body["model"], "gpt-5.4");
2868 assert_eq!(body["input"][0]["role"], "user");
2869 assert_eq!(body["input"][0]["content"], "hi");
2870 assert!(body.get("tools").is_some());
2871 assert_eq!(body["tool_choice"], "auto");
2872 }
2873
2874 #[test]
2875 fn openai_http_error_mentions_selected_endpoint() {
2876 let message = format_openai_http_error(
2877 reqwest::StatusCode::NOT_FOUND,
2878 "OpenAI",
2879 "gpt-5.3-codex",
2880 OpenAiApi::ChatCompletions,
2881 "This is not a chat model",
2882 None,
2883 );
2884 assert!(message.contains("/v1/responses"));
2885 assert!(message.contains("/v1/chat/completions"));
2886 }
2887
2888 #[test]
2889 fn openai_tool_call_deltas_are_assembled_by_index() {
2890 let mut pending = BTreeMap::new();
2891
2892 merge_openai_tool_call_delta(
2893 &mut pending,
2894 &serde_json::json!({
2895 "index": 0,
2896 "id": "call_0",
2897 "function": {"name": "task", "arguments": "{\"operation\":\"update_status\",\"task_id\":\"abc"}
2898 }),
2899 0,
2900 );
2901 merge_openai_tool_call_delta(
2902 &mut pending,
2903 &serde_json::json!({
2904 "index": 1,
2905 "id": "call_1",
2906 "function": {"name": "shell", "arguments": "{\"command\":\"cargo check"}
2907 }),
2908 1,
2909 );
2910 merge_openai_tool_call_delta(
2911 &mut pending,
2912 &serde_json::json!({
2913 "index": 0,
2914 "function": {"arguments": "\",\"status\":\"completed\"}"}
2915 }),
2916 0,
2917 );
2918 merge_openai_tool_call_delta(
2919 &mut pending,
2920 &serde_json::json!({
2921 "index": 1,
2922 "function": {"arguments": "\",\"timeout_secs\":300}"}
2923 }),
2924 1,
2925 );
2926
2927 let calls = take_openai_tool_calls(&mut pending);
2928 assert_eq!(calls.len(), 2);
2929 assert_eq!(calls[0].0, 0);
2930 assert_eq!(calls[0].1.id, "call_0");
2931 assert_eq!(calls[0].1.name, "task");
2932 assert_eq!(
2933 calls[0].1.arguments,
2934 "{\"operation\":\"update_status\",\"task_id\":\"abc\",\"status\":\"completed\"}"
2935 );
2936 assert_eq!(calls[1].0, 1);
2937 assert_eq!(calls[1].1.id, "call_1");
2938 assert_eq!(calls[1].1.name, "shell");
2939 assert_eq!(
2940 calls[1].1.arguments,
2941 "{\"command\":\"cargo check\",\"timeout_secs\":300}"
2942 );
2943 assert!(pending.is_empty());
2944 }
2945
2946 #[tokio::test]
2947 async fn emit_openai_tool_calls_streams_complete_calls_in_index_order() {
2948 let (tx, mut rx) = mpsc::channel(10);
2949 let mut pending = BTreeMap::new();
2950 pending.insert(
2951 1,
2952 PendingOpenAiToolCall {
2953 id: "call_1".to_string(),
2954 name: "shell".to_string(),
2955 arguments: "{\"command\":\"pwd\"}".to_string(),
2956 },
2957 );
2958 pending.insert(
2959 0,
2960 PendingOpenAiToolCall {
2961 id: "call_0".to_string(),
2962 name: "file".to_string(),
2963 arguments: "{\"operation\":\"list\"}".to_string(),
2964 },
2965 );
2966
2967 emit_openai_tool_calls(&tx, &mut pending).await;
2968
2969 assert!(matches!(
2970 rx.recv().await,
2971 Some(StreamChunk::ToolCallStart { id, name }) if id == "call_0" && name == "file"
2972 ));
2973 assert!(matches!(
2974 rx.recv().await,
2975 Some(StreamChunk::ToolCallArgs(args)) if args == "{\"operation\":\"list\"}"
2976 ));
2977 assert!(matches!(rx.recv().await, Some(StreamChunk::ToolCallEnd)));
2978 assert!(matches!(
2979 rx.recv().await,
2980 Some(StreamChunk::ToolCallStart { id, name }) if id == "call_1" && name == "shell"
2981 ));
2982 assert!(matches!(
2983 rx.recv().await,
2984 Some(StreamChunk::ToolCallArgs(args)) if args == "{\"command\":\"pwd\"}"
2985 ));
2986 assert!(matches!(rx.recv().await, Some(StreamChunk::ToolCallEnd)));
2987 assert!(pending.is_empty());
2988 }
2989
2990 #[test]
2991 fn openai_responses_tool_calls_are_buffered_by_output_index() {
2992 let mut pending = BTreeMap::new();
2993 let mut tool_indices = HashMap::new();
2994
2995 merge_openai_responses_tool_item(
2996 &mut pending,
2997 &mut tool_indices,
2998 &serde_json::json!({
2999 "type": "response.output_item.added",
3000 "output_index": 0,
3001 "item": {
3002 "type": "function_call",
3003 "id": "fc_0",
3004 "call_id": "call_0",
3005 "name": "file"
3006 }
3007 }),
3008 0,
3009 );
3010 merge_openai_responses_tool_argument_delta(
3011 &mut pending,
3012 &mut tool_indices,
3013 &serde_json::json!({
3014 "type": "response.function_call_arguments.delta",
3015 "output_index": 0,
3016 "item_id": "fc_0",
3017 "delta": "{\"operation\":\"list\"}"
3018 }),
3019 0,
3020 );
3021 complete_openai_responses_tool_arguments(
3022 &mut pending,
3023 &mut tool_indices,
3024 &serde_json::json!({
3025 "type": "response.function_call_arguments.done",
3026 "output_index": 0,
3027 "item_id": "fc_0",
3028 "arguments": "{\"operation\":\"list\"}"
3029 }),
3030 0,
3031 );
3032
3033 assert_eq!(pending.len(), 1);
3034 assert_eq!(pending[&0].id, "call_0");
3035 assert_eq!(pending[&0].name, "file");
3036 assert_eq!(pending[&0].arguments, "{\"operation\":\"list\"}");
3037 assert!(pending[&0].finished);
3038 }
3039
3040 #[test]
3041 fn openai_responses_tool_calls_reuse_stable_aliases_when_output_index_is_missing() {
3042 let mut pending = BTreeMap::new();
3043 let mut tool_indices = HashMap::new();
3044
3045 merge_openai_responses_tool_item(
3046 &mut pending,
3047 &mut tool_indices,
3048 &serde_json::json!({
3049 "type": "response.output_item.added",
3050 "item": {
3051 "type": "function_call",
3052 "id": "fc_0",
3053 "call_id": "call_0",
3054 "name": "file"
3055 }
3056 }),
3057 3,
3058 );
3059 merge_openai_responses_tool_argument_delta(
3060 &mut pending,
3061 &mut tool_indices,
3062 &serde_json::json!({
3063 "type": "response.function_call_arguments.delta",
3064 "item_id": "fc_0",
3065 "delta": "{\"operation\":\"list\"}"
3066 }),
3067 8,
3068 );
3069 complete_openai_responses_tool_arguments(
3070 &mut pending,
3071 &mut tool_indices,
3072 &serde_json::json!({
3073 "type": "response.function_call_arguments.done",
3074 "call_id": "call_0",
3075 "arguments": "{\"operation\":\"list\"}"
3076 }),
3077 13,
3078 );
3079
3080 assert_eq!(pending.len(), 1);
3081 assert_eq!(pending[&3].id, "call_0");
3082 assert_eq!(pending[&3].arguments, "{\"operation\":\"list\"}");
3083 assert!(pending[&3].finished);
3084 }
3085
3086 #[tokio::test]
3087 async fn emit_openai_responses_tool_calls_waits_for_lowest_ready_index() {
3088 let (tx, mut rx) = mpsc::channel(10);
3089 let mut pending = BTreeMap::new();
3090 let mut emitted_ids = HashSet::new();
3091 pending.insert(
3092 0,
3093 PendingOpenAiResponsesToolCall {
3094 id: "call_0".to_string(),
3095 name: "file".to_string(),
3096 arguments: "{\"operation\":\"list\"}".to_string(),
3097 finished: true,
3098 },
3099 );
3100 pending.insert(
3101 1,
3102 PendingOpenAiResponsesToolCall {
3103 id: "call_1".to_string(),
3104 name: "shell".to_string(),
3105 arguments: "{\"command\":\"pwd\"}".to_string(),
3106 finished: true,
3107 },
3108 );
3109
3110 emit_ready_openai_responses_tool_calls(&tx, &mut pending, &mut emitted_ids, false).await;
3111
3112 assert!(matches!(
3113 rx.recv().await,
3114 Some(StreamChunk::ToolCallStart { id, name }) if id == "call_0" && name == "file"
3115 ));
3116 assert!(matches!(
3117 rx.recv().await,
3118 Some(StreamChunk::ToolCallArgs(args)) if args == "{\"operation\":\"list\"}"
3119 ));
3120 assert!(matches!(rx.recv().await, Some(StreamChunk::ToolCallEnd)));
3121 assert!(matches!(
3122 rx.recv().await,
3123 Some(StreamChunk::ToolCallStart { id, name }) if id == "call_1" && name == "shell"
3124 ));
3125 assert!(matches!(
3126 rx.recv().await,
3127 Some(StreamChunk::ToolCallArgs(args)) if args == "{\"command\":\"pwd\"}"
3128 ));
3129 assert!(matches!(rx.recv().await, Some(StreamChunk::ToolCallEnd)));
3130 assert!(pending.is_empty());
3131 }
3132
3133 #[tokio::test]
3134 async fn emit_openai_responses_tool_calls_skips_duplicate_call_ids() {
3135 let (tx, mut rx) = mpsc::channel(10);
3136 let mut pending = BTreeMap::new();
3137 let mut emitted_ids = HashSet::new();
3138 pending.insert(
3139 0,
3140 PendingOpenAiResponsesToolCall {
3141 id: "call_dup".to_string(),
3142 name: "file".to_string(),
3143 arguments: "{\"operation\":\"list\"}".to_string(),
3144 finished: true,
3145 },
3146 );
3147 pending.insert(
3148 1,
3149 PendingOpenAiResponsesToolCall {
3150 id: "call_dup".to_string(),
3151 name: "file".to_string(),
3152 arguments: "{\"operation\":\"list\"}".to_string(),
3153 finished: true,
3154 },
3155 );
3156
3157 emit_ready_openai_responses_tool_calls(&tx, &mut pending, &mut emitted_ids, false).await;
3158
3159 assert!(matches!(
3160 rx.recv().await,
3161 Some(StreamChunk::ToolCallStart { id, name }) if id == "call_dup" && name == "file"
3162 ));
3163 assert!(matches!(
3164 rx.recv().await,
3165 Some(StreamChunk::ToolCallArgs(args)) if args == "{\"operation\":\"list\"}"
3166 ));
3167 assert!(matches!(rx.recv().await, Some(StreamChunk::ToolCallEnd)));
3168 assert!(rx.try_recv().is_err());
3169 assert!(pending.is_empty());
3170 }
3171
3172 #[test]
3173 fn anthropic_body_includes_tools_when_provided() {
3174 let tools = vec![serde_json::json!({
3175 "name": "shell",
3176 "description": "Run a command",
3177 "input_schema": {"type": "object", "properties": {}}
3178 })];
3179
3180 let body = build_anthropic_messages_body("claude-test", "hi", None, Some(&tools));
3181 assert!(body.get("tools").is_some());
3182 }
3183
3184 #[tokio::test]
3185 async fn test_stream_chunk_types() {
3186 let (tx, mut rx) = mpsc::channel(10);
3187
3188 tx.send(StreamChunk::Text("Hello".to_string()))
3189 .await
3190 .unwrap();
3191 tx.send(StreamChunk::Done(None)).await.unwrap();
3192
3193 if let Some(StreamChunk::Text(text)) = rx.recv().await {
3194 assert_eq!(text, "Hello");
3195 } else {
3196 panic!("Expected Text chunk");
3197 }
3198
3199 if let Some(StreamChunk::Done(_)) = rx.recv().await {
3200 } else {
3202 panic!("Expected Done chunk");
3203 }
3204 }
3205
3206 #[tokio::test]
3207 async fn start_streaming_unconfigured_provider_returns_error() {
3208 let cfg = StreamingConfig {
3209 primary: "openai".to_string(),
3210 openai: None,
3211 ..Default::default()
3212 };
3213
3214 let (tx, mut rx) = mpsc::channel(128);
3215 let cancel = CancellationToken::new();
3216
3217 tokio::spawn(async move {
3218 let prompt = "hello world";
3219 let _ = start_streaming(&cfg, prompt, None, tx, cancel).await;
3220 });
3221
3222 match rx.recv().await {
3224 Some(StreamChunk::Status { message }) => {
3225 assert!(message.contains("not configured"));
3226 }
3227 other => panic!("Expected Status chunk, got: {other:?}"),
3228 }
3229
3230 match rx.recv().await {
3232 Some(StreamChunk::Error(msg)) => {
3233 assert!(msg.contains("not configured"));
3234 }
3235 other => panic!("Expected Error chunk, got: {other:?}"),
3236 }
3237 }
3238
3239 #[tokio::test]
3240 async fn forward_attempt_stream_forwards_immediately() {
3241 let (outer_tx, mut outer_rx) = mpsc::channel::<StreamChunk>(10);
3242 let (attempt_tx, mut attempt_rx) = mpsc::channel::<StreamChunk>(10);
3243
3244 let forward_handle =
3245 tokio::spawn(async move { forward_attempt_stream(&mut attempt_rx, &outer_tx).await });
3246
3247 attempt_tx
3248 .send(StreamChunk::Text("A".to_string()))
3249 .await
3250 .unwrap();
3251
3252 match outer_rx.recv().await {
3254 Some(StreamChunk::Text(t)) => assert_eq!(t, "A"),
3255 other => panic!("Expected Text chunk, got: {other:?}"),
3256 }
3257
3258 attempt_tx.send(StreamChunk::Done(None)).await.unwrap();
3259 match outer_rx.recv().await {
3260 Some(StreamChunk::Done(_)) => {}
3261 other => panic!("Expected Done chunk, got: {other:?}"),
3262 }
3263
3264 let result = forward_handle.await.unwrap();
3265 assert_eq!(result.outcome, AttemptOutcome::Success);
3266 }
3267
3268 #[tokio::test]
3269 async fn forward_attempt_stream_retryable_error_before_output_is_not_forwarded() {
3270 let (outer_tx, mut outer_rx) = mpsc::channel::<StreamChunk>(10);
3271 let (attempt_tx, mut attempt_rx) = mpsc::channel::<StreamChunk>(10);
3272
3273 let forward_handle =
3274 tokio::spawn(async move { forward_attempt_stream(&mut attempt_rx, &outer_tx).await });
3275
3276 attempt_tx
3277 .send(StreamChunk::Error("nope".to_string()))
3278 .await
3279 .unwrap();
3280
3281 let recv =
3284 tokio::time::timeout(std::time::Duration::from_millis(50), outer_rx.recv()).await;
3285 match recv {
3286 Err(_) => {} Ok(None) => {} Ok(Some(other)) => panic!("did not expect any forwarded chunk, got: {other:?}"),
3289 }
3290
3291 let result = forward_handle.await.unwrap();
3292 assert_eq!(result.outcome, AttemptOutcome::RetryableError);
3293 }
3294
3295 #[tokio::test]
3296 async fn forward_attempt_stream_fatal_error_after_output_is_forwarded() {
3297 let (outer_tx, mut outer_rx) = mpsc::channel::<StreamChunk>(10);
3298 let (attempt_tx, mut attempt_rx) = mpsc::channel::<StreamChunk>(10);
3299
3300 let forward_handle =
3301 tokio::spawn(async move { forward_attempt_stream(&mut attempt_rx, &outer_tx).await });
3302
3303 attempt_tx
3304 .send(StreamChunk::Text("hello".to_string()))
3305 .await
3306 .unwrap();
3307 match outer_rx.recv().await {
3308 Some(StreamChunk::Text(t)) => assert_eq!(t, "hello"),
3309 other => panic!("Expected Text chunk, got: {other:?}"),
3310 }
3311
3312 attempt_tx
3313 .send(StreamChunk::Error("boom".to_string()))
3314 .await
3315 .unwrap();
3316 match outer_rx.recv().await {
3317 Some(StreamChunk::Error(e)) => assert_eq!(e, "boom"),
3318 other => panic!("Expected Error chunk, got: {other:?}"),
3319 }
3320
3321 let result = forward_handle.await.unwrap();
3322 assert_eq!(result.outcome, AttemptOutcome::FatalError);
3323 }
3324
3325 #[tokio::test]
3326 async fn forward_attempt_stream_drops_status_under_backpressure_without_blocking_retry() {
3327 let (outer_tx, mut outer_rx) = mpsc::channel::<StreamChunk>(1);
3328 outer_tx
3329 .send(StreamChunk::Status {
3330 message: "occupied".to_string(),
3331 })
3332 .await
3333 .unwrap();
3334
3335 let (attempt_tx, mut attempt_rx) = mpsc::channel::<StreamChunk>(10);
3336 let forward_handle =
3337 tokio::spawn(async move { forward_attempt_stream(&mut attempt_rx, &outer_tx).await });
3338
3339 attempt_tx
3340 .send(StreamChunk::Status {
3341 message: "keepalive".to_string(),
3342 })
3343 .await
3344 .unwrap();
3345 attempt_tx
3346 .send(StreamChunk::Error("retry me".to_string()))
3347 .await
3348 .unwrap();
3349 drop(attempt_tx);
3350
3351 let result = tokio::time::timeout(Duration::from_millis(300), forward_handle)
3352 .await
3353 .expect("status backpressure should not stall forwarder")
3354 .expect("forwarder join should succeed");
3355
3356 assert_eq!(result.outcome, AttemptOutcome::RetryableError);
3357 assert!(!result.forwarded_output);
3358 assert_eq!(result.error.as_deref(), Some("retry me"));
3359
3360 match outer_rx.recv().await {
3361 Some(StreamChunk::Status { message }) => assert_eq!(message, "occupied"),
3362 other => panic!("expected only the pre-filled status chunk, got: {other:?}"),
3363 }
3364
3365 let recv = tokio::time::timeout(Duration::from_millis(50), outer_rx.recv()).await;
3366 match recv {
3367 Err(_) => {}
3368 Ok(None) => {}
3369 Ok(Some(other)) => {
3370 panic!("did not expect forwarded status/error chunk, got: {other:?}")
3371 }
3372 }
3373 }
3374
3375 #[tokio::test]
3376 async fn forward_attempt_stream_drops_token_usage_under_backpressure_without_blocking_retry() {
3377 let (outer_tx, mut outer_rx) = mpsc::channel::<StreamChunk>(1);
3378 outer_tx
3379 .send(StreamChunk::TokenUsageUpdate {
3380 estimated: 42,
3381 limit: 100,
3382 percentage: 42,
3383 status: TokenUsageStatus::Green,
3384 estimated_cost: 0.0001,
3385 })
3386 .await
3387 .unwrap();
3388
3389 let (attempt_tx, mut attempt_rx) = mpsc::channel::<StreamChunk>(10);
3390 let forward_handle =
3391 tokio::spawn(async move { forward_attempt_stream(&mut attempt_rx, &outer_tx).await });
3392
3393 attempt_tx
3394 .send(StreamChunk::TokenUsageUpdate {
3395 estimated: 50,
3396 limit: 100,
3397 percentage: 50,
3398 status: TokenUsageStatus::Green,
3399 estimated_cost: 0.0002,
3400 })
3401 .await
3402 .unwrap();
3403 attempt_tx
3404 .send(StreamChunk::Error("retry me".to_string()))
3405 .await
3406 .unwrap();
3407 drop(attempt_tx);
3408
3409 let result = tokio::time::timeout(Duration::from_millis(300), forward_handle)
3410 .await
3411 .expect("token-usage backpressure should not stall forwarder")
3412 .expect("forwarder join should succeed");
3413
3414 assert_eq!(result.outcome, AttemptOutcome::RetryableError);
3415 assert!(!result.forwarded_output);
3416 assert_eq!(result.error.as_deref(), Some("retry me"));
3417
3418 match outer_rx.recv().await {
3419 Some(StreamChunk::TokenUsageUpdate { estimated, .. }) => assert_eq!(estimated, 42),
3420 other => panic!("expected only the pre-filled token-usage chunk, got: {other:?}"),
3421 }
3422
3423 let recv = tokio::time::timeout(Duration::from_millis(50), outer_rx.recv()).await;
3424 match recv {
3425 Err(_) => {}
3426 Ok(None) => {}
3427 Ok(Some(other)) => {
3428 panic!("did not expect forwarded token-usage/error chunk, got: {other:?}")
3429 }
3430 }
3431 }
3432}