1use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
5use gestura_core_foundation::AppError;
6use std::path::Path;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant};
10
11#[cfg(target_os = "windows")]
12use core::ffi::c_void;
13
14const SILENCE_THRESHOLD: f32 = 0.005; const SILENCE_TIMEOUT_SECS: f32 = 4.0; const MAX_RECORDING_SECS: u64 = 120; const VAD_WINDOW_MS: u64 = 100; const WAIT_FOR_SPEECH_TIMEOUT_SECS: u64 = 30; const WHISPER_SAMPLE_RATE: u32 = 16000; lazy_static::lazy_static! {
24 static ref EXTERNAL_STOP_FLAG: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
25}
26
27#[cfg(target_os = "windows")]
28lazy_static::lazy_static! {
29 static ref WINDOWS_AUDIO_HOST_ACCESS: Mutex<()> = Mutex::new(());
30}
31
32#[cfg(target_os = "windows")]
33type HResult = i32;
34
35#[cfg(target_os = "windows")]
36const COINIT_MULTITHREADED: u32 = 0;
37
38#[cfg(target_os = "windows")]
39const S_OK: HResult = 0;
40
41#[cfg(target_os = "windows")]
42const S_FALSE: HResult = 1;
43
44#[cfg(target_os = "windows")]
45const RPC_E_CHANGED_MODE: HResult = -2_147_417_850;
46
47#[cfg(target_os = "windows")]
48#[link(name = "ole32")]
49unsafe extern "system" {
50 fn CoInitializeEx(pv_reserved: *mut c_void, coinit: u32) -> HResult;
51 fn CoUninitialize();
52}
53
54#[cfg(target_os = "windows")]
55struct WindowsComGuard {
56 should_uninitialize: bool,
57}
58
59#[cfg(target_os = "windows")]
60impl WindowsComGuard {
61 fn initialize_for_audio() -> Self {
62 let hr = unsafe { CoInitializeEx(std::ptr::null_mut(), COINIT_MULTITHREADED) };
63 match hr {
64 S_OK | S_FALSE => Self {
65 should_uninitialize: true,
66 },
67 RPC_E_CHANGED_MODE => {
68 tracing::debug!(
69 "Windows COM apartment already initialized with a different threading model; reusing current apartment for audio host access"
70 );
71 Self {
72 should_uninitialize: false,
73 }
74 }
75 other => {
76 tracing::warn!(
77 hresult = format_args!("{other:#010x}"),
78 "Failed to initialize COM before Windows audio host access; continuing with existing thread state"
79 );
80 Self {
81 should_uninitialize: false,
82 }
83 }
84 }
85 }
86}
87
88#[cfg(target_os = "windows")]
89impl Drop for WindowsComGuard {
90 fn drop(&mut self) {
91 if self.should_uninitialize {
92 unsafe { CoUninitialize() };
93 }
94 }
95}
96
97fn with_audio_host_access<T>(operation: impl FnOnce() -> T) -> T {
98 #[cfg(target_os = "windows")]
99 {
100 let _audio_lock = WINDOWS_AUDIO_HOST_ACCESS
101 .lock()
102 .unwrap_or_else(|poisoned| poisoned.into_inner());
103 let _com_guard = WindowsComGuard::initialize_for_audio();
104 operation()
105 }
106
107 #[cfg(not(target_os = "windows"))]
108 {
109 operation()
110 }
111}
112
113pub fn request_stop_recording() {
115 tracing::info!("External stop requested for audio recording");
116 EXTERNAL_STOP_FLAG.store(true, Ordering::SeqCst);
117}
118
119pub fn reset_stop_flag() {
121 EXTERNAL_STOP_FLAG.store(false, Ordering::SeqCst);
122}
123
124pub fn is_stop_requested() -> bool {
126 EXTERNAL_STOP_FLAG.load(Ordering::SeqCst)
127}
128
129#[derive(Debug, Clone)]
131pub struct AudioCaptureConfig {
132 pub device_name: Option<String>,
134 pub silence_threshold: f32,
136 pub silence_timeout_secs: f32,
138 pub max_recording_secs: u64,
140 pub wait_for_speech_timeout_secs: u64,
142}
143
144impl Default for AudioCaptureConfig {
145 fn default() -> Self {
146 Self {
147 device_name: None,
148 silence_threshold: SILENCE_THRESHOLD,
149 silence_timeout_secs: SILENCE_TIMEOUT_SECS,
150 max_recording_secs: MAX_RECORDING_SECS,
151 wait_for_speech_timeout_secs: WAIT_FOR_SPEECH_TIMEOUT_SECS,
152 }
153 }
154}
155
156pub async fn record_audio(
162 _duration: Duration,
163 output_path: &Path,
164 config: AudioCaptureConfig,
165) -> Result<f32, AppError> {
166 let output_path = output_path.to_path_buf();
167
168 reset_stop_flag();
170
171 tokio::task::spawn_blocking(move || record_audio_with_vad(&output_path, &config))
174 .await
175 .map_err(|e| AppError::Voice(format!("Recording task failed: {}", e)))?
176}
177
178fn calculate_rms(samples: &[f32]) -> f32 {
180 if samples.is_empty() {
181 return 0.0;
182 }
183 let sum_squares: f32 = samples.iter().map(|s| s * s).sum();
184 (sum_squares / samples.len() as f32).sqrt()
185}
186
187struct VadState {
189 last_speech_time: Instant,
190 has_detected_speech: bool,
191 recording_start: Instant,
192 has_logged_max_duration: bool,
193 last_rms_log_time: Instant,
194 peak_rms: f32,
195}
196
197impl VadState {
198 fn new() -> Self {
199 let now = Instant::now();
200 Self {
201 last_speech_time: now,
202 has_detected_speech: false,
203 recording_start: now,
204 has_logged_max_duration: false,
205 last_rms_log_time: now,
206 peak_rms: 0.0,
207 }
208 }
209}
210
211#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
213pub struct AudioDeviceInfo {
214 pub name: String,
215 pub is_default: bool,
216}
217
218pub fn is_microphone_available() -> bool {
220 with_audio_host_access(|| {
221 let host = cpal::default_host();
222 host.default_input_device().is_some()
223 })
224}
225
226pub fn list_audio_input_devices() -> Vec<AudioDeviceInfo> {
228 with_audio_host_access(|| {
229 let host = cpal::default_host();
230 let default_device_name = host.default_input_device().and_then(|d| d.name().ok());
231
232 let mut devices = Vec::new();
233
234 if let Ok(input_devices) = host.input_devices() {
235 for device in input_devices {
236 if let Ok(name) = device.name() {
237 let is_default = default_device_name
238 .as_ref()
239 .map(|d| d == &name)
240 .unwrap_or(false);
241 devices.push(AudioDeviceInfo { name, is_default });
242 }
243 }
244 }
245
246 devices
247 })
248}
249
250fn record_audio_with_vad(output_path: &Path, config: &AudioCaptureConfig) -> Result<f32, AppError> {
252 with_audio_host_access(|| {
253 let host = cpal::default_host();
255
256 let device = if let Some(ref name) = config.device_name {
258 let found = host
259 .input_devices()
260 .ok()
261 .and_then(|mut devices| devices.find(|d| d.name().ok().as_deref() == Some(name)));
262
263 if let Some(dev) = found {
264 tracing::info!("Using configured audio input device: {}", name);
265 dev
266 } else {
267 tracing::warn!("Configured device '{}' not found, using default", name);
268 host.default_input_device()
269 .ok_or_else(|| AppError::Voice("No input device available".into()))?
270 }
271 } else {
272 host.default_input_device()
273 .ok_or_else(|| AppError::Voice("No input device available".into()))?
274 };
275
276 tracing::info!("Using audio input device: {:?}", device.name());
277
278 let device_config = device
280 .default_input_config()
281 .map_err(|e| AppError::Voice(format!("Failed to get config: {}", e)))?;
282
283 let sample_rate = device_config.sample_rate().0;
284 let channels = device_config.channels();
285
286 tracing::info!("Audio config: {}Hz, {} channels", sample_rate, channels);
287
288 let samples: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(Vec::new()));
290 let samples_clone = Arc::clone(&samples);
291
292 let vad_state = Arc::new(Mutex::new(VadState::new()));
294 let vad_state_clone = Arc::clone(&vad_state);
295
296 let should_stop = Arc::new(AtomicBool::new(false));
298 let should_stop_clone = Arc::clone(&should_stop);
299
300 let samples_per_window =
302 (sample_rate as u64 * channels as u64 * VAD_WINDOW_MS / 1000) as usize;
303
304 let vad_buffer: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(Vec::new()));
306 let vad_buffer_clone = Arc::clone(&vad_buffer);
307
308 let silence_threshold = config.silence_threshold;
310 let silence_timeout = config.silence_timeout_secs;
311 let max_recording = config.max_recording_secs;
312 let wait_for_speech = config.wait_for_speech_timeout_secs;
313
314 let stream = match device_config.sample_format() {
316 cpal::SampleFormat::F32 => device
317 .build_input_stream(
318 &device_config.clone().into(),
319 move |data: &[f32], _: &cpal::InputCallbackInfo| {
320 process_audio_data(
321 data,
322 &samples_clone,
323 &vad_buffer_clone,
324 &vad_state_clone,
325 &should_stop_clone,
326 samples_per_window,
327 silence_threshold,
328 silence_timeout,
329 max_recording,
330 wait_for_speech,
331 );
332 },
333 |err| {
334 tracing::error!("Audio stream error: {}", err);
335 },
336 None,
337 )
338 .map_err(|e| AppError::Voice(format!("Failed to build stream: {}", e)))?,
339 cpal::SampleFormat::I16 => {
340 let samples_clone_i16 = Arc::clone(&samples);
341 let vad_buffer_i16 = Arc::clone(&vad_buffer);
342 let vad_state_i16 = Arc::clone(&vad_state);
343 let should_stop_i16 = Arc::clone(&should_stop);
344
345 device
346 .build_input_stream(
347 &device_config.clone().into(),
348 move |data: &[i16], _: &cpal::InputCallbackInfo| {
349 let f32_data: Vec<f32> =
350 data.iter().map(|&s| s as f32 / i16::MAX as f32).collect();
351 process_audio_data(
352 &f32_data,
353 &samples_clone_i16,
354 &vad_buffer_i16,
355 &vad_state_i16,
356 &should_stop_i16,
357 samples_per_window,
358 silence_threshold,
359 silence_timeout,
360 max_recording,
361 wait_for_speech,
362 );
363 },
364 |err| {
365 tracing::error!("Audio stream error: {}", err);
366 },
367 None,
368 )
369 .map_err(|e| AppError::Voice(format!("Failed to build stream: {}", e)))?
370 }
371 _ => return Err(AppError::Voice("Unsupported sample format".into())),
372 };
373
374 stream
376 .play()
377 .map_err(|e| AppError::Voice(format!("Failed to start stream: {}", e)))?;
378
379 tracing::info!(
380 "Recording with VAD - will stop after {}s of silence...",
381 config.silence_timeout_secs
382 );
383
384 loop {
386 std::thread::sleep(Duration::from_millis(100));
387
388 if should_stop.load(Ordering::SeqCst) {
390 tracing::info!("Recording stopped by VAD (silence/max duration)");
391 break;
392 }
393
394 if is_stop_requested() {
396 tracing::info!("Recording stopped by external request");
397 should_stop.store(true, Ordering::SeqCst);
398 break;
399 }
400
401 let state = vad_state.lock().unwrap();
403 if state.recording_start.elapsed().as_secs() >= config.max_recording_secs {
404 break;
405 }
406 }
407
408 let _ = stream.pause();
410 drop(stream);
411 tracing::info!("Audio stream stopped and microphone released");
412
413 let recorded_samples = samples.lock().unwrap();
415 let sample_count = recorded_samples.len();
416 let duration_secs = sample_count as f32 / (sample_rate as f32 * channels as f32);
417
418 tracing::info!("Recorded {} samples ({:.2}s)", sample_count, duration_secs);
419
420 if sample_count == 0 {
422 if is_stop_requested() {
423 return Err(AppError::Voice("Recording cancelled by user".into()));
424 }
425 return Err(AppError::Voice("No audio captured".into()));
426 }
427
428 let resampled = resample_to_16khz(&recorded_samples, sample_rate, channels);
430
431 save_samples_to_wav(&resampled, WHISPER_SAMPLE_RATE, 1, output_path)?;
433
434 Ok(duration_secs)
435 })
436}
437
438#[allow(clippy::too_many_arguments)]
440fn process_audio_data(
441 data: &[f32],
442 samples: &Arc<Mutex<Vec<f32>>>,
443 vad_buffer: &Arc<Mutex<Vec<f32>>>,
444 vad_state: &Arc<Mutex<VadState>>,
445 should_stop: &Arc<AtomicBool>,
446 samples_per_window: usize,
447 silence_threshold: f32,
448 silence_timeout: f32,
449 max_recording: u64,
450 wait_for_speech: u64,
451) {
452 if should_stop.load(Ordering::SeqCst) {
454 return;
455 }
456
457 {
459 let mut buffer = samples.lock().unwrap();
460 buffer.extend_from_slice(data);
461 }
462
463 {
465 let mut vad_buf = vad_buffer.lock().unwrap();
466 vad_buf.extend_from_slice(data);
467
468 while vad_buf.len() >= samples_per_window {
470 let window: Vec<f32> = vad_buf.drain(..samples_per_window).collect();
471 let rms = calculate_rms(&window);
472
473 let mut state = vad_state.lock().unwrap();
474 let now = Instant::now();
475
476 if rms > state.peak_rms {
478 state.peak_rms = rms;
479 }
480
481 if now.duration_since(state.last_rms_log_time).as_secs() >= 2 {
483 tracing::info!(
484 "VAD status: current_rms={:.4}, peak_rms={:.4}, threshold={:.4}, speech_detected={}",
485 rms,
486 state.peak_rms,
487 silence_threshold,
488 state.has_detected_speech
489 );
490 state.last_rms_log_time = now;
491 }
492
493 if rms > silence_threshold {
494 state.last_speech_time = now;
496 if !state.has_detected_speech {
497 state.has_detected_speech = true;
498 tracing::info!(
499 "🎤 Speech detected! (RMS: {:.4} > threshold: {:.4})",
500 rms,
501 silence_threshold
502 );
503 }
504 } else if state.has_detected_speech {
505 let silence_duration = now.duration_since(state.last_speech_time);
507 if silence_duration.as_secs_f32() >= silence_timeout
508 && !should_stop.load(Ordering::SeqCst)
509 {
510 tracing::info!(
511 "🔇 Silence timeout reached ({:.1}s) - stopping recording",
512 silence_duration.as_secs_f32()
513 );
514 should_stop.store(true, Ordering::SeqCst);
515 }
516 } else {
517 let waiting_duration = now.duration_since(state.recording_start).as_secs();
519 if waiting_duration >= wait_for_speech {
520 tracing::warn!(
521 "⏱️ No speech detected after {}s (peak_rms={:.4}, threshold={:.4}) - stopping",
522 waiting_duration,
523 state.peak_rms,
524 silence_threshold
525 );
526 should_stop.store(true, Ordering::SeqCst);
527 }
528 }
529
530 if now.duration_since(state.recording_start).as_secs() >= max_recording {
532 if !state.has_logged_max_duration {
533 tracing::info!("Max recording duration reached");
534 state.has_logged_max_duration = true;
535 }
536 should_stop.store(true, Ordering::SeqCst);
537 }
538 }
539 }
540}
541
542fn resample_to_16khz(samples: &[f32], source_rate: u32, channels: u16) -> Vec<f32> {
545 if samples.is_empty() {
546 return Vec::new();
547 }
548
549 let mono_samples: Vec<f32> = if channels > 1 {
551 samples
552 .chunks(channels as usize)
553 .map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
554 .collect()
555 } else {
556 samples.to_vec()
557 };
558
559 if source_rate == WHISPER_SAMPLE_RATE {
561 tracing::info!("Audio already at 16kHz, no resampling needed");
562 return mono_samples;
563 }
564
565 let ratio = source_rate as f64 / WHISPER_SAMPLE_RATE as f64;
567 let output_len = (mono_samples.len() as f64 / ratio).ceil() as usize;
568 let mut resampled = Vec::with_capacity(output_len);
569
570 tracing::info!(
571 "Resampling audio from {}Hz to {}Hz ({} -> {} samples)",
572 source_rate,
573 WHISPER_SAMPLE_RATE,
574 mono_samples.len(),
575 output_len
576 );
577
578 for i in 0..output_len {
580 let src_pos = i as f64 * ratio;
581 let src_idx = src_pos.floor() as usize;
582 let frac = (src_pos - src_idx as f64) as f32;
583
584 if src_idx + 1 < mono_samples.len() {
585 let sample = mono_samples[src_idx] * (1.0 - frac) + mono_samples[src_idx + 1] * frac;
587 resampled.push(sample);
588 } else if src_idx < mono_samples.len() {
589 resampled.push(mono_samples[src_idx]);
590 }
591 }
592
593 resampled
594}
595
596fn save_samples_to_wav(
598 samples: &[f32],
599 sample_rate: u32,
600 channels: u16,
601 path: &Path,
602) -> Result<(), AppError> {
603 let spec = hound::WavSpec {
604 channels,
605 sample_rate,
606 bits_per_sample: 16,
607 sample_format: hound::SampleFormat::Int,
608 };
609
610 let mut writer = hound::WavWriter::create(path, spec)
611 .map_err(|e| AppError::Voice(format!("Failed to create WAV: {}", e)))?;
612
613 for sample in samples {
615 let sample_i16 = (*sample * i16::MAX as f32) as i16;
616 writer
617 .write_sample(sample_i16)
618 .map_err(|e| AppError::Voice(format!("Failed to write sample: {}", e)))?;
619 }
620
621 writer
622 .finalize()
623 .map_err(|e| AppError::Voice(format!("Failed to finalize WAV: {}", e)))?;
624
625 tracing::info!(
626 "Saved audio to {:?} ({}Hz, {} channels)",
627 path,
628 sample_rate,
629 channels
630 );
631 Ok(())
632}