gestura_core_audio/
audio_capture.rs

1//! Audio capture module for microphone input
2//! Uses cpal for cross-platform audio recording with voice activity detection (VAD)
3
4use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
5use gestura_core_foundation::AppError;
6use std::path::Path;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant};
10
11#[cfg(target_os = "windows")]
12use core::ffi::c_void;
13
14/// Silence detection configuration
15const SILENCE_THRESHOLD: f32 = 0.005; // RMS threshold for detecting silence (lowered for sensitivity)
16const SILENCE_TIMEOUT_SECS: f32 = 4.0; // Stop recording after 4 seconds of silence
17const MAX_RECORDING_SECS: u64 = 120; // Maximum recording duration (2 minutes)
18const VAD_WINDOW_MS: u64 = 100; // Window size for VAD analysis
19const WAIT_FOR_SPEECH_TIMEOUT_SECS: u64 = 30; // Timeout if no speech detected after 30 seconds
20const WHISPER_SAMPLE_RATE: u32 = 16000; // Whisper requires 16kHz audio
21
22// Global flag to signal external stop request (e.g., from "Stop Listening" button)
23lazy_static::lazy_static! {
24    static ref EXTERNAL_STOP_FLAG: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
25}
26
27#[cfg(target_os = "windows")]
28lazy_static::lazy_static! {
29    static ref WINDOWS_AUDIO_HOST_ACCESS: Mutex<()> = Mutex::new(());
30}
31
32#[cfg(target_os = "windows")]
33type HResult = i32;
34
35#[cfg(target_os = "windows")]
36const COINIT_MULTITHREADED: u32 = 0;
37
38#[cfg(target_os = "windows")]
39const S_OK: HResult = 0;
40
41#[cfg(target_os = "windows")]
42const S_FALSE: HResult = 1;
43
44#[cfg(target_os = "windows")]
45const RPC_E_CHANGED_MODE: HResult = -2_147_417_850;
46
47#[cfg(target_os = "windows")]
48#[link(name = "ole32")]
49unsafe extern "system" {
50    fn CoInitializeEx(pv_reserved: *mut c_void, coinit: u32) -> HResult;
51    fn CoUninitialize();
52}
53
54#[cfg(target_os = "windows")]
55struct WindowsComGuard {
56    should_uninitialize: bool,
57}
58
59#[cfg(target_os = "windows")]
60impl WindowsComGuard {
61    fn initialize_for_audio() -> Self {
62        let hr = unsafe { CoInitializeEx(std::ptr::null_mut(), COINIT_MULTITHREADED) };
63        match hr {
64            S_OK | S_FALSE => Self {
65                should_uninitialize: true,
66            },
67            RPC_E_CHANGED_MODE => {
68                tracing::debug!(
69                    "Windows COM apartment already initialized with a different threading model; reusing current apartment for audio host access"
70                );
71                Self {
72                    should_uninitialize: false,
73                }
74            }
75            other => {
76                tracing::warn!(
77                    hresult = format_args!("{other:#010x}"),
78                    "Failed to initialize COM before Windows audio host access; continuing with existing thread state"
79                );
80                Self {
81                    should_uninitialize: false,
82                }
83            }
84        }
85    }
86}
87
88#[cfg(target_os = "windows")]
89impl Drop for WindowsComGuard {
90    fn drop(&mut self) {
91        if self.should_uninitialize {
92            unsafe { CoUninitialize() };
93        }
94    }
95}
96
97fn with_audio_host_access<T>(operation: impl FnOnce() -> T) -> T {
98    #[cfg(target_os = "windows")]
99    {
100        let _audio_lock = WINDOWS_AUDIO_HOST_ACCESS
101            .lock()
102            .unwrap_or_else(|poisoned| poisoned.into_inner());
103        let _com_guard = WindowsComGuard::initialize_for_audio();
104        operation()
105    }
106
107    #[cfg(not(target_os = "windows"))]
108    {
109        operation()
110    }
111}
112
113/// Request the audio recording to stop from external code
114pub fn request_stop_recording() {
115    tracing::info!("External stop requested for audio recording");
116    EXTERNAL_STOP_FLAG.store(true, Ordering::SeqCst);
117}
118
119/// Reset the external stop flag (call before starting a new recording)
120pub fn reset_stop_flag() {
121    EXTERNAL_STOP_FLAG.store(false, Ordering::SeqCst);
122}
123
124/// Check if external stop was requested
125pub fn is_stop_requested() -> bool {
126    EXTERNAL_STOP_FLAG.load(Ordering::SeqCst)
127}
128
129/// Audio capture configuration
130#[derive(Debug, Clone)]
131pub struct AudioCaptureConfig {
132    /// Optional device name to use (None = default device)
133    pub device_name: Option<String>,
134    /// Silence threshold for VAD (RMS value)
135    pub silence_threshold: f32,
136    /// Seconds of silence before stopping
137    pub silence_timeout_secs: f32,
138    /// Maximum recording duration in seconds
139    pub max_recording_secs: u64,
140    /// Timeout for waiting for speech to start
141    pub wait_for_speech_timeout_secs: u64,
142}
143
144impl Default for AudioCaptureConfig {
145    fn default() -> Self {
146        Self {
147            device_name: None,
148            silence_threshold: SILENCE_THRESHOLD,
149            silence_timeout_secs: SILENCE_TIMEOUT_SECS,
150            max_recording_secs: MAX_RECORDING_SECS,
151            wait_for_speech_timeout_secs: WAIT_FOR_SPEECH_TIMEOUT_SECS,
152        }
153    }
154}
155
156/// Record audio from the microphone until user stops speaking (4 seconds of silence)
157/// Returns the duration of recorded audio in seconds
158///
159/// This function runs the entire recording process in a blocking task
160/// to avoid Send/Sync issues with cpal::Stream
161pub async fn record_audio(
162    _duration: Duration,
163    output_path: &Path,
164    config: AudioCaptureConfig,
165) -> Result<f32, AppError> {
166    let output_path = output_path.to_path_buf();
167
168    // Reset the external stop flag before starting a new recording
169    reset_stop_flag();
170
171    // Run the entire recording process in a blocking task
172    // because cpal::Stream is not Send
173    tokio::task::spawn_blocking(move || record_audio_with_vad(&output_path, &config))
174        .await
175        .map_err(|e| AppError::Voice(format!("Recording task failed: {}", e)))?
176}
177
178/// Calculate RMS (Root Mean Square) of audio samples - a measure of audio energy
179fn calculate_rms(samples: &[f32]) -> f32 {
180    if samples.is_empty() {
181        return 0.0;
182    }
183    let sum_squares: f32 = samples.iter().map(|s| s * s).sum();
184    (sum_squares / samples.len() as f32).sqrt()
185}
186
187/// Voice Activity Detection state
188struct VadState {
189    last_speech_time: Instant,
190    has_detected_speech: bool,
191    recording_start: Instant,
192    has_logged_max_duration: bool,
193    last_rms_log_time: Instant,
194    peak_rms: f32,
195}
196
197impl VadState {
198    fn new() -> Self {
199        let now = Instant::now();
200        Self {
201            last_speech_time: now,
202            has_detected_speech: false,
203            recording_start: now,
204            has_logged_max_duration: false,
205            last_rms_log_time: now,
206            peak_rms: 0.0,
207        }
208    }
209}
210
211/// Audio device information
212#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
213pub struct AudioDeviceInfo {
214    pub name: String,
215    pub is_default: bool,
216}
217
218/// Check if microphone is available
219pub fn is_microphone_available() -> bool {
220    with_audio_host_access(|| {
221        let host = cpal::default_host();
222        host.default_input_device().is_some()
223    })
224}
225
226/// List all available audio input devices
227pub fn list_audio_input_devices() -> Vec<AudioDeviceInfo> {
228    with_audio_host_access(|| {
229        let host = cpal::default_host();
230        let default_device_name = host.default_input_device().and_then(|d| d.name().ok());
231
232        let mut devices = Vec::new();
233
234        if let Ok(input_devices) = host.input_devices() {
235            for device in input_devices {
236                if let Ok(name) = device.name() {
237                    let is_default = default_device_name
238                        .as_ref()
239                        .map(|d| d == &name)
240                        .unwrap_or(false);
241                    devices.push(AudioDeviceInfo { name, is_default });
242                }
243            }
244        }
245
246        devices
247    })
248}
249
250/// Record audio with Voice Activity Detection - stops after silence timeout
251fn record_audio_with_vad(output_path: &Path, config: &AudioCaptureConfig) -> Result<f32, AppError> {
252    with_audio_host_access(|| {
253        // Get default audio host
254        let host = cpal::default_host();
255
256        // Try to find the specified device, or fall back to default
257        let device = if let Some(ref name) = config.device_name {
258            let found = host
259                .input_devices()
260                .ok()
261                .and_then(|mut devices| devices.find(|d| d.name().ok().as_deref() == Some(name)));
262
263            if let Some(dev) = found {
264                tracing::info!("Using configured audio input device: {}", name);
265                dev
266            } else {
267                tracing::warn!("Configured device '{}' not found, using default", name);
268                host.default_input_device()
269                    .ok_or_else(|| AppError::Voice("No input device available".into()))?
270            }
271        } else {
272            host.default_input_device()
273                .ok_or_else(|| AppError::Voice("No input device available".into()))?
274        };
275
276        tracing::info!("Using audio input device: {:?}", device.name());
277
278        // Get supported config
279        let device_config = device
280            .default_input_config()
281            .map_err(|e| AppError::Voice(format!("Failed to get config: {}", e)))?;
282
283        let sample_rate = device_config.sample_rate().0;
284        let channels = device_config.channels();
285
286        tracing::info!("Audio config: {}Hz, {} channels", sample_rate, channels);
287
288        // Create shared buffer for samples
289        let samples: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(Vec::new()));
290        let samples_clone = Arc::clone(&samples);
291
292        // Shared VAD state
293        let vad_state = Arc::new(Mutex::new(VadState::new()));
294        let vad_state_clone = Arc::clone(&vad_state);
295
296        // Flag to signal when to stop recording
297        let should_stop = Arc::new(AtomicBool::new(false));
298        let should_stop_clone = Arc::clone(&should_stop);
299
300        // Samples per VAD window
301        let samples_per_window =
302            (sample_rate as u64 * channels as u64 * VAD_WINDOW_MS / 1000) as usize;
303
304        // Buffer for VAD analysis
305        let vad_buffer: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(Vec::new()));
306        let vad_buffer_clone = Arc::clone(&vad_buffer);
307
308        // Capture config values for closure
309        let silence_threshold = config.silence_threshold;
310        let silence_timeout = config.silence_timeout_secs;
311        let max_recording = config.max_recording_secs;
312        let wait_for_speech = config.wait_for_speech_timeout_secs;
313
314        // Build input stream based on sample format
315        let stream = match device_config.sample_format() {
316            cpal::SampleFormat::F32 => device
317                .build_input_stream(
318                    &device_config.clone().into(),
319                    move |data: &[f32], _: &cpal::InputCallbackInfo| {
320                        process_audio_data(
321                            data,
322                            &samples_clone,
323                            &vad_buffer_clone,
324                            &vad_state_clone,
325                            &should_stop_clone,
326                            samples_per_window,
327                            silence_threshold,
328                            silence_timeout,
329                            max_recording,
330                            wait_for_speech,
331                        );
332                    },
333                    |err| {
334                        tracing::error!("Audio stream error: {}", err);
335                    },
336                    None,
337                )
338                .map_err(|e| AppError::Voice(format!("Failed to build stream: {}", e)))?,
339            cpal::SampleFormat::I16 => {
340                let samples_clone_i16 = Arc::clone(&samples);
341                let vad_buffer_i16 = Arc::clone(&vad_buffer);
342                let vad_state_i16 = Arc::clone(&vad_state);
343                let should_stop_i16 = Arc::clone(&should_stop);
344
345                device
346                    .build_input_stream(
347                        &device_config.clone().into(),
348                        move |data: &[i16], _: &cpal::InputCallbackInfo| {
349                            let f32_data: Vec<f32> =
350                                data.iter().map(|&s| s as f32 / i16::MAX as f32).collect();
351                            process_audio_data(
352                                &f32_data,
353                                &samples_clone_i16,
354                                &vad_buffer_i16,
355                                &vad_state_i16,
356                                &should_stop_i16,
357                                samples_per_window,
358                                silence_threshold,
359                                silence_timeout,
360                                max_recording,
361                                wait_for_speech,
362                            );
363                        },
364                        |err| {
365                            tracing::error!("Audio stream error: {}", err);
366                        },
367                        None,
368                    )
369                    .map_err(|e| AppError::Voice(format!("Failed to build stream: {}", e)))?
370            }
371            _ => return Err(AppError::Voice("Unsupported sample format".into())),
372        };
373
374        // Start recording
375        stream
376            .play()
377            .map_err(|e| AppError::Voice(format!("Failed to start stream: {}", e)))?;
378
379        tracing::info!(
380            "Recording with VAD - will stop after {}s of silence...",
381            config.silence_timeout_secs
382        );
383
384        // Wait for speech to end (with silence timeout), max duration, or external stop request
385        loop {
386            std::thread::sleep(Duration::from_millis(100));
387
388            // Check internal VAD stop flag
389            if should_stop.load(Ordering::SeqCst) {
390                tracing::info!("Recording stopped by VAD (silence/max duration)");
391                break;
392            }
393
394            // Check external stop request (e.g., "Stop Listening" button)
395            if is_stop_requested() {
396                tracing::info!("Recording stopped by external request");
397                should_stop.store(true, Ordering::SeqCst);
398                break;
399            }
400
401            // Also check for max duration from main thread
402            let state = vad_state.lock().unwrap();
403            if state.recording_start.elapsed().as_secs() >= config.max_recording_secs {
404                break;
405            }
406        }
407
408        // Stop recording - explicitly pause and drop the stream to release the microphone
409        let _ = stream.pause();
410        drop(stream);
411        tracing::info!("Audio stream stopped and microphone released");
412
413        // Get recorded samples
414        let recorded_samples = samples.lock().unwrap();
415        let sample_count = recorded_samples.len();
416        let duration_secs = sample_count as f32 / (sample_rate as f32 * channels as f32);
417
418        tracing::info!("Recorded {} samples ({:.2}s)", sample_count, duration_secs);
419
420        // If externally stopped with no audio, return early without error
421        if sample_count == 0 {
422            if is_stop_requested() {
423                return Err(AppError::Voice("Recording cancelled by user".into()));
424            }
425            return Err(AppError::Voice("No audio captured".into()));
426        }
427
428        // Resample audio to 16kHz mono for Whisper compatibility
429        let resampled = resample_to_16khz(&recorded_samples, sample_rate, channels);
430
431        // Save to WAV file at 16kHz mono (what Whisper expects)
432        save_samples_to_wav(&resampled, WHISPER_SAMPLE_RATE, 1, output_path)?;
433
434        Ok(duration_secs)
435    })
436}
437
438/// Process audio data for VAD analysis
439#[allow(clippy::too_many_arguments)]
440fn process_audio_data(
441    data: &[f32],
442    samples: &Arc<Mutex<Vec<f32>>>,
443    vad_buffer: &Arc<Mutex<Vec<f32>>>,
444    vad_state: &Arc<Mutex<VadState>>,
445    should_stop: &Arc<AtomicBool>,
446    samples_per_window: usize,
447    silence_threshold: f32,
448    silence_timeout: f32,
449    max_recording: u64,
450    wait_for_speech: u64,
451) {
452    // Early exit if we should stop - don't process any more audio
453    if should_stop.load(Ordering::SeqCst) {
454        return;
455    }
456
457    // Store all samples
458    {
459        let mut buffer = samples.lock().unwrap();
460        buffer.extend_from_slice(data);
461    }
462
463    // Add to VAD buffer for analysis
464    {
465        let mut vad_buf = vad_buffer.lock().unwrap();
466        vad_buf.extend_from_slice(data);
467
468        // Analyze when we have enough samples
469        while vad_buf.len() >= samples_per_window {
470            let window: Vec<f32> = vad_buf.drain(..samples_per_window).collect();
471            let rms = calculate_rms(&window);
472
473            let mut state = vad_state.lock().unwrap();
474            let now = Instant::now();
475
476            // Track peak RMS for debugging
477            if rms > state.peak_rms {
478                state.peak_rms = rms;
479            }
480
481            // Log RMS periodically (every 2 seconds) for debugging
482            if now.duration_since(state.last_rms_log_time).as_secs() >= 2 {
483                tracing::info!(
484                    "VAD status: current_rms={:.4}, peak_rms={:.4}, threshold={:.4}, speech_detected={}",
485                    rms,
486                    state.peak_rms,
487                    silence_threshold,
488                    state.has_detected_speech
489                );
490                state.last_rms_log_time = now;
491            }
492
493            if rms > silence_threshold {
494                // Speech detected
495                state.last_speech_time = now;
496                if !state.has_detected_speech {
497                    state.has_detected_speech = true;
498                    tracing::info!(
499                        "🎤 Speech detected! (RMS: {:.4} > threshold: {:.4})",
500                        rms,
501                        silence_threshold
502                    );
503                }
504            } else if state.has_detected_speech {
505                // Check if silence timeout reached (only stop once)
506                let silence_duration = now.duration_since(state.last_speech_time);
507                if silence_duration.as_secs_f32() >= silence_timeout
508                    && !should_stop.load(Ordering::SeqCst)
509                {
510                    tracing::info!(
511                        "🔇 Silence timeout reached ({:.1}s) - stopping recording",
512                        silence_duration.as_secs_f32()
513                    );
514                    should_stop.store(true, Ordering::SeqCst);
515                }
516            } else {
517                // No speech detected yet - check for "waiting for speech" timeout
518                let waiting_duration = now.duration_since(state.recording_start).as_secs();
519                if waiting_duration >= wait_for_speech {
520                    tracing::warn!(
521                        "⏱️ No speech detected after {}s (peak_rms={:.4}, threshold={:.4}) - stopping",
522                        waiting_duration,
523                        state.peak_rms,
524                        silence_threshold
525                    );
526                    should_stop.store(true, Ordering::SeqCst);
527                }
528            }
529
530            // Check max recording duration (only log once)
531            if now.duration_since(state.recording_start).as_secs() >= max_recording {
532                if !state.has_logged_max_duration {
533                    tracing::info!("Max recording duration reached");
534                    state.has_logged_max_duration = true;
535                }
536                should_stop.store(true, Ordering::SeqCst);
537            }
538        }
539    }
540}
541
542/// Resample audio from source sample rate to 16kHz mono for Whisper
543/// Uses simple linear interpolation for resampling
544fn resample_to_16khz(samples: &[f32], source_rate: u32, channels: u16) -> Vec<f32> {
545    if samples.is_empty() {
546        return Vec::new();
547    }
548
549    // First, convert to mono if stereo
550    let mono_samples: Vec<f32> = if channels > 1 {
551        samples
552            .chunks(channels as usize)
553            .map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
554            .collect()
555    } else {
556        samples.to_vec()
557    };
558
559    // If already at 16kHz, return mono samples
560    if source_rate == WHISPER_SAMPLE_RATE {
561        tracing::info!("Audio already at 16kHz, no resampling needed");
562        return mono_samples;
563    }
564
565    // Calculate resampling ratio
566    let ratio = source_rate as f64 / WHISPER_SAMPLE_RATE as f64;
567    let output_len = (mono_samples.len() as f64 / ratio).ceil() as usize;
568    let mut resampled = Vec::with_capacity(output_len);
569
570    tracing::info!(
571        "Resampling audio from {}Hz to {}Hz ({} -> {} samples)",
572        source_rate,
573        WHISPER_SAMPLE_RATE,
574        mono_samples.len(),
575        output_len
576    );
577
578    // Linear interpolation resampling
579    for i in 0..output_len {
580        let src_pos = i as f64 * ratio;
581        let src_idx = src_pos.floor() as usize;
582        let frac = (src_pos - src_idx as f64) as f32;
583
584        if src_idx + 1 < mono_samples.len() {
585            // Interpolate between two samples
586            let sample = mono_samples[src_idx] * (1.0 - frac) + mono_samples[src_idx + 1] * frac;
587            resampled.push(sample);
588        } else if src_idx < mono_samples.len() {
589            resampled.push(mono_samples[src_idx]);
590        }
591    }
592
593    resampled
594}
595
596/// Save audio samples to a WAV file
597fn save_samples_to_wav(
598    samples: &[f32],
599    sample_rate: u32,
600    channels: u16,
601    path: &Path,
602) -> Result<(), AppError> {
603    let spec = hound::WavSpec {
604        channels,
605        sample_rate,
606        bits_per_sample: 16,
607        sample_format: hound::SampleFormat::Int,
608    };
609
610    let mut writer = hound::WavWriter::create(path, spec)
611        .map_err(|e| AppError::Voice(format!("Failed to create WAV: {}", e)))?;
612
613    // Convert f32 samples to i16
614    for sample in samples {
615        let sample_i16 = (*sample * i16::MAX as f32) as i16;
616        writer
617            .write_sample(sample_i16)
618            .map_err(|e| AppError::Voice(format!("Failed to write sample: {}", e)))?;
619    }
620
621    writer
622        .finalize()
623        .map_err(|e| AppError::Voice(format!("Failed to finalize WAV: {}", e)))?;
624
625    tracing::info!(
626        "Saved audio to {:?} ({}Hz, {} channels)",
627        path,
628        sample_rate,
629        channels
630    );
631    Ok(())
632}