gestura_core_mcp/
notifications.rs

1//! MCP Notifications - Progress, logging, and cancellation
2//! Provides notification handling for long-running operations.
3
4use super::types::{
5    CancelledNotification, LogLevel, LoggingMessage, ProgressNotification, ProgressToken,
6};
7use std::collections::HashMap;
8use std::sync::RwLock;
9use tokio::sync::broadcast;
10
11/// Notification sender for MCP notifications
12pub type NotificationSender = broadcast::Sender<McpNotification>;
13/// Notification receiver for MCP notifications
14pub type NotificationReceiver = broadcast::Receiver<McpNotification>;
15
16/// MCP notification types
17#[derive(Debug, Clone, serde::Serialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum McpNotification {
20    /// Progress update
21    Progress(ProgressNotification),
22    /// Log message
23    Log(LoggingMessage),
24    /// Request cancelled
25    Cancelled(CancelledNotification),
26    /// Tools list changed
27    ToolsListChanged,
28    /// Resources list changed
29    ResourcesListChanged,
30    /// Prompts list changed
31    PromptsListChanged,
32}
33
34/// Progress tracker for long-running operations
35#[derive(Debug)]
36pub struct ProgressTracker {
37    active_operations: RwLock<HashMap<String, OperationProgress>>,
38    sender: NotificationSender,
39}
40
41/// Progress state for an operation
42#[derive(Debug, Clone)]
43pub struct OperationProgress {
44    pub token: ProgressToken,
45    pub current: f64,
46    pub total: Option<f64>,
47    pub message: Option<String>,
48    pub cancelled: bool,
49}
50
51impl ProgressTracker {
52    /// Create a new progress tracker
53    pub fn new(sender: NotificationSender) -> Self {
54        Self {
55            active_operations: RwLock::new(HashMap::new()),
56            sender,
57        }
58    }
59
60    /// Start tracking a new operation
61    pub fn start_operation(&self, token: impl Into<ProgressToken>, total: Option<f64>) -> String {
62        let token = token.into();
63        let id = match &token {
64            ProgressToken::String(s) => s.clone(),
65            ProgressToken::Integer(i) => i.to_string(),
66        };
67
68        let progress = OperationProgress {
69            token: token.clone(),
70            current: 0.0,
71            total,
72            message: None,
73            cancelled: false,
74        };
75
76        if let Ok(mut ops) = self.active_operations.write() {
77            ops.insert(id.clone(), progress);
78        }
79
80        id
81    }
82
83    /// Update progress for an operation
84    pub fn update_progress(&self, id: &str, current: f64, message: Option<String>) {
85        let notification = {
86            let mut ops = match self.active_operations.write() {
87                Ok(ops) => ops,
88                Err(_) => return,
89            };
90
91            if let Some(op) = ops.get_mut(id) {
92                op.current = current;
93                op.message = message.clone();
94
95                Some(ProgressNotification {
96                    progress_token: op.token.clone(),
97                    progress: current,
98                    total: op.total,
99                    message,
100                })
101            } else {
102                None
103            }
104        };
105
106        if let Some(notif) = notification {
107            let _ = self.sender.send(McpNotification::Progress(notif));
108        }
109    }
110
111    /// Complete an operation
112    pub fn complete_operation(&self, id: &str) {
113        if let Ok(mut ops) = self.active_operations.write()
114            && let Some(op) = ops.remove(id)
115        {
116            let _ = self
117                .sender
118                .send(McpNotification::Progress(ProgressNotification {
119                    progress_token: op.token,
120                    progress: op.total.unwrap_or(100.0),
121                    total: op.total,
122                    message: Some("Complete".to_string()),
123                }));
124        }
125    }
126
127    /// Cancel an operation
128    pub fn cancel_operation(&self, id: &str, reason: Option<String>) -> bool {
129        if let Ok(mut ops) = self.active_operations.write()
130            && let Some(op) = ops.get_mut(id)
131        {
132            op.cancelled = true;
133            let _ = self
134                .sender
135                .send(McpNotification::Cancelled(CancelledNotification {
136                    request_id: serde_json::json!(id),
137                    reason,
138                }));
139            return true;
140        }
141
142        false
143    }
144
145    /// Check if an operation is cancelled
146    pub fn is_cancelled(&self, id: &str) -> bool {
147        if let Ok(ops) = self.active_operations.read()
148            && let Some(op) = ops.get(id)
149        {
150            return op.cancelled;
151        }
152
153        false
154    }
155}
156
157/// Logger for MCP logging notifications
158#[derive(Debug, Clone)]
159pub struct McpLogger {
160    sender: NotificationSender,
161    logger_name: Option<String>,
162}
163
164impl McpLogger {
165    /// Create a new MCP logger
166    pub fn new(sender: NotificationSender, logger_name: Option<String>) -> Self {
167        Self {
168            sender,
169            logger_name,
170        }
171    }
172
173    /// Log a message at the specified level
174    pub fn log(&self, level: LogLevel, data: impl Into<serde_json::Value>) {
175        let _ = self.sender.send(McpNotification::Log(LoggingMessage {
176            level,
177            logger: self.logger_name.clone(),
178            data: data.into(),
179        }));
180    }
181
182    /// Log a debug message
183    pub fn debug(&self, message: impl Into<String>) {
184        self.log(LogLevel::Debug, serde_json::json!(message.into()));
185    }
186
187    /// Log an info message
188    pub fn info(&self, message: impl Into<String>) {
189        self.log(LogLevel::Info, serde_json::json!(message.into()));
190    }
191
192    /// Log a warning message
193    pub fn warning(&self, message: impl Into<String>) {
194        self.log(LogLevel::Warning, serde_json::json!(message.into()));
195    }
196
197    /// Log an error message
198    pub fn error(&self, message: impl Into<String>) {
199        self.log(LogLevel::Error, serde_json::json!(message.into()));
200    }
201}
202
203/// Create a notification channel
204pub fn create_notification_channel() -> (NotificationSender, NotificationReceiver) {
205    broadcast::channel(100)
206}