1use std::collections::{HashMap, HashSet};
12use std::sync::RwLock;
13use std::time::Instant;
14
15use lazy_static::lazy_static;
16use tokio::sync::oneshot;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25pub enum ToolConfirmationDecision {
26 AllowOnce,
28 AllowSession,
30 AllowAlways,
32 DenyOnce,
34 DenySession,
36}
37
38impl ToolConfirmationDecision {
39 pub fn is_allowed(self) -> bool {
41 matches!(
42 self,
43 Self::AllowOnce | Self::AllowSession | Self::AllowAlways
44 )
45 }
46
47 pub fn as_str(self) -> &'static str {
49 match self {
50 Self::AllowOnce => "allow_once",
51 Self::AllowSession => "allow_session",
52 Self::AllowAlways => "allow_always",
53 Self::DenyOnce => "deny_once",
54 Self::DenySession => "deny_session",
55 }
56 }
57
58 pub fn parse(input: &str) -> Result<Self, String> {
62 let normalized = input.trim().to_ascii_lowercase();
63 match normalized.as_str() {
64 "allow" | "allow_once" | "once_allow" => Ok(Self::AllowOnce),
65 "allow_session" | "session_allow" => Ok(Self::AllowSession),
66 "allow_always" | "always_allow" | "allow_forever" => Ok(Self::AllowAlways),
67 "deny" | "deny_once" | "once_deny" => Ok(Self::DenyOnce),
68 "deny_session" | "session_deny" | "block_session" => Ok(Self::DenySession),
69 other => Err(format!(
70 "Unknown tool confirmation decision '{other}'. Expected one of: allow_once, allow_session, allow_always, deny_once, deny_session"
71 )),
72 }
73 }
74}
75
76impl From<bool> for ToolConfirmationDecision {
77 fn from(value: bool) -> Self {
78 if value {
79 Self::AllowOnce
80 } else {
81 Self::DenyOnce
82 }
83 }
84}
85
86#[derive(Debug)]
88pub struct PendingToolConfirmation {
89 pub session_id: Option<String>,
91 pub tool_name: String,
93 pub tool_args: String,
95 pub created_at: Instant,
97 sender: oneshot::Sender<ToolConfirmationDecision>,
99}
100
101#[derive(Debug, Default)]
106pub struct ToolConfirmationManager {
107 pending: RwLock<HashMap<String, PendingToolConfirmation>>,
108 session_confirmed: RwLock<HashMap<String, HashSet<String>>>,
109 session_blocked: RwLock<HashMap<String, HashSet<String>>>,
110}
111
112impl ToolConfirmationManager {
113 pub fn new() -> Self {
115 Self {
116 pending: RwLock::new(HashMap::new()),
117 session_confirmed: RwLock::new(HashMap::new()),
118 session_blocked: RwLock::new(HashMap::new()),
119 }
120 }
121
122 pub fn register(
126 &self,
127 confirmation_id: String,
128 session_id: Option<String>,
129 tool_name: String,
130 tool_args: String,
131 ) -> oneshot::Receiver<ToolConfirmationDecision> {
132 let (tx, rx) = oneshot::channel();
133 let pending = PendingToolConfirmation {
134 session_id,
135 tool_name,
136 tool_args,
137 created_at: Instant::now(),
138 sender: tx,
139 };
140
141 if let Ok(mut map) = self.pending.write() {
142 map.insert(confirmation_id, pending);
143 }
144 rx
145 }
146
147 pub fn resolve(
152 &self,
153 confirmation_id: &str,
154 session_id: Option<&str>,
155 approved: bool,
156 ) -> Result<(), String> {
157 self.resolve_decision(
158 confirmation_id,
159 session_id,
160 ToolConfirmationDecision::from(approved),
161 )
162 }
163
164 pub fn resolve_decision(
169 &self,
170 confirmation_id: &str,
171 session_id: Option<&str>,
172 decision: ToolConfirmationDecision,
173 ) -> Result<(), String> {
174 let pending = {
175 let mut map = self
176 .pending
177 .write()
178 .map_err(|_| "tool confirmation manager poisoned".to_string())?;
179 map.remove(confirmation_id)
180 }
181 .ok_or_else(|| format!("Unknown confirmation id: {confirmation_id}"))?;
182
183 if let Some(expected) = pending.session_id.as_deref()
184 && let Some(got) = session_id
185 && expected != got
186 {
187 return Err(format!(
188 "Session mismatch for confirmation {confirmation_id}: expected {expected}, got {got}"
189 ));
190 }
191
192 let _ = pending.sender.send(decision);
194 Ok(())
195 }
196
197 pub fn apply_session_policy_decision(
202 &self,
203 session_id: &str,
204 tool_name: &str,
205 decision: ToolConfirmationDecision,
206 ) {
207 match decision {
208 ToolConfirmationDecision::AllowSession | ToolConfirmationDecision::AllowAlways => {
209 if let Ok(mut map) = self.session_confirmed.write() {
210 map.entry(session_id.to_string())
211 .or_default()
212 .insert(tool_name.to_string());
213 }
214 if let Ok(mut map) = self.session_blocked.write()
216 && let Some(set) = map.get_mut(session_id)
217 {
218 set.remove(tool_name);
219 }
220 }
221 ToolConfirmationDecision::DenySession => {
222 if let Ok(mut map) = self.session_blocked.write() {
223 map.entry(session_id.to_string())
224 .or_default()
225 .insert(tool_name.to_string());
226 }
227 if let Ok(mut map) = self.session_confirmed.write()
229 && let Some(set) = map.get_mut(session_id)
230 {
231 set.remove(tool_name);
232 }
233 }
234 ToolConfirmationDecision::AllowOnce | ToolConfirmationDecision::DenyOnce => {}
235 }
236 }
237
238 pub fn is_tool_allowed_for_session(&self, session_id: &str, tool_name: &str) -> bool {
240 self.session_confirmed
241 .read()
242 .ok()
243 .and_then(|m| m.get(session_id).cloned())
244 .is_some_and(|set| set.contains(tool_name))
245 }
246
247 pub fn is_tool_blocked_for_session(&self, session_id: &str, tool_name: &str) -> bool {
249 self.session_blocked
250 .read()
251 .ok()
252 .and_then(|m| m.get(session_id).cloned())
253 .is_some_and(|set| set.contains(tool_name))
254 }
255
256 pub fn abandon(&self, confirmation_id: &str) {
260 if let Ok(mut map) = self.pending.write() {
261 map.remove(confirmation_id);
262 }
263 }
264
265 pub fn pending_count(&self) -> usize {
267 self.pending.read().map(|m| m.len()).unwrap_or_default()
268 }
269}
270
271lazy_static! {
272 pub static ref TOOL_CONFIRMATIONS: ToolConfirmationManager = ToolConfirmationManager::new();
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[tokio::test]
281 async fn register_and_resolve_allows() {
282 let mgr = ToolConfirmationManager::new();
283 let id = "c1".to_string();
284 let rx = mgr.register(
285 id.clone(),
286 Some("s1".to_string()),
287 "shell".to_string(),
288 "{}".to_string(),
289 );
290 mgr.resolve(&id, Some("s1"), true).unwrap();
291 assert!(rx.await.unwrap().is_allowed());
292 assert_eq!(mgr.pending_count(), 0);
293 }
294
295 #[tokio::test]
296 async fn resolve_rejects_session_mismatch() {
297 let mgr = ToolConfirmationManager::new();
298 let id = "c2".to_string();
299 let _rx = mgr.register(
300 id.clone(),
301 Some("s1".to_string()),
302 "file".to_string(),
303 "{}".to_string(),
304 );
305 let err = mgr.resolve(&id, Some("s2"), true).unwrap_err();
306 assert!(err.contains("Session mismatch"));
307 }
308
309 #[tokio::test]
310 async fn resolve_allows_missing_session_id_when_confirmation_id_matches() {
311 let mgr = ToolConfirmationManager::new();
312 let id = "c2b".to_string();
313 let rx = mgr.register(
314 id.clone(),
315 Some("s1".to_string()),
316 "shell".to_string(),
317 "{}".to_string(),
318 );
319
320 mgr.resolve(&id, None, true).unwrap();
321 assert!(rx.await.unwrap().is_allowed());
322 }
323
324 #[tokio::test]
325 async fn abandon_removes_pending() {
326 let mgr = ToolConfirmationManager::new();
327 let id = "c3".to_string();
328 let _rx = mgr.register(id.clone(), None, "file".to_string(), "{}".to_string());
329 assert_eq!(mgr.pending_count(), 1);
330 mgr.abandon(&id);
331 assert_eq!(mgr.pending_count(), 0);
332 }
333
334 #[test]
335 fn session_policy_is_recorded() {
336 let mgr = ToolConfirmationManager::new();
337 mgr.apply_session_policy_decision("s1", "file", ToolConfirmationDecision::AllowSession);
338 assert!(mgr.is_tool_allowed_for_session("s1", "file"));
339 assert!(!mgr.is_tool_blocked_for_session("s1", "file"));
340
341 mgr.apply_session_policy_decision("s1", "file", ToolConfirmationDecision::DenySession);
342 assert!(!mgr.is_tool_allowed_for_session("s1", "file"));
343 assert!(mgr.is_tool_blocked_for_session("s1", "file"));
344 }
345}