gestura_core_audio/
speech.rs

1//! Speech processing module for Gestura
2//!
3//! This module provides core speech-to-text and LLM processing functionality.
4//! It is designed to be used by both GUI and CLI applications.
5//!
6//! The Tauri-specific event handling (window management, event emission) should
7//! be implemented in the gestura-gui crate.
8
9use crate::audio_capture::{AudioCaptureConfig, record_audio};
10use gestura_core_config::AppConfig;
11use gestura_core_foundation::AppError;
12use serde::{Deserialize, Serialize};
13use std::path::{Path, PathBuf};
14use std::sync::{Arc, Mutex};
15use std::time::Duration;
16
17/// Speech processing configuration
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct SpeechConfig {
20    /// STT provider: "local-whisper" or "openai-whisper"
21    pub stt_provider: String,
22    /// LLM provider for processing
23    pub llm_provider: String,
24    /// OpenAI API key for Whisper API
25    pub openai_api_key: String,
26    /// Anthropic API key
27    pub anthropic_api_key: String,
28    /// Google API key
29    pub google_api_key: String,
30    /// Azure API key
31    pub azure_api_key: String,
32    /// Local LLM endpoint (e.g., Ollama)
33    pub local_llm_endpoint: String,
34    /// STT timeout in seconds
35    pub stt_timeout: u64,
36    /// LLM timeout in seconds
37    pub llm_timeout: u64,
38    /// Enable fallback to alternative providers
39    pub enable_fallback: bool,
40    /// Cache LLM responses
41    pub cache_responses: bool,
42}
43
44impl Default for SpeechConfig {
45    fn default() -> Self {
46        Self {
47            stt_provider: "local-whisper".to_string(),
48            llm_provider: "openai".to_string(),
49            openai_api_key: String::new(),
50            anthropic_api_key: String::new(),
51            google_api_key: String::new(),
52            azure_api_key: String::new(),
53            local_llm_endpoint: "http://localhost:11434".to_string(),
54            stt_timeout: 30,
55            llm_timeout: 60,
56            enable_fallback: true,
57            cache_responses: true,
58        }
59    }
60}
61
62impl SpeechConfig {
63    /// Create SpeechConfig from AppConfig
64    pub fn from_app_config(app_config: &AppConfig) -> Self {
65        Self {
66            stt_provider: match app_config.voice.provider.as_str() {
67                "local" => "local-whisper".to_string(),
68                "openai" => "openai-whisper".to_string(),
69                _ => "local-whisper".to_string(),
70            },
71            llm_provider: app_config.llm.primary.clone(),
72            openai_api_key: app_config.voice.openai_api_key.clone().unwrap_or_default(),
73            anthropic_api_key: app_config
74                .llm
75                .anthropic
76                .as_ref()
77                .map(|a| a.api_key.clone())
78                .unwrap_or_default(),
79            google_api_key: String::new(),
80            azure_api_key: String::new(),
81            local_llm_endpoint: app_config
82                .llm
83                .ollama
84                .as_ref()
85                .map(|o| o.base_url.clone())
86                .unwrap_or_else(|| "http://localhost:11434".to_string()),
87            stt_timeout: 30,
88            llm_timeout: 60,
89            enable_fallback: true,
90            cache_responses: true,
91        }
92    }
93}
94
95/// Result of speech transcription
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct TranscriptionResult {
98    /// The transcribed text
99    pub text: String,
100    /// Duration of the audio in seconds
101    pub duration_secs: f32,
102    /// Path to the temporary audio file (if retained)
103    pub audio_path: Option<PathBuf>,
104    /// Provider used for transcription
105    pub provider: String,
106}
107
108/// Result of LLM processing
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct LlmResponse {
111    /// The AI response text
112    pub text: String,
113    /// Provider used for processing
114    pub provider: String,
115    /// Whether this was a cached response
116    pub cached: bool,
117}
118
119/// Core speech processor without Tauri dependencies
120///
121/// This processor handles:
122/// - Audio recording with VAD
123/// - Speech-to-text transcription
124/// - LLM processing
125///
126/// Event emission and window management should be handled by the caller.
127#[derive(Debug, Clone)]
128pub struct SpeechProcessor {
129    config: Arc<Mutex<SpeechConfig>>,
130    is_recording: Arc<Mutex<bool>>,
131}
132
133impl Default for SpeechProcessor {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139impl SpeechProcessor {
140    /// Create a new speech processor with default configuration
141    pub fn new() -> Self {
142        Self {
143            config: Arc::new(Mutex::new(SpeechConfig::default())),
144            is_recording: Arc::new(Mutex::new(false)),
145        }
146    }
147
148    /// Create a new speech processor with custom configuration
149    pub fn with_config(config: SpeechConfig) -> Self {
150        Self {
151            config: Arc::new(Mutex::new(config)),
152            is_recording: Arc::new(Mutex::new(false)),
153        }
154    }
155
156    /// Update the speech processor configuration
157    pub fn update_config(&self, config: SpeechConfig) {
158        let mut current_config = self.config.lock().unwrap();
159        *current_config = config;
160        tracing::info!("Speech processor configuration updated");
161    }
162
163    /// Get the current configuration
164    pub fn get_config(&self) -> SpeechConfig {
165        self.config.lock().unwrap().clone()
166    }
167
168    /// Check if currently recording
169    pub fn is_recording(&self) -> bool {
170        *self.is_recording.lock().unwrap()
171    }
172
173    /// Set recording state
174    pub fn set_recording(&self, recording: bool) {
175        let mut is_recording = self.is_recording.lock().unwrap();
176        *is_recording = recording;
177    }
178
179    /// Stop the current recording
180    pub fn stop_recording(&self) -> Result<(), AppError> {
181        let mut recording = self.is_recording.lock().unwrap();
182        if !*recording {
183            return Err(AppError::Voice("Not currently recording".to_string()));
184        }
185        *recording = false;
186
187        // Signal the audio capture to stop immediately
188        crate::audio_capture::request_stop_recording();
189
190        tracing::info!("Stopped speech capture and requested audio recording stop");
191        Ok(())
192    }
193
194    /// Record audio from microphone and return the path to the audio file
195    ///
196    /// Returns the duration and path to the recorded audio file.
197    /// The caller is responsible for cleaning up the temp file.
198    pub async fn record_audio_to_file(
199        &self,
200        device_name: Option<String>,
201    ) -> Result<(f32, PathBuf), AppError> {
202        // Check if already recording
203        {
204            let mut recording = self.is_recording.lock().unwrap();
205            if *recording {
206                return Err(AppError::Voice("Already recording".to_string()));
207            }
208            *recording = true;
209        }
210
211        let temp_dir = std::env::temp_dir();
212        let audio_path = temp_dir.join(format!(
213            "gestura_audio_{}.wav",
214            chrono::Utc::now().timestamp()
215        ));
216
217        let config = AudioCaptureConfig {
218            device_name,
219            ..Default::default()
220        };
221
222        let result = record_audio(Duration::from_secs(0), &audio_path, config).await;
223
224        // Reset recording state
225        {
226            let mut recording = self.is_recording.lock().unwrap();
227            *recording = false;
228        }
229
230        match result {
231            Ok(duration) => {
232                tracing::info!("Recorded {:.2}s of audio to {:?}", duration, audio_path);
233                if duration < 0.5 {
234                    let _ = std::fs::remove_file(&audio_path);
235                    return Err(AppError::Voice(
236                        "Recording too short - no audio captured".to_string(),
237                    ));
238                }
239                Ok((duration, audio_path))
240            }
241            Err(e) => {
242                let _ = std::fs::remove_file(&audio_path);
243                Err(e)
244            }
245        }
246    }
247
248    /// Transcribe audio using local Whisper model (whisper-rs)
249    #[cfg(feature = "voice-local")]
250    #[allow(dead_code)]
251    async fn transcribe_with_local_whisper(
252        &self,
253        audio_path: &Path,
254    ) -> Result<TranscriptionResult, AppError> {
255        use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
256
257        // Get model path from config or use default
258        let model_path = self.get_whisper_model_path()?;
259
260        tracing::info!("Loading Whisper model from: {:?}", model_path);
261
262        // Load the Whisper context (model)
263        let ctx = WhisperContext::new_with_params(
264            model_path
265                .to_str()
266                .ok_or_else(|| AppError::Voice("Invalid model path encoding".to_string()))?,
267            WhisperContextParameters::default(),
268        )
269        .map_err(|e| AppError::Voice(format!("Failed to load Whisper model: {}", e)))?;
270
271        // Read and convert audio to f32 samples
272        let samples = self.load_audio_samples(audio_path)?;
273        let duration_secs = samples.len() as f32 / 16000.0; // Whisper expects 16kHz
274
275        // Create transcription parameters
276        let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
277        params.set_language(Some("en"));
278        params.set_print_special(false);
279        params.set_print_progress(false);
280        params.set_print_realtime(false);
281        params.set_print_timestamps(false);
282        params.set_translate(false);
283        params.set_no_context(true);
284        params.set_single_segment(false);
285
286        // Create state and run transcription
287        let mut state = ctx
288            .create_state()
289            .map_err(|e| AppError::Voice(format!("Failed to create Whisper state: {}", e)))?;
290
291        state
292            .full(params, &samples)
293            .map_err(|e| AppError::Voice(format!("Whisper transcription failed: {}", e)))?;
294
295        // Collect all segments
296        let num_segments = state
297            .full_n_segments()
298            .map_err(|e| AppError::Voice(format!("Failed to get segment count: {}", e)))?;
299
300        let mut text = String::new();
301        for i in 0..num_segments {
302            if let Ok(segment_text) = state.full_get_segment_text(i) {
303                text.push_str(&segment_text);
304                text.push(' ');
305            }
306        }
307
308        let text = text.trim().to_string();
309        tracing::info!("Local Whisper transcription complete: {} chars", text.len());
310
311        Ok(TranscriptionResult {
312            text,
313            duration_secs,
314            audio_path: Some(audio_path.to_path_buf()),
315            provider: "local-whisper".to_string(),
316        })
317    }
318
319    /// Fallback when voice-local feature is not enabled
320    #[cfg(not(feature = "voice-local"))]
321    #[allow(dead_code)]
322    async fn transcribe_with_local_whisper(
323        &self,
324        _audio_path: &Path,
325    ) -> Result<TranscriptionResult, AppError> {
326        Err(AppError::Voice(
327            "Local Whisper transcription requires the 'whisper' feature. \
328             Build with `cargo build --features whisper` or use OpenAI Whisper API instead."
329                .to_string(),
330        ))
331    }
332
333    /// Get the path to the Whisper model file
334    #[allow(dead_code)]
335    fn get_whisper_model_path(&self) -> Result<PathBuf, AppError> {
336        // Check environment variable first
337        if let Ok(path) = std::env::var("GESTURA_WHISPER_MODEL") {
338            let path = PathBuf::from(path);
339            if path.exists() {
340                return Ok(path);
341            }
342        }
343
344        // Check standard locations
345        let model_names = ["ggml-base.en.bin", "ggml-small.en.bin", "ggml-tiny.en.bin"];
346        let search_dirs = [
347            // User data directory
348            dirs::data_dir().map(|d| d.join("gestura").join("models")),
349            // Home directory
350            dirs::home_dir().map(|d| d.join(".gestura").join("models")),
351            // Current directory
352            Some(PathBuf::from("models")),
353        ];
354
355        for dir in search_dirs.iter().flatten() {
356            for model_name in &model_names {
357                let model_path = dir.join(model_name);
358                if model_path.exists() {
359                    tracing::info!("Found Whisper model at: {:?}", model_path);
360                    return Ok(model_path);
361                }
362            }
363        }
364
365        Err(AppError::Voice(
366            "Whisper model not found. Please download a model (e.g., ggml-base.en.bin) \
367             and place it in ~/.gestura/models/ or set GESTURA_WHISPER_MODEL environment variable."
368                .to_string(),
369        ))
370    }
371
372    /// Load audio file and convert to f32 samples at 16kHz
373    #[cfg(feature = "voice-local")]
374    #[allow(dead_code)]
375    fn load_audio_samples(&self, audio_path: &Path) -> Result<Vec<f32>, AppError> {
376        use hound::WavReader;
377
378        let mut reader = WavReader::open(audio_path)
379            .map_err(|e| AppError::Voice(format!("Failed to open audio file: {}", e)))?;
380
381        let spec = reader.spec();
382        let sample_rate = spec.sample_rate;
383        let channels = spec.channels as usize;
384
385        // Read samples based on format, propagating decode errors
386        let samples: Vec<f32> = match spec.sample_format {
387            hound::SampleFormat::Int => {
388                let max_val = (1 << (spec.bits_per_sample - 1)) as f32;
389                let raw_samples: Result<Vec<i32>, _> = reader.samples::<i32>().collect();
390                raw_samples
391                    .map_err(|e| AppError::Voice(format!("Failed to decode audio samples: {}", e)))?
392                    .into_iter()
393                    .map(|s| s as f32 / max_val)
394                    .collect()
395            }
396            hound::SampleFormat::Float => {
397                let raw_samples: Result<Vec<f32>, _> = reader.samples::<f32>().collect();
398                raw_samples.map_err(|e| {
399                    AppError::Voice(format!("Failed to decode audio samples: {}", e))
400                })?
401            }
402        };
403
404        // Convert to mono if stereo
405        let mono_samples: Vec<f32> = if channels > 1 {
406            samples
407                .chunks(channels)
408                .map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
409                .collect()
410        } else {
411            samples
412        };
413
414        // Resample to 16kHz if needed (simple linear interpolation)
415        let target_rate = 16000;
416        if sample_rate != target_rate {
417            let ratio = sample_rate as f32 / target_rate as f32;
418            let new_len = (mono_samples.len() as f32 / ratio) as usize;
419            let mut resampled = Vec::with_capacity(new_len);
420
421            for i in 0..new_len {
422                let src_idx = i as f32 * ratio;
423                let idx = src_idx as usize;
424                let frac = src_idx - idx as f32;
425
426                let sample = if idx + 1 < mono_samples.len() {
427                    mono_samples[idx] * (1.0 - frac) + mono_samples[idx + 1] * frac
428                } else {
429                    mono_samples[idx.min(mono_samples.len() - 1)]
430                };
431                resampled.push(sample);
432            }
433
434            Ok(resampled)
435        } else {
436            Ok(mono_samples)
437        }
438    }
439
440    /// Determine if text is a conversation or command
441    pub fn is_conversation(&self, text: &str) -> bool {
442        let conversation_keywords = [
443            "help", "what", "how", "can you", "please", "tell me", "explain",
444        ];
445        let command_keywords = [
446            "open", "close", "start", "stop", "launch", "quit", "show", "hide",
447        ];
448
449        let text_lower = text.to_lowercase();
450        let conversation_score = conversation_keywords
451            .iter()
452            .filter(|&keyword| text_lower.contains(keyword))
453            .count();
454        let command_score = command_keywords
455            .iter()
456            .filter(|&keyword| text_lower.contains(keyword))
457            .count();
458
459        conversation_score > command_score
460    }
461}
462
463/// Resolve the path to the local Whisper model file.
464///
465/// Checks (in order):
466/// 1. `config.voice.local_model_path` if set
467/// 2. `GESTURA_WHISPER_MODEL` environment variable
468/// 3. Default search directories (~/.gestura/models/, ./models/)
469///
470/// Returns an error with an actionable message if no model is found.
471#[cfg(feature = "voice-local")]
472pub fn resolve_whisper_model_path(config: &AppConfig) -> Result<PathBuf, AppError> {
473    // 1. Check config-provided path first
474    if let Some(ref path_str) = config.voice.local_model_path {
475        let path = PathBuf::from(path_str);
476        if path.exists() {
477            return Ok(path);
478        }
479        // Config path set but doesn't exist — warn and continue searching
480        tracing::warn!(
481            "Configured local_model_path '{}' does not exist; searching defaults",
482            path_str
483        );
484    }
485
486    // 2. Check environment variable
487    if let Ok(path_str) = std::env::var("GESTURA_WHISPER_MODEL") {
488        let path = PathBuf::from(&path_str);
489        if path.exists() {
490            return Ok(path);
491        }
492    }
493
494    // 3. Search default directories
495    let model_names = [
496        "ggml-base.en.bin",
497        "ggml-base.bin",
498        "ggml-small.en.bin",
499        "ggml-small.bin",
500        "ggml-medium.en.bin",
501        "ggml-medium.bin",
502        "ggml-large.bin",
503    ];
504
505    let search_dirs: Vec<Option<PathBuf>> = vec![
506        dirs::home_dir().map(|h| h.join(".gestura").join("models")),
507        Some(PathBuf::from("models")),
508    ];
509
510    for dir in search_dirs.iter().flatten() {
511        for model_name in &model_names {
512            let model_path = dir.join(model_name);
513            if model_path.exists() {
514                tracing::info!("Found Whisper model at: {:?}", model_path);
515                return Ok(model_path);
516            }
517        }
518    }
519
520    Err(AppError::Voice(
521        "Whisper model not found. Please download a model (e.g., ggml-base.en.bin) \
522         and place it in ~/.gestura/models/ or set GESTURA_WHISPER_MODEL environment variable, \
523         or configure voice.local_model_path in your settings."
524            .to_string(),
525    ))
526}
527
528/// Resolve the path to a local Whisper model file, with an optional session-scoped override.
529///
530/// This is the core-owned implementation of the GUI's per-session model selection rules.
531///
532/// When `session_model` is provided (and non-empty after trimming), it is interpreted as:
533/// - a full path when it contains a path separator (`/` or `\\`), otherwise
534/// - a filename resolved under `AppConfig::whisper_models_dir()`.
535///
536/// If a session override is provided but the resolved file does not exist, this returns an
537/// error instead of silently falling back to global defaults.
538///
539/// When no session override is provided, this falls back to [`resolve_whisper_model_path`].
540#[cfg(feature = "voice-local")]
541pub fn resolve_whisper_model_path_with_override(
542    config: &AppConfig,
543    session_model: Option<&str>,
544) -> Result<PathBuf, AppError> {
545    resolve_whisper_model_path_with_override_in_dir(
546        config,
547        session_model,
548        &AppConfig::whisper_models_dir(),
549    )
550}
551
552/// Implementation helper for [`resolve_whisper_model_path_with_override`].
553///
554/// This is split out to allow deterministic unit testing with a temporary models directory.
555#[cfg(feature = "voice-local")]
556fn resolve_whisper_model_path_with_override_in_dir(
557    config: &AppConfig,
558    session_model: Option<&str>,
559    whisper_models_dir: &Path,
560) -> Result<PathBuf, AppError> {
561    let session_model = session_model.map(str::trim).filter(|m| !m.is_empty());
562
563    if let Some(model) = session_model {
564        let candidate = if model.contains('/') || model.contains('\\') {
565            PathBuf::from(model)
566        } else {
567            whisper_models_dir.join(model)
568        };
569
570        if !candidate.exists() {
571            return Err(AppError::Voice(format!(
572                "Local Whisper model file not found at: {}. Please download a whisper.cpp compatible model (.bin file).",
573                candidate.display()
574            )));
575        }
576
577        return Ok(candidate);
578    }
579
580    resolve_whisper_model_path(config)
581}
582
583#[cfg(all(test, feature = "voice-local"))]
584mod whisper_model_override_tests {
585    use super::*;
586
587    #[test]
588    fn session_model_filename_is_resolved_under_models_dir() {
589        let tmp = tempfile::tempdir().expect("tempdir");
590        let models_dir = tmp.path().join("models");
591        std::fs::create_dir_all(&models_dir).expect("create models dir");
592
593        let model_file = models_dir.join("ggml-tiny.en.bin");
594        std::fs::write(&model_file, b"test").expect("write model");
595
596        let cfg = AppConfig::default();
597        let resolved = resolve_whisper_model_path_with_override_in_dir(
598            &cfg,
599            Some("ggml-tiny.en.bin"),
600            &models_dir,
601        )
602        .expect("resolve");
603
604        assert_eq!(resolved, model_file);
605    }
606
607    #[test]
608    fn session_model_path_is_used_as_is() {
609        let tmp = tempfile::tempdir().expect("tempdir");
610        let model_file = tmp.path().join("ggml-small.en.bin");
611        std::fs::write(&model_file, b"test").expect("write model");
612
613        let cfg = AppConfig::default();
614        let resolved = resolve_whisper_model_path_with_override_in_dir(
615            &cfg,
616            Some(model_file.to_string_lossy().as_ref()),
617            tmp.path(),
618        )
619        .expect("resolve");
620
621        assert_eq!(resolved, model_file);
622    }
623
624    #[test]
625    fn missing_session_model_returns_actionable_error() {
626        let tmp = tempfile::tempdir().expect("tempdir");
627        let cfg = AppConfig::default();
628        let err = resolve_whisper_model_path_with_override_in_dir(
629            &cfg,
630            Some("does-not-exist.bin"),
631            tmp.path(),
632        )
633        .expect_err("should error");
634
635        match err {
636            AppError::Voice(msg) => assert!(msg.contains("Local Whisper model file not found")),
637            other => panic!("unexpected error type: {other:?}"),
638        }
639    }
640}
641
642/// Load audio file and convert to 16kHz mono f32 samples.
643///
644/// This is the format required by whisper.cpp / whisper-rs for transcription.
645/// Supports WAV files with integer or float samples, any sample rate, mono or stereo.
646#[cfg(feature = "voice-local")]
647pub fn load_audio_samples_16khz_mono(audio_path: &Path) -> Result<Vec<f32>, AppError> {
648    use hound::WavReader;
649
650    let mut reader = WavReader::open(audio_path)
651        .map_err(|e| AppError::Voice(format!("Failed to open audio file: {}", e)))?;
652
653    let spec = reader.spec();
654    let sample_rate = spec.sample_rate;
655    let channels = spec.channels as usize;
656
657    // Read samples based on format, propagating decode errors
658    let samples: Vec<f32> = match spec.sample_format {
659        hound::SampleFormat::Int => {
660            let max_val = (1 << (spec.bits_per_sample - 1)) as f32;
661            let raw_samples: Result<Vec<i32>, _> = reader.samples::<i32>().collect();
662            raw_samples
663                .map_err(|e| AppError::Voice(format!("Failed to decode audio samples: {}", e)))?
664                .into_iter()
665                .map(|s| s as f32 / max_val)
666                .collect()
667        }
668        hound::SampleFormat::Float => {
669            let raw_samples: Result<Vec<f32>, _> = reader.samples::<f32>().collect();
670            raw_samples
671                .map_err(|e| AppError::Voice(format!("Failed to decode audio samples: {}", e)))?
672        }
673    };
674
675    // Convert to mono if stereo (average channels)
676    let mono_samples: Vec<f32> = if channels > 1 {
677        samples
678            .chunks(channels)
679            .map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
680            .collect()
681    } else {
682        samples
683    };
684
685    // Resample to 16kHz if needed (simple linear interpolation)
686    let target_rate = 16000u32;
687    if sample_rate != target_rate {
688        let ratio = sample_rate as f32 / target_rate as f32;
689        let new_len = (mono_samples.len() as f32 / ratio) as usize;
690        let mut resampled = Vec::with_capacity(new_len);
691
692        for i in 0..new_len {
693            let src_idx = i as f32 * ratio;
694            let idx = src_idx as usize;
695            let frac = src_idx - idx as f32;
696
697            let sample = if idx + 1 < mono_samples.len() {
698                mono_samples[idx] * (1.0 - frac) + mono_samples[idx + 1] * frac
699            } else {
700                mono_samples[idx.min(mono_samples.len() - 1)]
701            };
702            resampled.push(sample);
703        }
704
705        Ok(resampled)
706    } else {
707        Ok(mono_samples)
708    }
709}