1use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU32, Ordering};
9use std::time::Instant;
10use tokio::sync::mpsc;
11
12pub const DEFAULT_MAX_RECONNECT_ATTEMPTS: u32 = 5;
14
15pub const DEFAULT_INITIAL_BACKOFF_MS: u64 = 1000;
17
18pub const DEFAULT_MAX_BACKOFF_MS: u64 = 30000;
20
21pub const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26pub enum ReconnectState {
27 Idle,
29 Waiting,
31 Connecting,
33 Connected,
35 Failed,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub enum ReconnectEvent {
42 AttemptStarted {
44 attempt: u32,
46 max_attempts: u32,
48 },
49 Waiting {
51 delay_ms: u64,
53 reason: String,
55 },
56 Connected {
58 attempts: u32,
60 total_time_ms: u64,
62 },
63 Failed {
65 attempts: u32,
67 error: String,
69 },
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ReconnectConfig {
75 pub max_attempts: u32,
77 pub initial_backoff_ms: u64,
79 pub max_backoff_ms: u64,
81 pub backoff_multiplier: f64,
83 pub jitter: bool,
85}
86
87impl Default for ReconnectConfig {
88 fn default() -> Self {
89 Self {
90 max_attempts: DEFAULT_MAX_RECONNECT_ATTEMPTS,
91 initial_backoff_ms: DEFAULT_INITIAL_BACKOFF_MS,
92 max_backoff_ms: DEFAULT_MAX_BACKOFF_MS,
93 backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
94 jitter: true,
95 }
96 }
97}
98
99impl ReconnectConfig {
100 pub fn backoff_delay_ms(&self, attempt: u32) -> u64 {
102 let base_delay = self.initial_backoff_ms as f64
103 * self
104 .backoff_multiplier
105 .powi(attempt.saturating_sub(1) as i32);
106 let delay = (base_delay as u64).min(self.max_backoff_ms);
107
108 if self.jitter {
109 let jitter = (delay as f64 * 0.25 * rand_jitter()) as u64;
111 delay.saturating_add(jitter)
112 } else {
113 delay
114 }
115 }
116}
117
118fn rand_jitter() -> f64 {
120 use std::time::SystemTime;
121 let nanos = SystemTime::now()
122 .duration_since(SystemTime::UNIX_EPOCH)
123 .map(|d| d.subsec_nanos())
124 .unwrap_or(0);
125 (nanos % 1000) as f64 / 1000.0
126}
127
128#[derive(Debug, Clone, Default, Serialize, Deserialize)]
130pub struct StreamState {
131 pub chunks_received: u64,
133 pub bytes_received: u64,
135 pub last_chunk_time_ms: u64,
137 pub in_tool_call: bool,
139 pub current_tool_id: Option<String>,
141 pub tool_args_buffer: String,
143}
144
145impl StreamState {
146 pub fn new() -> Self {
148 Self::default()
149 }
150
151 pub fn record_chunk(&mut self, bytes: u64, time_ms: u64) {
153 self.chunks_received += 1;
154 self.bytes_received += bytes;
155 self.last_chunk_time_ms = time_ms;
156 }
157
158 pub fn start_tool_call(&mut self, id: String) {
160 self.in_tool_call = true;
161 self.current_tool_id = Some(id);
162 self.tool_args_buffer.clear();
163 }
164
165 pub fn append_tool_args(&mut self, args: &str) {
167 self.tool_args_buffer.push_str(args);
168 }
169
170 pub fn end_tool_call(&mut self) {
172 self.in_tool_call = false;
173 self.current_tool_id = None;
174 self.tool_args_buffer.clear();
175 }
176
177 pub fn can_resume(&self) -> bool {
179 self.chunks_received == 0 || !self.in_tool_call
182 }
183}
184
185pub struct ReconnectManager {
187 config: ReconnectConfig,
188 state: ReconnectState,
189 attempt_count: Arc<AtomicU32>,
190 start_time: Option<Instant>,
191 stream_state: StreamState,
192 event_tx: Option<mpsc::Sender<ReconnectEvent>>,
193}
194
195impl ReconnectManager {
196 pub fn new(config: ReconnectConfig) -> Self {
198 Self {
199 config,
200 state: ReconnectState::Idle,
201 attempt_count: Arc::new(AtomicU32::new(0)),
202 start_time: None,
203 stream_state: StreamState::new(),
204 event_tx: None,
205 }
206 }
207
208 pub fn with_events(config: ReconnectConfig, tx: mpsc::Sender<ReconnectEvent>) -> Self {
210 Self {
211 config,
212 state: ReconnectState::Idle,
213 attempt_count: Arc::new(AtomicU32::new(0)),
214 start_time: None,
215 stream_state: StreamState::new(),
216 event_tx: Some(tx),
217 }
218 }
219
220 pub fn state(&self) -> ReconnectState {
222 self.state
223 }
224
225 pub fn attempt_count(&self) -> u32 {
227 self.attempt_count.load(Ordering::SeqCst)
228 }
229
230 pub fn stream_state(&self) -> &StreamState {
232 &self.stream_state
233 }
234
235 pub fn stream_state_mut(&mut self) -> &mut StreamState {
237 &mut self.stream_state
238 }
239
240 pub fn can_retry(&self) -> bool {
242 self.attempt_count() < self.config.max_attempts
243 }
244
245 pub async fn start_attempt(&mut self) -> Option<u64> {
247 if !self.can_retry() {
248 self.state = ReconnectState::Failed;
249 if let Some(ref tx) = self.event_tx {
250 let _ = tx
251 .send(ReconnectEvent::Failed {
252 attempts: self.attempt_count(),
253 error: "Maximum reconnection attempts exceeded".to_string(),
254 })
255 .await;
256 }
257 return None;
258 }
259
260 let attempt = self.attempt_count.fetch_add(1, Ordering::SeqCst) + 1;
261
262 if self.start_time.is_none() {
263 self.start_time = Some(Instant::now());
264 }
265
266 let delay_ms = self.config.backoff_delay_ms(attempt);
268
269 self.state = ReconnectState::Waiting;
270 if let Some(ref tx) = self.event_tx {
271 let _ = tx
272 .send(ReconnectEvent::Waiting {
273 delay_ms,
274 reason: format!("Attempt {} of {}", attempt, self.config.max_attempts),
275 })
276 .await;
277 }
278
279 Some(delay_ms)
280 }
281
282 pub async fn mark_connecting(&mut self) {
284 self.state = ReconnectState::Connecting;
285 if let Some(ref tx) = self.event_tx {
286 let _ = tx
287 .send(ReconnectEvent::AttemptStarted {
288 attempt: self.attempt_count(),
289 max_attempts: self.config.max_attempts,
290 })
291 .await;
292 }
293 }
294
295 pub async fn mark_connected(&mut self) {
297 self.state = ReconnectState::Connected;
298 let total_time_ms = self
299 .start_time
300 .map(|t| t.elapsed().as_millis() as u64)
301 .unwrap_or(0);
302
303 if let Some(ref tx) = self.event_tx {
304 let _ = tx
305 .send(ReconnectEvent::Connected {
306 attempts: self.attempt_count(),
307 total_time_ms,
308 })
309 .await;
310 }
311 }
312
313 pub async fn mark_failed(&mut self, error: &str) {
315 self.state = ReconnectState::Failed;
316 if let Some(ref tx) = self.event_tx {
317 let _ = tx
318 .send(ReconnectEvent::Failed {
319 attempts: self.attempt_count(),
320 error: error.to_string(),
321 })
322 .await;
323 }
324 }
325
326 pub fn reset(&mut self) {
328 self.state = ReconnectState::Idle;
329 self.attempt_count.store(0, Ordering::SeqCst);
330 self.start_time = None;
331 self.stream_state = StreamState::new();
332 }
333
334 pub fn config(&self) -> &ReconnectConfig {
336 &self.config
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn test_reconnect_config_default() {
346 let config = ReconnectConfig::default();
347 assert_eq!(config.max_attempts, DEFAULT_MAX_RECONNECT_ATTEMPTS);
348 assert_eq!(config.initial_backoff_ms, DEFAULT_INITIAL_BACKOFF_MS);
349 assert!(config.jitter);
350 }
351
352 #[test]
353 fn test_backoff_delay_exponential() {
354 let config = ReconnectConfig {
355 initial_backoff_ms: 1000,
356 backoff_multiplier: 2.0,
357 max_backoff_ms: 30000,
358 jitter: false,
359 ..Default::default()
360 };
361
362 assert_eq!(config.backoff_delay_ms(1), 1000);
363 assert_eq!(config.backoff_delay_ms(2), 2000);
364 assert_eq!(config.backoff_delay_ms(3), 4000);
365 assert_eq!(config.backoff_delay_ms(4), 8000);
366 }
367
368 #[test]
369 fn test_backoff_delay_max_cap() {
370 let config = ReconnectConfig {
371 initial_backoff_ms: 1000,
372 backoff_multiplier: 2.0,
373 max_backoff_ms: 5000,
374 jitter: false,
375 ..Default::default()
376 };
377
378 assert_eq!(config.backoff_delay_ms(10), 5000);
379 }
380
381 #[test]
382 fn test_stream_state_record_chunk() {
383 let mut state = StreamState::new();
384 state.record_chunk(100, 1000);
385 assert_eq!(state.chunks_received, 1);
386 assert_eq!(state.bytes_received, 100);
387 assert_eq!(state.last_chunk_time_ms, 1000);
388 }
389
390 #[test]
391 fn test_stream_state_tool_call() {
392 let mut state = StreamState::new();
393 state.start_tool_call("tool-1".to_string());
394 assert!(state.in_tool_call);
395 assert_eq!(state.current_tool_id, Some("tool-1".to_string()));
396
397 state.append_tool_args("{\"arg\":\"value\"}");
398 assert_eq!(state.tool_args_buffer, "{\"arg\":\"value\"}");
399
400 state.end_tool_call();
401 assert!(!state.in_tool_call);
402 assert!(state.current_tool_id.is_none());
403 }
404
405 #[test]
406 fn test_stream_state_can_resume() {
407 let mut state = StreamState::new();
408 assert!(state.can_resume()); state.record_chunk(100, 1000);
411 assert!(state.can_resume()); state.start_tool_call("tool-1".to_string());
414 assert!(!state.can_resume()); }
416
417 #[test]
418 fn test_reconnect_manager_can_retry() {
419 let manager = ReconnectManager::new(ReconnectConfig {
420 max_attempts: 3,
421 ..Default::default()
422 });
423 assert!(manager.can_retry());
424 assert_eq!(manager.attempt_count(), 0);
425 }
426
427 #[tokio::test]
428 async fn test_reconnect_manager_start_attempt() {
429 let mut manager = ReconnectManager::new(ReconnectConfig {
430 max_attempts: 3,
431 initial_backoff_ms: 1000,
432 jitter: false,
433 ..Default::default()
434 });
435
436 let delay = manager.start_attempt().await;
437 assert!(delay.is_some());
438 assert_eq!(delay.unwrap(), 1000);
439 assert_eq!(manager.attempt_count(), 1);
440 assert_eq!(manager.state(), ReconnectState::Waiting);
441 }
442
443 #[tokio::test]
444 async fn test_reconnect_manager_exhausted() {
445 let mut manager = ReconnectManager::new(ReconnectConfig {
446 max_attempts: 1,
447 ..Default::default()
448 });
449
450 let _ = manager.start_attempt().await;
451 let delay = manager.start_attempt().await;
452 assert!(delay.is_none());
453 assert_eq!(manager.state(), ReconnectState::Failed);
454 }
455
456 #[tokio::test]
457 async fn test_reconnect_manager_reset() {
458 let mut manager = ReconnectManager::new(ReconnectConfig::default());
459 let _ = manager.start_attempt().await;
460 manager.reset();
461 assert_eq!(manager.attempt_count(), 0);
462 assert_eq!(manager.state(), ReconnectState::Idle);
463 }
464}