gestura_core_mcp/
notifications.rs1use super::types::{
5 CancelledNotification, LogLevel, LoggingMessage, ProgressNotification, ProgressToken,
6};
7use std::collections::HashMap;
8use std::sync::RwLock;
9use tokio::sync::broadcast;
10
11pub type NotificationSender = broadcast::Sender<McpNotification>;
13pub type NotificationReceiver = broadcast::Receiver<McpNotification>;
15
16#[derive(Debug, Clone, serde::Serialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum McpNotification {
20 Progress(ProgressNotification),
22 Log(LoggingMessage),
24 Cancelled(CancelledNotification),
26 ToolsListChanged,
28 ResourcesListChanged,
30 PromptsListChanged,
32}
33
34#[derive(Debug)]
36pub struct ProgressTracker {
37 active_operations: RwLock<HashMap<String, OperationProgress>>,
38 sender: NotificationSender,
39}
40
41#[derive(Debug, Clone)]
43pub struct OperationProgress {
44 pub token: ProgressToken,
45 pub current: f64,
46 pub total: Option<f64>,
47 pub message: Option<String>,
48 pub cancelled: bool,
49}
50
51impl ProgressTracker {
52 pub fn new(sender: NotificationSender) -> Self {
54 Self {
55 active_operations: RwLock::new(HashMap::new()),
56 sender,
57 }
58 }
59
60 pub fn start_operation(&self, token: impl Into<ProgressToken>, total: Option<f64>) -> String {
62 let token = token.into();
63 let id = match &token {
64 ProgressToken::String(s) => s.clone(),
65 ProgressToken::Integer(i) => i.to_string(),
66 };
67
68 let progress = OperationProgress {
69 token: token.clone(),
70 current: 0.0,
71 total,
72 message: None,
73 cancelled: false,
74 };
75
76 if let Ok(mut ops) = self.active_operations.write() {
77 ops.insert(id.clone(), progress);
78 }
79
80 id
81 }
82
83 pub fn update_progress(&self, id: &str, current: f64, message: Option<String>) {
85 let notification = {
86 let mut ops = match self.active_operations.write() {
87 Ok(ops) => ops,
88 Err(_) => return,
89 };
90
91 if let Some(op) = ops.get_mut(id) {
92 op.current = current;
93 op.message = message.clone();
94
95 Some(ProgressNotification {
96 progress_token: op.token.clone(),
97 progress: current,
98 total: op.total,
99 message,
100 })
101 } else {
102 None
103 }
104 };
105
106 if let Some(notif) = notification {
107 let _ = self.sender.send(McpNotification::Progress(notif));
108 }
109 }
110
111 pub fn complete_operation(&self, id: &str) {
113 if let Ok(mut ops) = self.active_operations.write()
114 && let Some(op) = ops.remove(id)
115 {
116 let _ = self
117 .sender
118 .send(McpNotification::Progress(ProgressNotification {
119 progress_token: op.token,
120 progress: op.total.unwrap_or(100.0),
121 total: op.total,
122 message: Some("Complete".to_string()),
123 }));
124 }
125 }
126
127 pub fn cancel_operation(&self, id: &str, reason: Option<String>) -> bool {
129 if let Ok(mut ops) = self.active_operations.write()
130 && let Some(op) = ops.get_mut(id)
131 {
132 op.cancelled = true;
133 let _ = self
134 .sender
135 .send(McpNotification::Cancelled(CancelledNotification {
136 request_id: serde_json::json!(id),
137 reason,
138 }));
139 return true;
140 }
141
142 false
143 }
144
145 pub fn is_cancelled(&self, id: &str) -> bool {
147 if let Ok(ops) = self.active_operations.read()
148 && let Some(op) = ops.get(id)
149 {
150 return op.cancelled;
151 }
152
153 false
154 }
155}
156
157#[derive(Debug, Clone)]
159pub struct McpLogger {
160 sender: NotificationSender,
161 logger_name: Option<String>,
162}
163
164impl McpLogger {
165 pub fn new(sender: NotificationSender, logger_name: Option<String>) -> Self {
167 Self {
168 sender,
169 logger_name,
170 }
171 }
172
173 pub fn log(&self, level: LogLevel, data: impl Into<serde_json::Value>) {
175 let _ = self.sender.send(McpNotification::Log(LoggingMessage {
176 level,
177 logger: self.logger_name.clone(),
178 data: data.into(),
179 }));
180 }
181
182 pub fn debug(&self, message: impl Into<String>) {
184 self.log(LogLevel::Debug, serde_json::json!(message.into()));
185 }
186
187 pub fn info(&self, message: impl Into<String>) {
189 self.log(LogLevel::Info, serde_json::json!(message.into()));
190 }
191
192 pub fn warning(&self, message: impl Into<String>) {
194 self.log(LogLevel::Warning, serde_json::json!(message.into()));
195 }
196
197 pub fn error(&self, message: impl Into<String>) {
199 self.log(LogLevel::Error, serde_json::json!(message.into()));
200 }
201}
202
203pub fn create_notification_channel() -> (NotificationSender, NotificationReceiver) {
205 broadcast::channel(100)
206}