1use serde::{Deserialize, Serialize};
7use std::fmt;
8use thiserror::Error;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12pub enum StreamErrorCategory {
13 Network,
15 Auth,
17 RateLimit,
19 Provider,
21 Format,
23 Resource,
25 Internal,
27 Cancelled,
29}
30
31impl StreamErrorCategory {
32 pub fn is_retryable(&self) -> bool {
34 matches!(
35 self,
36 StreamErrorCategory::Network
37 | StreamErrorCategory::RateLimit
38 | StreamErrorCategory::Provider
39 )
40 }
41
42 pub fn suggested_retry_delay_ms(&self) -> Option<u64> {
44 match self {
45 StreamErrorCategory::Network => Some(1000),
46 StreamErrorCategory::RateLimit => Some(5000),
47 StreamErrorCategory::Provider => Some(2000),
48 _ => None,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Error, Serialize, Deserialize)]
55pub struct StreamError {
56 pub category: StreamErrorCategory,
58 pub code: String,
60 pub message: String,
62 pub provider: Option<String>,
64 pub http_status: Option<u16>,
66 pub retryable: bool,
68 pub retry_after_ms: Option<u64>,
70 pub context: Option<String>,
72}
73
74impl fmt::Display for StreamError {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 write!(f, "[{}] {}", self.code, self.message)
77 }
78}
79
80impl StreamError {
81 pub fn new(
83 category: StreamErrorCategory,
84 code: impl Into<String>,
85 message: impl Into<String>,
86 ) -> Self {
87 let retryable = category.is_retryable();
88 let retry_after_ms = category.suggested_retry_delay_ms();
89 Self {
90 category,
91 code: code.into(),
92 message: message.into(),
93 provider: None,
94 http_status: None,
95 retryable,
96 retry_after_ms,
97 context: None,
98 }
99 }
100
101 pub fn network(message: impl Into<String>) -> Self {
103 Self::new(StreamErrorCategory::Network, "NETWORK_ERROR", message)
104 }
105
106 pub fn timeout(message: impl Into<String>) -> Self {
108 Self::new(StreamErrorCategory::Network, "TIMEOUT", message)
109 }
110
111 pub fn auth(message: impl Into<String>) -> Self {
113 Self::new(StreamErrorCategory::Auth, "AUTH_ERROR", message)
114 }
115
116 pub fn rate_limit(message: impl Into<String>, retry_after_ms: Option<u64>) -> Self {
118 let mut err = Self::new(StreamErrorCategory::RateLimit, "RATE_LIMITED", message);
119 err.retry_after_ms = retry_after_ms;
120 err
121 }
122
123 pub fn provider(provider: impl Into<String>, message: impl Into<String>) -> Self {
125 let mut err = Self::new(StreamErrorCategory::Provider, "PROVIDER_ERROR", message);
126 err.provider = Some(provider.into());
127 err
128 }
129
130 pub fn format(message: impl Into<String>) -> Self {
132 Self::new(StreamErrorCategory::Format, "FORMAT_ERROR", message)
133 }
134
135 pub fn resource(message: impl Into<String>) -> Self {
137 Self::new(StreamErrorCategory::Resource, "RESOURCE_EXHAUSTED", message)
138 }
139
140 pub fn internal(message: impl Into<String>) -> Self {
142 Self::new(StreamErrorCategory::Internal, "INTERNAL_ERROR", message)
143 }
144
145 pub fn cancelled() -> Self {
147 let mut err = Self::new(
148 StreamErrorCategory::Cancelled,
149 "CANCELLED",
150 "Stream was cancelled",
151 );
152 err.retryable = false;
153 err
154 }
155
156 pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
158 self.provider = Some(provider.into());
159 self
160 }
161
162 pub fn with_http_status(mut self, status: u16) -> Self {
164 self.http_status = Some(status);
165 self
166 }
167
168 pub fn with_context(mut self, context: impl Into<String>) -> Self {
170 self.context = Some(context.into());
171 self
172 }
173
174 pub fn with_retry_after(mut self, ms: u64) -> Self {
176 self.retry_after_ms = Some(ms);
177 self.retryable = true;
178 self
179 }
180
181 pub fn non_retryable(mut self) -> Self {
183 self.retryable = false;
184 self.retry_after_ms = None;
185 self
186 }
187
188 pub fn from_http_response(provider: &str, status: u16, body: &str) -> Self {
190 let category = match status {
191 401 | 403 => StreamErrorCategory::Auth,
192 429 => StreamErrorCategory::RateLimit,
193 400 | 422 => StreamErrorCategory::Format,
194 500..=599 => StreamErrorCategory::Provider,
195 _ => StreamErrorCategory::Internal,
196 };
197
198 let code = format!("HTTP_{}", status);
199 let message = if body.is_empty() {
200 format!("HTTP {} error from {}", status, provider)
201 } else {
202 if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
204 json.get("error")
205 .and_then(|e| e.get("message").or(Some(e)))
206 .and_then(|m| m.as_str())
207 .map(|s| s.to_string())
208 .unwrap_or_else(|| body.chars().take(200).collect())
209 } else {
210 body.chars().take(200).collect()
211 }
212 };
213
214 Self::new(category, code, message)
215 .with_provider(provider)
216 .with_http_status(status)
217 }
218
219 pub fn log(&self) {
221 match self.category {
222 StreamErrorCategory::Cancelled => {
223 tracing::debug!(
224 category = ?self.category,
225 code = %self.code,
226 "Stream cancelled"
227 );
228 }
229 StreamErrorCategory::RateLimit => {
230 tracing::warn!(
231 category = ?self.category,
232 code = %self.code,
233 provider = ?self.provider,
234 retry_after_ms = ?self.retry_after_ms,
235 "Rate limited: {}", self.message
236 );
237 }
238 StreamErrorCategory::Auth => {
239 tracing::error!(
240 category = ?self.category,
241 code = %self.code,
242 provider = ?self.provider,
243 "Authentication error: {}", self.message
244 );
245 }
246 _ => {
247 tracing::error!(
248 category = ?self.category,
249 code = %self.code,
250 provider = ?self.provider,
251 http_status = ?self.http_status,
252 retryable = self.retryable,
253 "Stream error: {}", self.message
254 );
255 }
256 }
257 }
258}
259
260pub type StreamResult<T> = Result<T, StreamError>;
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_stream_error_category_retryable() {
269 assert!(StreamErrorCategory::Network.is_retryable());
270 assert!(StreamErrorCategory::RateLimit.is_retryable());
271 assert!(StreamErrorCategory::Provider.is_retryable());
272 assert!(!StreamErrorCategory::Auth.is_retryable());
273 assert!(!StreamErrorCategory::Format.is_retryable());
274 }
275
276 #[test]
277 fn test_stream_error_new() {
278 let err = StreamError::new(StreamErrorCategory::Network, "TEST", "Test error");
279 assert_eq!(err.category, StreamErrorCategory::Network);
280 assert_eq!(err.code, "TEST");
281 assert_eq!(err.message, "Test error");
282 assert!(err.retryable);
283 }
284
285 #[test]
286 fn test_stream_error_network() {
287 let err = StreamError::network("Connection failed");
288 assert_eq!(err.category, StreamErrorCategory::Network);
289 assert_eq!(err.code, "NETWORK_ERROR");
290 assert!(err.retryable);
291 }
292
293 #[test]
294 fn test_stream_error_timeout() {
295 let err = StreamError::timeout("Request timed out");
296 assert_eq!(err.category, StreamErrorCategory::Network);
297 assert_eq!(err.code, "TIMEOUT");
298 }
299
300 #[test]
301 fn test_stream_error_auth() {
302 let err = StreamError::auth("Invalid API key");
303 assert_eq!(err.category, StreamErrorCategory::Auth);
304 assert!(!err.retryable);
305 }
306
307 #[test]
308 fn test_stream_error_rate_limit() {
309 let err = StreamError::rate_limit("Too many requests", Some(5000));
310 assert_eq!(err.category, StreamErrorCategory::RateLimit);
311 assert_eq!(err.retry_after_ms, Some(5000));
312 assert!(err.retryable);
313 }
314
315 #[test]
316 fn test_stream_error_provider() {
317 let err = StreamError::provider("openai", "Model not found");
318 assert_eq!(err.category, StreamErrorCategory::Provider);
319 assert_eq!(err.provider, Some("openai".to_string()));
320 }
321
322 #[test]
323 fn test_stream_error_cancelled() {
324 let err = StreamError::cancelled();
325 assert_eq!(err.category, StreamErrorCategory::Cancelled);
326 assert!(!err.retryable);
327 }
328
329 #[test]
330 fn test_stream_error_with_context() {
331 let err = StreamError::network("Failed")
332 .with_provider("anthropic")
333 .with_http_status(500)
334 .with_context("During streaming response");
335
336 assert_eq!(err.provider, Some("anthropic".to_string()));
337 assert_eq!(err.http_status, Some(500));
338 assert_eq!(err.context, Some("During streaming response".to_string()));
339 }
340
341 #[test]
342 fn test_stream_error_from_http_response() {
343 let err = StreamError::from_http_response(
344 "openai",
345 401,
346 r#"{"error":{"message":"Invalid API key"}}"#,
347 );
348 assert_eq!(err.category, StreamErrorCategory::Auth);
349 assert_eq!(err.http_status, Some(401));
350 assert!(err.message.contains("Invalid API key"));
351 }
352
353 #[test]
354 fn test_stream_error_from_http_response_rate_limit() {
355 let err = StreamError::from_http_response("anthropic", 429, "Rate limit exceeded");
356 assert_eq!(err.category, StreamErrorCategory::RateLimit);
357 assert!(err.retryable);
358 }
359
360 #[test]
361 fn test_stream_error_display() {
362 let err = StreamError::network("Connection reset");
363 let display = format!("{}", err);
364 assert!(display.contains("NETWORK_ERROR"));
365 assert!(display.contains("Connection reset"));
366 }
367}