gestura_core_llm/
token_tracker.rs1use crate::TokenUsage;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::VecDeque;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13const MAX_USAGE_HISTORY: usize = 1000;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct UsageRecord {
19 pub timestamp: DateTime<Utc>,
21 pub usage: TokenUsage,
23 pub session_id: Option<String>,
25}
26
27#[derive(Debug, Clone, Default, Serialize, Deserialize)]
29pub struct UsageStats {
30 pub total_input_tokens: u64,
32 pub total_output_tokens: u64,
34 pub total_tokens: u64,
36 pub estimated_cost_usd: f64,
38 pub call_count: u64,
40 pub avg_tokens_per_call: f64,
42}
43
44impl UsageStats {
45 pub fn add(&mut self, usage: &TokenUsage) {
47 self.total_input_tokens += usage.input_tokens as u64;
48 self.total_output_tokens += usage.output_tokens as u64;
49 self.total_tokens += usage.total_tokens as u64;
50 self.estimated_cost_usd += usage.estimated_cost_usd.unwrap_or(0.0);
51 self.call_count += 1;
52 if self.call_count > 0 {
53 self.avg_tokens_per_call = self.total_tokens as f64 / self.call_count as f64;
54 }
55 }
56
57 pub fn format_summary(&self) -> String {
59 format!(
60 "{}↓ {}↑ | ${:.4} | {} calls",
61 format_token_count(self.total_input_tokens),
62 format_token_count(self.total_output_tokens),
63 self.estimated_cost_usd,
64 self.call_count
65 )
66 }
67
68 pub fn format_compact(&self) -> String {
70 format!(
71 "{}tok ${:.2}",
72 format_token_count(self.total_tokens),
73 self.estimated_cost_usd
74 )
75 }
76}
77
78pub fn format_token_count(tokens: u64) -> String {
80 if tokens >= 1_000_000 {
81 format!("{:.1}M", tokens as f64 / 1_000_000.0)
82 } else if tokens >= 1_000 {
83 format!("{:.1}K", tokens as f64 / 1_000.0)
84 } else {
85 tokens.to_string()
86 }
87}
88
89#[derive(Debug)]
91pub struct TokenTracker {
92 history: RwLock<VecDeque<UsageRecord>>,
94 session_stats: RwLock<UsageStats>,
96 global_stats: RwLock<UsageStats>,
98 session_id: RwLock<Option<String>>,
100 daily_budget_usd: RwLock<Option<f64>>,
102 today_stats: RwLock<UsageStats>,
104 today_date: RwLock<chrono::NaiveDate>,
106}
107
108impl Default for TokenTracker {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl TokenTracker {
115 pub fn new() -> Self {
117 Self {
118 history: RwLock::new(VecDeque::with_capacity(MAX_USAGE_HISTORY)),
119 session_stats: RwLock::new(UsageStats::default()),
120 global_stats: RwLock::new(UsageStats::default()),
121 session_id: RwLock::new(None),
122 daily_budget_usd: RwLock::new(None),
123 today_stats: RwLock::new(UsageStats::default()),
124 today_date: RwLock::new(Utc::now().date_naive()),
125 }
126 }
127
128 pub async fn set_session(&self, session_id: impl Into<String>) {
130 let mut id = self.session_id.write().await;
131 *id = Some(session_id.into());
132 let mut stats = self.session_stats.write().await;
134 *stats = UsageStats::default();
135 }
136
137 pub async fn set_daily_budget(&self, budget_usd: f64) {
139 let mut budget = self.daily_budget_usd.write().await;
140 *budget = Some(budget_usd);
141 }
142
143 pub async fn record_usage(&self, usage: TokenUsage) {
145 let now = Utc::now();
146 let session_id = self.session_id.read().await.clone();
147
148 let record = UsageRecord {
150 timestamp: now,
151 usage: usage.clone(),
152 session_id,
153 };
154
155 let mut history = self.history.write().await;
157 if history.len() >= MAX_USAGE_HISTORY {
158 history.pop_front();
159 }
160 history.push_back(record);
161 drop(history);
162
163 let mut session_stats = self.session_stats.write().await;
165 session_stats.add(&usage);
166 drop(session_stats);
167
168 let mut global_stats = self.global_stats.write().await;
170 global_stats.add(&usage);
171 drop(global_stats);
172
173 let today = now.date_naive();
175 let mut today_date = self.today_date.write().await;
176 if *today_date != today {
177 *today_date = today;
178 let mut today_stats = self.today_stats.write().await;
179 *today_stats = UsageStats::default();
180 }
181 drop(today_date);
182
183 let mut today_stats = self.today_stats.write().await;
185 today_stats.add(&usage);
186 }
187
188 pub async fn get_session_stats(&self) -> UsageStats {
190 self.session_stats.read().await.clone()
191 }
192
193 pub async fn get_global_stats(&self) -> UsageStats {
195 self.global_stats.read().await.clone()
196 }
197
198 pub async fn get_today_stats(&self) -> UsageStats {
200 self.today_stats.read().await.clone()
201 }
202
203 pub async fn check_budget(&self, estimated_cost: f64) -> BudgetStatus {
205 let budget = self.daily_budget_usd.read().await;
206 match *budget {
207 None => BudgetStatus::NoBudgetSet,
208 Some(limit) => {
209 let today = self.today_stats.read().await;
210 let current = today.estimated_cost_usd;
211 let projected = current + estimated_cost;
212 if projected > limit {
213 BudgetStatus::WouldExceed {
214 current,
215 projected,
216 limit,
217 }
218 } else if current > limit * 0.8 {
219 BudgetStatus::NearLimit {
220 current,
221 limit,
222 remaining: limit - current,
223 }
224 } else {
225 BudgetStatus::Ok {
226 current,
227 limit,
228 remaining: limit - current,
229 }
230 }
231 }
232 }
233 }
234
235 pub async fn get_recent_history(&self, count: usize) -> Vec<UsageRecord> {
237 let history = self.history.read().await;
238 history.iter().rev().take(count).cloned().collect()
239 }
240
241 pub async fn reset_session(&self) {
243 let mut stats = self.session_stats.write().await;
244 *stats = UsageStats::default();
245 }
246}
247
248#[derive(Debug, Clone)]
250pub enum BudgetStatus {
251 NoBudgetSet,
253 Ok {
255 current: f64,
256 limit: f64,
257 remaining: f64,
258 },
259 NearLimit {
261 current: f64,
262 limit: f64,
263 remaining: f64,
264 },
265 WouldExceed {
267 current: f64,
268 projected: f64,
269 limit: f64,
270 },
271}
272
273static TOKEN_TRACKER: tokio::sync::OnceCell<Arc<TokenTracker>> = tokio::sync::OnceCell::const_new();
275
276pub async fn get_token_tracker() -> &'static Arc<TokenTracker> {
278 TOKEN_TRACKER
279 .get_or_init(|| async { Arc::new(TokenTracker::new()) })
280 .await
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[tokio::test]
288 async fn test_token_tracker_basic() {
289 let tracker = TokenTracker::new();
290
291 let usage = TokenUsage::new(100, 50).with_cost(0.001);
292 tracker.record_usage(usage).await;
293
294 let stats = tracker.get_session_stats().await;
295 assert_eq!(stats.total_input_tokens, 100);
296 assert_eq!(stats.total_output_tokens, 50);
297 assert_eq!(stats.total_tokens, 150);
298 assert_eq!(stats.call_count, 1);
299 }
300
301 #[tokio::test]
302 async fn test_budget_tracking() {
303 let tracker = TokenTracker::new();
304 tracker.set_daily_budget(1.0).await;
305
306 let usage = TokenUsage::new(10000, 5000).with_cost(0.85);
308 tracker.record_usage(usage).await;
309
310 let status = tracker.check_budget(0.05).await;
312 assert!(matches!(status, BudgetStatus::NearLimit { .. }));
313 }
314
315 #[test]
316 fn test_format_token_count() {
317 assert_eq!(format_token_count(500), "500");
318 assert_eq!(format_token_count(1500), "1.5K");
319 assert_eq!(format_token_count(1_500_000), "1.5M");
320 }
321}