gestura_core_foundation/
stream_reconnect.rs

1//! Stream Reconnection Logic
2//!
3//! Provides automatic reconnection for dropped streaming connections
4//! with exponential backoff and state preservation.
5
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU32, Ordering};
9use std::time::Instant;
10use tokio::sync::mpsc;
11
12/// Default maximum reconnection attempts
13pub const DEFAULT_MAX_RECONNECT_ATTEMPTS: u32 = 5;
14
15/// Default initial backoff delay in milliseconds
16pub const DEFAULT_INITIAL_BACKOFF_MS: u64 = 1000;
17
18/// Default maximum backoff delay in milliseconds
19pub const DEFAULT_MAX_BACKOFF_MS: u64 = 30000;
20
21/// Default backoff multiplier
22pub const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
23
24/// Reconnection state
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26pub enum ReconnectState {
27    /// Not attempting reconnection
28    Idle,
29    /// Waiting before next attempt
30    Waiting,
31    /// Currently attempting to reconnect
32    Connecting,
33    /// Successfully reconnected
34    Connected,
35    /// All attempts exhausted
36    Failed,
37}
38
39/// Reconnection event for frontend notification
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub enum ReconnectEvent {
42    /// Starting reconnection attempt
43    AttemptStarted {
44        /// Current attempt number (1-indexed)
45        attempt: u32,
46        /// Maximum attempts
47        max_attempts: u32,
48    },
49    /// Waiting before next attempt
50    Waiting {
51        /// Delay in milliseconds
52        delay_ms: u64,
53        /// Reason for reconnection
54        reason: String,
55    },
56    /// Reconnection succeeded
57    Connected {
58        /// Total attempts made
59        attempts: u32,
60        /// Total time spent reconnecting in milliseconds
61        total_time_ms: u64,
62    },
63    /// Reconnection failed
64    Failed {
65        /// Total attempts made
66        attempts: u32,
67        /// Final error message
68        error: String,
69    },
70}
71
72/// Configuration for reconnection behavior
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ReconnectConfig {
75    /// Maximum number of reconnection attempts
76    pub max_attempts: u32,
77    /// Initial backoff delay in milliseconds
78    pub initial_backoff_ms: u64,
79    /// Maximum backoff delay in milliseconds
80    pub max_backoff_ms: u64,
81    /// Backoff multiplier for exponential backoff
82    pub backoff_multiplier: f64,
83    /// Whether to add jitter to backoff delays
84    pub jitter: bool,
85}
86
87impl Default for ReconnectConfig {
88    fn default() -> Self {
89        Self {
90            max_attempts: DEFAULT_MAX_RECONNECT_ATTEMPTS,
91            initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
92            max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
93            backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
94            jitter: true,
95        }
96    }
97}
98
99impl ReconnectConfig {
100    /// Calculate backoff delay for a given attempt
101    pub fn backoff_delay_ms(&self, attempt: u32) -> u64 {
102        let base_delay = self.initial_backoff_ms as f64
103            * self
104                .backoff_multiplier
105                .powi(attempt.saturating_sub(1) as i32);
106        let delay = (base_delay as u64).min(self.max_backoff_ms);
107
108        if self.jitter {
109            // Add up to 25% jitter
110            let jitter = (delay as f64 * 0.25 * rand_jitter()) as u64;
111            delay.saturating_add(jitter)
112        } else {
113            delay
114        }
115    }
116}
117
118/// Simple pseudo-random jitter (0.0 to 1.0)
119fn rand_jitter() -> f64 {
120    use std::time::SystemTime;
121    let nanos = SystemTime::now()
122        .duration_since(SystemTime::UNIX_EPOCH)
123        .map(|d| d.subsec_nanos())
124        .unwrap_or(0);
125    (nanos % 1000) as f64 / 1000.0
126}
127
128/// Stream state that can be preserved across reconnections
129#[derive(Debug, Clone, Default, Serialize, Deserialize)]
130pub struct StreamState {
131    /// Number of chunks received before disconnect
132    pub chunks_received: u64,
133    /// Total bytes received before disconnect
134    pub bytes_received: u64,
135    /// Last successful chunk timestamp (ms since stream start)
136    pub last_chunk_time_ms: u64,
137    /// Whether the stream was in the middle of a tool call
138    pub in_tool_call: bool,
139    /// Current tool call ID if in progress
140    pub current_tool_id: Option<String>,
141    /// Accumulated tool arguments if in progress
142    pub tool_args_buffer: String,
143}
144
145impl StreamState {
146    /// Create a new empty stream state
147    pub fn new() -> Self {
148        Self::default()
149    }
150
151    /// Record a chunk received
152    pub fn record_chunk(&mut self, bytes: u64, time_ms: u64) {
153        self.chunks_received += 1;
154        self.bytes_received += bytes;
155        self.last_chunk_time_ms = time_ms;
156    }
157
158    /// Start a tool call
159    pub fn start_tool_call(&mut self, id: String) {
160        self.in_tool_call = true;
161        self.current_tool_id = Some(id);
162        self.tool_args_buffer.clear();
163    }
164
165    /// Append tool arguments
166    pub fn append_tool_args(&mut self, args: &str) {
167        self.tool_args_buffer.push_str(args);
168    }
169
170    /// End tool call
171    pub fn end_tool_call(&mut self) {
172        self.in_tool_call = false;
173        self.current_tool_id = None;
174        self.tool_args_buffer.clear();
175    }
176
177    /// Check if stream can be resumed
178    pub fn can_resume(&self) -> bool {
179        // Can resume if we haven't received any chunks yet
180        // or if we're not in the middle of a tool call
181        self.chunks_received == 0 || !self.in_tool_call
182    }
183}
184
185/// Reconnection manager for streaming connections
186pub struct ReconnectManager {
187    config: ReconnectConfig,
188    state: ReconnectState,
189    attempt_count: Arc<AtomicU32>,
190    start_time: Option<Instant>,
191    stream_state: StreamState,
192    event_tx: Option<mpsc::Sender<ReconnectEvent>>,
193}
194
195impl ReconnectManager {
196    /// Create a new reconnection manager
197    pub fn new(config: ReconnectConfig) -> Self {
198        Self {
199            config,
200            state: ReconnectState::Idle,
201            attempt_count: Arc::new(AtomicU32::new(0)),
202            start_time: None,
203            stream_state: StreamState::new(),
204            event_tx: None,
205        }
206    }
207
208    /// Create with event channel for notifications
209    pub fn with_events(config: ReconnectConfig, tx: mpsc::Sender<ReconnectEvent>) -> Self {
210        Self {
211            config,
212            state: ReconnectState::Idle,
213            attempt_count: Arc::new(AtomicU32::new(0)),
214            start_time: None,
215            stream_state: StreamState::new(),
216            event_tx: Some(tx),
217        }
218    }
219
220    /// Get current reconnection state
221    pub fn state(&self) -> ReconnectState {
222        self.state
223    }
224
225    /// Get current attempt count
226    pub fn attempt_count(&self) -> u32 {
227        self.attempt_count.load(Ordering::SeqCst)
228    }
229
230    /// Get stream state
231    pub fn stream_state(&self) -> &StreamState {
232        &self.stream_state
233    }
234
235    /// Get mutable stream state
236    pub fn stream_state_mut(&mut self) -> &mut StreamState {
237        &mut self.stream_state
238    }
239
240    /// Check if more attempts are available
241    pub fn can_retry(&self) -> bool {
242        self.attempt_count() < self.config.max_attempts
243    }
244
245    /// Start a reconnection attempt
246    pub async fn start_attempt(&mut self) -> Option<u64> {
247        if !self.can_retry() {
248            self.state = ReconnectState::Failed;
249            if let Some(ref tx) = self.event_tx {
250                let _ = tx
251                    .send(ReconnectEvent::Failed {
252                        attempts: self.attempt_count(),
253                        error: "Maximum reconnection attempts exceeded".to_string(),
254                    })
255                    .await;
256            }
257            return None;
258        }
259
260        let attempt = self.attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
261
262        if self.start_time.is_none() {
263            self.start_time = Some(Instant::now());
264        }
265
266        // Calculate backoff delay
267        let delay_ms = self.config.backoff_delay_ms(attempt);
268
269        self.state = ReconnectState::Waiting;
270        if let Some(ref tx) = self.event_tx {
271            let _ = tx
272                .send(ReconnectEvent::Waiting {
273                    delay_ms,
274                    reason: format!("Attempt {} of {}", attempt, self.config.max_attempts),
275                })
276                .await;
277        }
278
279        Some(delay_ms)
280    }
281
282    /// Mark as connecting
283    pub async fn mark_connecting(&mut self) {
284        self.state = ReconnectState::Connecting;
285        if let Some(ref tx) = self.event_tx {
286            let _ = tx
287                .send(ReconnectEvent::AttemptStarted {
288                    attempt: self.attempt_count(),
289                    max_attempts: self.config.max_attempts,
290                })
291                .await;
292        }
293    }
294
295    /// Mark as successfully connected
296    pub async fn mark_connected(&mut self) {
297        self.state = ReconnectState::Connected;
298        let total_time_ms = self
299            .start_time
300            .map(|t| t.elapsed().as_millis() as u64)
301            .unwrap_or(0);
302
303        if let Some(ref tx) = self.event_tx {
304            let _ = tx
305                .send(ReconnectEvent::Connected {
306                    attempts: self.attempt_count(),
307                    total_time_ms,
308                })
309                .await;
310        }
311    }
312
313    /// Mark as failed with error
314    pub async fn mark_failed(&mut self, error: &str) {
315        self.state = ReconnectState::Failed;
316        if let Some(ref tx) = self.event_tx {
317            let _ = tx
318                .send(ReconnectEvent::Failed {
319                    attempts: self.attempt_count(),
320                    error: error.to_string(),
321                })
322                .await;
323        }
324    }
325
326    /// Reset for a new stream
327    pub fn reset(&mut self) {
328        self.state = ReconnectState::Idle;
329        self.attempt_count.store(0, Ordering::SeqCst);
330        self.start_time = None;
331        self.stream_state = StreamState::new();
332    }
333
334    /// Get configuration
335    pub fn config(&self) -> &ReconnectConfig {
336        &self.config
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[test]
345    fn test_reconnect_config_default() {
346        let config = ReconnectConfig::default();
347        assert_eq!(config.max_attempts, DEFAULT_MAX_RECONNECT_ATTEMPTS);
348        assert_eq!(config.initial_backoff_ms, DEFAULT_INITIAL_BACKOFF_MS);
349        assert!(config.jitter);
350    }
351
352    #[test]
353    fn test_backoff_delay_exponential() {
354        let config = ReconnectConfig {
355            initial_backoff_ms: 1000,
356            backoff_multiplier: 2.0,
357            max_backoff_ms: 30000,
358            jitter: false,
359            ..Default::default()
360        };
361
362        assert_eq!(config.backoff_delay_ms(1), 1000);
363        assert_eq!(config.backoff_delay_ms(2), 2000);
364        assert_eq!(config.backoff_delay_ms(3), 4000);
365        assert_eq!(config.backoff_delay_ms(4), 8000);
366    }
367
368    #[test]
369    fn test_backoff_delay_max_cap() {
370        let config = ReconnectConfig {
371            initial_backoff_ms: 1000,
372            backoff_multiplier: 2.0,
373            max_backoff_ms: 5000,
374            jitter: false,
375            ..Default::default()
376        };
377
378        assert_eq!(config.backoff_delay_ms(10), 5000);
379    }
380
381    #[test]
382    fn test_stream_state_record_chunk() {
383        let mut state = StreamState::new();
384        state.record_chunk(100, 1000);
385        assert_eq!(state.chunks_received, 1);
386        assert_eq!(state.bytes_received, 100);
387        assert_eq!(state.last_chunk_time_ms, 1000);
388    }
389
390    #[test]
391    fn test_stream_state_tool_call() {
392        let mut state = StreamState::new();
393        state.start_tool_call("tool-1".to_string());
394        assert!(state.in_tool_call);
395        assert_eq!(state.current_tool_id, Some("tool-1".to_string()));
396
397        state.append_tool_args("{\"arg\":\"value\"}");
398        assert_eq!(state.tool_args_buffer, "{\"arg\":\"value\"}");
399
400        state.end_tool_call();
401        assert!(!state.in_tool_call);
402        assert!(state.current_tool_id.is_none());
403    }
404
405    #[test]
406    fn test_stream_state_can_resume() {
407        let mut state = StreamState::new();
408        assert!(state.can_resume()); // No chunks yet
409
410        state.record_chunk(100, 1000);
411        assert!(state.can_resume()); // Has chunks but not in tool call
412
413        state.start_tool_call("tool-1".to_string());
414        assert!(!state.can_resume()); // In tool call
415    }
416
417    #[test]
418    fn test_reconnect_manager_can_retry() {
419        let manager = ReconnectManager::new(ReconnectConfig {
420            max_attempts: 3,
421            ..Default::default()
422        });
423        assert!(manager.can_retry());
424        assert_eq!(manager.attempt_count(), 0);
425    }
426
427    #[tokio::test]
428    async fn test_reconnect_manager_start_attempt() {
429        let mut manager = ReconnectManager::new(ReconnectConfig {
430            max_attempts: 3,
431            initial_backoff_ms: 1000,
432            jitter: false,
433            ..Default::default()
434        });
435
436        let delay = manager.start_attempt().await;
437        assert!(delay.is_some());
438        assert_eq!(delay.unwrap(), 1000);
439        assert_eq!(manager.attempt_count(), 1);
440        assert_eq!(manager.state(), ReconnectState::Waiting);
441    }
442
443    #[tokio::test]
444    async fn test_reconnect_manager_exhausted() {
445        let mut manager = ReconnectManager::new(ReconnectConfig {
446            max_attempts: 1,
447            ..Default::default()
448        });
449
450        let _ = manager.start_attempt().await;
451        let delay = manager.start_attempt().await;
452        assert!(delay.is_none());
453        assert_eq!(manager.state(), ReconnectState::Failed);
454    }
455
456    #[tokio::test]
457    async fn test_reconnect_manager_reset() {
458        let mut manager = ReconnectManager::new(ReconnectConfig::default());
459        let _ = manager.start_attempt().await;
460        manager.reset();
461        assert_eq!(manager.attempt_count(), 0);
462        assert_eq!(manager.state(), ReconnectState::Idle);
463    }
464}