1use gestura_core_foundation::error::AppError;
9use rand::Rng as _;
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ErrorClass {
16 Transient,
18 Permanent,
20 ContextOverflow,
22 Unknown,
24}
25
26impl ErrorClass {
27 pub fn classify(error: &AppError) -> Self {
29 match error {
30 AppError::Timeout(_) => Self::Transient,
32 AppError::Http(e) => {
33 if e.is_timeout() || e.is_connect() {
34 Self::Transient
35 } else if let Some(status) = e.status() {
36 match status.as_u16() {
37 429 => Self::Transient, 500..=599 => Self::Transient, 401 | 403 => Self::Permanent, 400 | 404 => Self::Permanent, _ => Self::Unknown,
42 }
43 } else {
44 Self::Unknown
45 }
46 }
47 AppError::Llm(msg) => {
48 let msg_lower = msg.to_lowercase();
49 if msg_lower.contains("context_length_exceeded")
51 || msg_lower.contains("context length")
52 || msg_lower.contains("maximum context")
53 || msg_lower.contains("token limit")
54 || (msg_lower.contains("tokens") && msg_lower.contains("exceeds"))
55 {
56 Self::ContextOverflow
57 } else if msg_lower.contains("rate limit")
58 || msg_lower.contains("429")
59 || msg_lower.contains("timeout")
60 || msg_lower.contains("connection")
61 || msg_lower.contains("temporarily")
62 {
63 Self::Transient
64 } else if msg_lower.contains("401")
65 || msg_lower.contains("403")
66 || msg_lower.contains("unauthorized")
67 || msg_lower.contains("invalid api key")
68 || msg_lower.contains("not configured")
69 {
70 Self::Permanent
71 } else {
72 Self::Unknown
73 }
74 }
75 AppError::ContextOverflow(_) => Self::ContextOverflow,
77 AppError::Config(_) => Self::Permanent,
79 AppError::PermissionDenied(_) => Self::Permanent,
80 AppError::InvalidInput(_) => Self::Permanent,
81 AppError::NotFound(_) => Self::Permanent,
82 _ => Self::Unknown,
84 }
85 }
86
87 pub fn should_retry(&self) -> bool {
89 matches!(self, Self::Transient | Self::Unknown)
90 }
91
92 pub fn needs_compaction(&self) -> bool {
94 matches!(self, Self::ContextOverflow)
95 }
96
97 pub fn is_recoverable(&self) -> bool {
99 !matches!(self, Self::Permanent)
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct RetryPolicy {
106 pub max_attempts: u32,
108 pub initial_delay_ms: u64,
110 pub max_delay_ms: u64,
112 pub backoff_multiplier: f64,
114 pub jitter_factor: f64,
116}
117
118impl Default for RetryPolicy {
119 fn default() -> Self {
120 Self {
121 max_attempts: 3,
122 initial_delay_ms: 1000,
123 max_delay_ms: 30000,
124 backoff_multiplier: 2.0,
125 jitter_factor: 0.25,
126 }
127 }
128}
129
130impl RetryPolicy {
131 pub fn for_api() -> Self {
133 Self::default()
134 }
135
136 pub fn for_tools() -> Self {
138 Self {
139 max_attempts: 2,
140 initial_delay_ms: 500,
141 max_delay_ms: 5000,
142 backoff_multiplier: 2.0,
143 jitter_factor: 0.1,
144 }
145 }
146
147 pub fn for_streaming() -> Self {
149 Self {
150 max_attempts: 3,
151 initial_delay_ms: 1000,
152 max_delay_ms: 8000,
153 backoff_multiplier: 2.0,
154 jitter_factor: 0.25,
155 }
156 }
157
158 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
160 if attempt == 0 {
161 return Duration::ZERO;
162 }
163
164 let base_delay =
165 self.initial_delay_ms as f64 * self.backoff_multiplier.powi((attempt - 1) as i32);
166 let capped_delay = base_delay.min(self.max_delay_ms as f64);
167
168 let jitter_range = capped_delay * self.jitter_factor;
170 let jitter = if jitter_range > 0.0 {
171 rand::thread_rng().gen_range(-jitter_range..jitter_range)
172 } else {
173 0.0
174 };
175 let final_delay = (capped_delay + jitter).max(0.0);
176
177 Duration::from_millis(final_delay as u64)
178 }
179}
180
181#[derive(Debug, Clone)]
183pub struct RetryEvent {
184 pub attempt: u32,
186 pub max_attempts: u32,
188 pub delay: Duration,
190 pub error_message: String,
192 pub error_class: ErrorClass,
194}
195
196pub type RetryCallback = Box<dyn Fn(RetryEvent) + Send + Sync>;
198
199pub struct RetryManager {
201 policy: RetryPolicy,
202 on_retry: Option<RetryCallback>,
203}
204
205impl RetryManager {
206 pub fn new(policy: RetryPolicy) -> Self {
208 Self {
209 policy,
210 on_retry: None,
211 }
212 }
213
214 pub fn for_api() -> Self {
216 Self::new(RetryPolicy::for_api())
217 }
218
219 pub fn for_streaming() -> Self {
221 Self::new(RetryPolicy::for_streaming())
222 }
223
224 pub fn for_tools() -> Self {
226 Self::new(RetryPolicy::for_tools())
227 }
228
229 pub fn with_retry_callback(mut self, callback: RetryCallback) -> Self {
231 self.on_retry = Some(callback);
232 self
233 }
234
235 pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T, AppError>
239 where
240 F: FnMut() -> Fut,
241 Fut: std::future::Future<Output = Result<T, AppError>>,
242 {
243 let mut last_error: Option<AppError> = None;
244
245 for attempt in 0..=self.policy.max_attempts {
246 if attempt > 0 {
248 let delay = self.policy.delay_for_attempt(attempt);
249 tokio::time::sleep(delay).await;
250 }
251
252 match operation().await {
253 Ok(result) => return Ok(result),
254 Err(e) => {
255 let error_class = ErrorClass::classify(&e);
256
257 if !error_class.should_retry() {
259 return Err(e);
260 }
261
262 if attempt < self.policy.max_attempts {
264 let delay = self.policy.delay_for_attempt(attempt + 1);
265
266 if let Some(ref callback) = self.on_retry {
268 callback(RetryEvent {
269 attempt: attempt + 1,
270 max_attempts: self.policy.max_attempts,
271 delay,
272 error_message: e.to_string(),
273 error_class,
274 });
275 }
276
277 tracing::warn!(
278 attempt = attempt + 1,
279 max_attempts = self.policy.max_attempts,
280 delay_ms = delay.as_millis(),
281 error = %e,
282 error_class = ?error_class,
283 "Operation failed, will retry"
284 );
285 }
286
287 last_error = Some(e);
288 }
289 }
290 }
291
292 Err(last_error.unwrap_or_else(|| AppError::Internal("Retry exhausted".to_string())))
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_error_classification_transient() {
302 let timeout_err = AppError::Timeout("connection timed out".to_string());
303 assert_eq!(ErrorClass::classify(&timeout_err), ErrorClass::Transient);
304
305 let rate_limit_err = AppError::Llm("rate limit exceeded (429)".to_string());
306 assert_eq!(ErrorClass::classify(&rate_limit_err), ErrorClass::Transient);
307 }
308
309 #[test]
310 fn test_error_classification_permanent() {
311 let config_err = AppError::Config("missing API key".to_string());
312 assert_eq!(ErrorClass::classify(&config_err), ErrorClass::Permanent);
313
314 let auth_err = AppError::Llm("401 unauthorized".to_string());
315 assert_eq!(ErrorClass::classify(&auth_err), ErrorClass::Permanent);
316 }
317
318 #[test]
319 fn test_error_classification_context_overflow() {
320 let overflow_err = AppError::Llm("maximum context length is 16385 tokens".to_string());
322 assert_eq!(
323 ErrorClass::classify(&overflow_err),
324 ErrorClass::ContextOverflow
325 );
326
327 let explicit_err = AppError::ContextOverflow("context too large".to_string());
329 assert_eq!(
330 ErrorClass::classify(&explicit_err),
331 ErrorClass::ContextOverflow
332 );
333
334 let token_err = AppError::Llm("Request tokens exceeds limit".to_string());
336 assert_eq!(
337 ErrorClass::classify(&token_err),
338 ErrorClass::ContextOverflow
339 );
340 }
341
342 #[test]
343 fn test_context_overflow_needs_compaction() {
344 assert!(ErrorClass::ContextOverflow.needs_compaction());
345 assert!(!ErrorClass::Transient.needs_compaction());
346 assert!(!ErrorClass::Permanent.needs_compaction());
347 }
348
349 #[test]
350 fn test_delay_calculation() {
351 let policy = RetryPolicy {
352 max_attempts: 3,
353 initial_delay_ms: 1000,
354 max_delay_ms: 10000,
355 backoff_multiplier: 2.0,
356 jitter_factor: 0.0, };
358
359 assert_eq!(policy.delay_for_attempt(0), Duration::ZERO);
360 assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(1000));
361 assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(2000));
362 assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(4000));
363 assert_eq!(policy.delay_for_attempt(5), Duration::from_millis(10000));
365 }
366
367 #[test]
368 fn test_should_retry() {
369 assert!(ErrorClass::Transient.should_retry());
370 assert!(ErrorClass::Unknown.should_retry());
371 assert!(!ErrorClass::Permanent.should_retry());
372 }
373
374 #[tokio::test]
375 async fn test_retry_manager_success_first_try() {
376 let manager = RetryManager::new(RetryPolicy::default());
377 let result: Result<i32, AppError> = manager.execute(|| async { Ok(42) }).await;
378 assert_eq!(result.unwrap(), 42);
379 }
380
381 #[tokio::test]
382 async fn test_retry_manager_permanent_error_no_retry() {
383 let manager = RetryManager::new(RetryPolicy::default());
384 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
385 let call_count_clone = call_count.clone();
386
387 let result: Result<i32, AppError> = manager
388 .execute(|| {
389 let count = call_count_clone.clone();
390 async move {
391 count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
392 Err(AppError::Config("permanent error".to_string()))
393 }
394 })
395 .await;
396
397 assert!(result.is_err());
398 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
400 }
401}