gestura_core_retry/
retry.rs

1//! Retry management for transient failures
2//!
3//! Provides configurable retry policies with exponential backoff and jitter,
4//! error classification, and user notification support.
5//!
6//! Based on patterns from Block Goose's RetryManager architecture.
7
8use gestura_core_foundation::error::AppError;
9use rand::Rng as _;
10use serde::{Deserialize, Serialize};
11use std::time::Duration;
12
13/// Error classification for retry decisions
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ErrorClass {
16    /// Transient error that may succeed on retry (rate limits, timeouts, network issues)
17    Transient,
18    /// Permanent error that will not succeed on retry (auth failure, invalid input)
19    Permanent,
20    /// Context overflow error that requires compaction before retry
21    ContextOverflow,
22    /// Unknown error classification - treat as transient with limited retries
23    Unknown,
24}
25
26impl ErrorClass {
27    /// Classify an AppError for retry decisions
28    pub fn classify(error: &AppError) -> Self {
29        match error {
30            // Transient errors - worth retrying
31            AppError::Timeout(_) => Self::Transient,
32            AppError::Http(e) => {
33                if e.is_timeout() || e.is_connect() {
34                    Self::Transient
35                } else if let Some(status) = e.status() {
36                    match status.as_u16() {
37                        429 => Self::Transient,       // Rate limit
38                        500..=599 => Self::Transient, // Server errors
39                        401 | 403 => Self::Permanent, // Auth errors
40                        400 | 404 => Self::Permanent, // Client errors
41                        _ => Self::Unknown,
42                    }
43                } else {
44                    Self::Unknown
45                }
46            }
47            AppError::Llm(msg) => {
48                let msg_lower = msg.to_lowercase();
49                // Context overflow errors - need compaction, not blind retry
50                if msg_lower.contains("context_length_exceeded")
51                    || msg_lower.contains("context length")
52                    || msg_lower.contains("maximum context")
53                    || msg_lower.contains("token limit")
54                    || (msg_lower.contains("tokens") && msg_lower.contains("exceeds"))
55                {
56                    Self::ContextOverflow
57                } else if msg_lower.contains("rate limit")
58                    || msg_lower.contains("429")
59                    || msg_lower.contains("timeout")
60                    || msg_lower.contains("connection")
61                    || msg_lower.contains("temporarily")
62                {
63                    Self::Transient
64                } else if msg_lower.contains("401")
65                    || msg_lower.contains("403")
66                    || msg_lower.contains("unauthorized")
67                    || msg_lower.contains("invalid api key")
68                    || msg_lower.contains("not configured")
69                {
70                    Self::Permanent
71                } else {
72                    Self::Unknown
73                }
74            }
75            // Context overflow - needs compaction, not retry
76            AppError::ContextOverflow(_) => Self::ContextOverflow,
77            // Permanent errors - don't retry
78            AppError::Config(_) => Self::Permanent,
79            AppError::PermissionDenied(_) => Self::Permanent,
80            AppError::InvalidInput(_) => Self::Permanent,
81            AppError::NotFound(_) => Self::Permanent,
82            // Unknown - treat conservatively
83            _ => Self::Unknown,
84        }
85    }
86
87    /// Whether this error class should be retried with standard backoff
88    pub fn should_retry(&self) -> bool {
89        matches!(self, Self::Transient | Self::Unknown)
90    }
91
92    /// Whether this error requires context compaction before retry
93    pub fn needs_compaction(&self) -> bool {
94        matches!(self, Self::ContextOverflow)
95    }
96
97    /// Whether this error is recoverable (either by retry or compaction)
98    pub fn is_recoverable(&self) -> bool {
99        !matches!(self, Self::Permanent)
100    }
101}
102
103/// Retry policy configuration
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct RetryPolicy {
106    /// Maximum number of retry attempts (0 = no retries)
107    pub max_attempts: u32,
108    /// Initial delay before first retry (milliseconds)
109    pub initial_delay_ms: u64,
110    /// Maximum delay between retries (milliseconds)
111    pub max_delay_ms: u64,
112    /// Multiplier for exponential backoff (e.g., 2.0 = double each time)
113    pub backoff_multiplier: f64,
114    /// Jitter factor (0.0-1.0) to add randomness to delays
115    pub jitter_factor: f64,
116}
117
118impl Default for RetryPolicy {
119    fn default() -> Self {
120        Self {
121            max_attempts: 3,
122            initial_delay_ms: 1000,
123            max_delay_ms: 30000,
124            backoff_multiplier: 2.0,
125            jitter_factor: 0.25,
126        }
127    }
128}
129
130impl RetryPolicy {
131    /// Create a policy for API calls (moderate retries)
132    pub fn for_api() -> Self {
133        Self::default()
134    }
135
136    /// Create a policy for tool execution (fewer retries)
137    pub fn for_tools() -> Self {
138        Self {
139            max_attempts: 2,
140            initial_delay_ms: 500,
141            max_delay_ms: 5000,
142            backoff_multiplier: 2.0,
143            jitter_factor: 0.1,
144        }
145    }
146
147    /// Create a policy for streaming (quick retries)
148    pub fn for_streaming() -> Self {
149        Self {
150            max_attempts: 3,
151            initial_delay_ms: 1000,
152            max_delay_ms: 8000,
153            backoff_multiplier: 2.0,
154            jitter_factor: 0.25,
155        }
156    }
157
158    /// Calculate delay for a given attempt number (0-indexed)
159    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
160        if attempt == 0 {
161            return Duration::ZERO;
162        }
163
164        let base_delay =
165            self.initial_delay_ms as f64 * self.backoff_multiplier.powi((attempt - 1) as i32);
166        let capped_delay = base_delay.min(self.max_delay_ms as f64);
167
168        // Add jitter (only if jitter_factor > 0)
169        let jitter_range = capped_delay * self.jitter_factor;
170        let jitter = if jitter_range > 0.0 {
171            rand::thread_rng().gen_range(-jitter_range..jitter_range)
172        } else {
173            0.0
174        };
175        let final_delay = (capped_delay + jitter).max(0.0);
176
177        Duration::from_millis(final_delay as u64)
178    }
179}
180
181/// Retry event for notification callbacks
182#[derive(Debug, Clone)]
183pub struct RetryEvent {
184    /// Current attempt number (1-indexed)
185    pub attempt: u32,
186    /// Maximum attempts configured
187    pub max_attempts: u32,
188    /// Delay before next retry
189    pub delay: Duration,
190    /// Error that triggered the retry
191    pub error_message: String,
192    /// Error classification
193    pub error_class: ErrorClass,
194}
195
196/// Callback type for retry notifications
197pub type RetryCallback = Box<dyn Fn(RetryEvent) + Send + Sync>;
198
199/// Retry manager for executing operations with automatic retry
200pub struct RetryManager {
201    policy: RetryPolicy,
202    on_retry: Option<RetryCallback>,
203}
204
205impl RetryManager {
206    /// Create a new retry manager with the given policy
207    pub fn new(policy: RetryPolicy) -> Self {
208        Self {
209            policy,
210            on_retry: None,
211        }
212    }
213
214    /// Create a retry manager with default API policy
215    pub fn for_api() -> Self {
216        Self::new(RetryPolicy::for_api())
217    }
218
219    /// Create a retry manager for streaming operations
220    pub fn for_streaming() -> Self {
221        Self::new(RetryPolicy::for_streaming())
222    }
223
224    /// Create a retry manager for tool execution
225    pub fn for_tools() -> Self {
226        Self::new(RetryPolicy::for_tools())
227    }
228
229    /// Set a callback to be notified on retry attempts
230    pub fn with_retry_callback(mut self, callback: RetryCallback) -> Self {
231        self.on_retry = Some(callback);
232        self
233    }
234
235    /// Execute an async operation with retry logic
236    ///
237    /// Returns the result of the operation, or the last error if all retries fail.
238    pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T, AppError>
239    where
240        F: FnMut() -> Fut,
241        Fut: std::future::Future<Output = Result<T, AppError>>,
242    {
243        let mut last_error: Option<AppError> = None;
244
245        for attempt in 0..=self.policy.max_attempts {
246            // Wait before retry (skip for first attempt)
247            if attempt > 0 {
248                let delay = self.policy.delay_for_attempt(attempt);
249                tokio::time::sleep(delay).await;
250            }
251
252            match operation().await {
253                Ok(result) => return Ok(result),
254                Err(e) => {
255                    let error_class = ErrorClass::classify(&e);
256
257                    // Don't retry permanent errors
258                    if !error_class.should_retry() {
259                        return Err(e);
260                    }
261
262                    // Check if we have more attempts
263                    if attempt < self.policy.max_attempts {
264                        let delay = self.policy.delay_for_attempt(attempt + 1);
265
266                        // Notify callback if set
267                        if let Some(ref callback) = self.on_retry {
268                            callback(RetryEvent {
269                                attempt: attempt + 1,
270                                max_attempts: self.policy.max_attempts,
271                                delay,
272                                error_message: e.to_string(),
273                                error_class,
274                            });
275                        }
276
277                        tracing::warn!(
278                            attempt = attempt + 1,
279                            max_attempts = self.policy.max_attempts,
280                            delay_ms = delay.as_millis(),
281                            error = %e,
282                            error_class = ?error_class,
283                            "Operation failed, will retry"
284                        );
285                    }
286
287                    last_error = Some(e);
288                }
289            }
290        }
291
292        Err(last_error.unwrap_or_else(|| AppError::Internal("Retry exhausted".to_string())))
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn test_error_classification_transient() {
302        let timeout_err = AppError::Timeout("connection timed out".to_string());
303        assert_eq!(ErrorClass::classify(&timeout_err), ErrorClass::Transient);
304
305        let rate_limit_err = AppError::Llm("rate limit exceeded (429)".to_string());
306        assert_eq!(ErrorClass::classify(&rate_limit_err), ErrorClass::Transient);
307    }
308
309    #[test]
310    fn test_error_classification_permanent() {
311        let config_err = AppError::Config("missing API key".to_string());
312        assert_eq!(ErrorClass::classify(&config_err), ErrorClass::Permanent);
313
314        let auth_err = AppError::Llm("401 unauthorized".to_string());
315        assert_eq!(ErrorClass::classify(&auth_err), ErrorClass::Permanent);
316    }
317
318    #[test]
319    fn test_error_classification_context_overflow() {
320        // From error message
321        let overflow_err = AppError::Llm("maximum context length is 16385 tokens".to_string());
322        assert_eq!(
323            ErrorClass::classify(&overflow_err),
324            ErrorClass::ContextOverflow
325        );
326
327        // From explicit variant
328        let explicit_err = AppError::ContextOverflow("context too large".to_string());
329        assert_eq!(
330            ErrorClass::classify(&explicit_err),
331            ErrorClass::ContextOverflow
332        );
333
334        // From different message format
335        let token_err = AppError::Llm("Request tokens exceeds limit".to_string());
336        assert_eq!(
337            ErrorClass::classify(&token_err),
338            ErrorClass::ContextOverflow
339        );
340    }
341
342    #[test]
343    fn test_context_overflow_needs_compaction() {
344        assert!(ErrorClass::ContextOverflow.needs_compaction());
345        assert!(!ErrorClass::Transient.needs_compaction());
346        assert!(!ErrorClass::Permanent.needs_compaction());
347    }
348
349    #[test]
350    fn test_delay_calculation() {
351        let policy = RetryPolicy {
352            max_attempts: 3,
353            initial_delay_ms: 1000,
354            max_delay_ms: 10000,
355            backoff_multiplier: 2.0,
356            jitter_factor: 0.0, // No jitter for predictable testing
357        };
358
359        assert_eq!(policy.delay_for_attempt(0), Duration::ZERO);
360        assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(1000));
361        assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(2000));
362        assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(4000));
363        // Should cap at max_delay_ms
364        assert_eq!(policy.delay_for_attempt(5), Duration::from_millis(10000));
365    }
366
367    #[test]
368    fn test_should_retry() {
369        assert!(ErrorClass::Transient.should_retry());
370        assert!(ErrorClass::Unknown.should_retry());
371        assert!(!ErrorClass::Permanent.should_retry());
372    }
373
374    #[tokio::test]
375    async fn test_retry_manager_success_first_try() {
376        let manager = RetryManager::new(RetryPolicy::default());
377        let result: Result<i32, AppError> = manager.execute(|| async { Ok(42) }).await;
378        assert_eq!(result.unwrap(), 42);
379    }
380
381    #[tokio::test]
382    async fn test_retry_manager_permanent_error_no_retry() {
383        let manager = RetryManager::new(RetryPolicy::default());
384        let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
385        let call_count_clone = call_count.clone();
386
387        let result: Result<i32, AppError> = manager
388            .execute(|| {
389                let count = call_count_clone.clone();
390                async move {
391                    count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
392                    Err(AppError::Config("permanent error".to_string()))
393                }
394            })
395            .await;
396
397        assert!(result.is_err());
398        // Should only be called once (no retries for permanent errors)
399        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
400    }
401}