1use crate::audio_capture::{AudioCaptureConfig, record_audio};
10use gestura_core_config::AppConfig;
11use gestura_core_foundation::AppError;
12use serde::{Deserialize, Serialize};
13use std::path::{Path, PathBuf};
14use std::sync::{Arc, Mutex};
15use std::time::Duration;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct SpeechConfig {
20 pub stt_provider: String,
22 pub llm_provider: String,
24 pub openai_api_key: String,
26 pub anthropic_api_key: String,
28 pub google_api_key: String,
30 pub azure_api_key: String,
32 pub local_llm_endpoint: String,
34 pub stt_timeout: u64,
36 pub llm_timeout: u64,
38 pub enable_fallback: bool,
40 pub cache_responses: bool,
42}
43
44impl Default for SpeechConfig {
45 fn default() -> Self {
46 Self {
47 stt_provider: "local-whisper".to_string(),
48 llm_provider: "openai".to_string(),
49 openai_api_key: String::new(),
50 anthropic_api_key: String::new(),
51 google_api_key: String::new(),
52 azure_api_key: String::new(),
53 local_llm_endpoint: "http://localhost:11434".to_string(),
54 stt_timeout: 30,
55 llm_timeout: 60,
56 enable_fallback: true,
57 cache_responses: true,
58 }
59 }
60}
61
62impl SpeechConfig {
63 pub fn from_app_config(app_config: &AppConfig) -> Self {
65 Self {
66 stt_provider: match app_config.voice.provider.as_str() {
67 "local" => "local-whisper".to_string(),
68 "openai" => "openai-whisper".to_string(),
69 _ => "local-whisper".to_string(),
70 },
71 llm_provider: app_config.llm.primary.clone(),
72 openai_api_key: app_config.voice.openai_api_key.clone().unwrap_or_default(),
73 anthropic_api_key: app_config
74 .llm
75 .anthropic
76 .as_ref()
77 .map(|a| a.api_key.clone())
78 .unwrap_or_default(),
79 google_api_key: String::new(),
80 azure_api_key: String::new(),
81 local_llm_endpoint: app_config
82 .llm
83 .ollama
84 .as_ref()
85 .map(|o| o.base_url.clone())
86 .unwrap_or_else(|| "http://localhost:11434".to_string()),
87 stt_timeout: 30,
88 llm_timeout: 60,
89 enable_fallback: true,
90 cache_responses: true,
91 }
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct TranscriptionResult {
98 pub text: String,
100 pub duration_secs: f32,
102 pub audio_path: Option<PathBuf>,
104 pub provider: String,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct LlmResponse {
111 pub text: String,
113 pub provider: String,
115 pub cached: bool,
117}
118
119#[derive(Debug, Clone)]
128pub struct SpeechProcessor {
129 config: Arc<Mutex<SpeechConfig>>,
130 is_recording: Arc<Mutex<bool>>,
131}
132
133impl Default for SpeechProcessor {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139impl SpeechProcessor {
140 pub fn new() -> Self {
142 Self {
143 config: Arc::new(Mutex::new(SpeechConfig::default())),
144 is_recording: Arc::new(Mutex::new(false)),
145 }
146 }
147
148 pub fn with_config(config: SpeechConfig) -> Self {
150 Self {
151 config: Arc::new(Mutex::new(config)),
152 is_recording: Arc::new(Mutex::new(false)),
153 }
154 }
155
156 pub fn update_config(&self, config: SpeechConfig) {
158 let mut current_config = self.config.lock().unwrap();
159 *current_config = config;
160 tracing::info!("Speech processor configuration updated");
161 }
162
163 pub fn get_config(&self) -> SpeechConfig {
165 self.config.lock().unwrap().clone()
166 }
167
168 pub fn is_recording(&self) -> bool {
170 *self.is_recording.lock().unwrap()
171 }
172
173 pub fn set_recording(&self, recording: bool) {
175 let mut is_recording = self.is_recording.lock().unwrap();
176 *is_recording = recording;
177 }
178
179 pub fn stop_recording(&self) -> Result<(), AppError> {
181 let mut recording = self.is_recording.lock().unwrap();
182 if !*recording {
183 return Err(AppError::Voice("Not currently recording".to_string()));
184 }
185 *recording = false;
186
187 crate::audio_capture::request_stop_recording();
189
190 tracing::info!("Stopped speech capture and requested audio recording stop");
191 Ok(())
192 }
193
194 pub async fn record_audio_to_file(
199 &self,
200 device_name: Option<String>,
201 ) -> Result<(f32, PathBuf), AppError> {
202 {
204 let mut recording = self.is_recording.lock().unwrap();
205 if *recording {
206 return Err(AppError::Voice("Already recording".to_string()));
207 }
208 *recording = true;
209 }
210
211 let temp_dir = std::env::temp_dir();
212 let audio_path = temp_dir.join(format!(
213 "gestura_audio_{}.wav",
214 chrono::Utc::now().timestamp()
215 ));
216
217 let config = AudioCaptureConfig {
218 device_name,
219 ..Default::default()
220 };
221
222 let result = record_audio(Duration::from_secs(0), &audio_path, config).await;
223
224 {
226 let mut recording = self.is_recording.lock().unwrap();
227 *recording = false;
228 }
229
230 match result {
231 Ok(duration) => {
232 tracing::info!("Recorded {:.2}s of audio to {:?}", duration, audio_path);
233 if duration < 0.5 {
234 let _ = std::fs::remove_file(&audio_path);
235 return Err(AppError::Voice(
236 "Recording too short - no audio captured".to_string(),
237 ));
238 }
239 Ok((duration, audio_path))
240 }
241 Err(e) => {
242 let _ = std::fs::remove_file(&audio_path);
243 Err(e)
244 }
245 }
246 }
247
248 #[cfg(feature = "voice-local")]
250 #[allow(dead_code)]
251 async fn transcribe_with_local_whisper(
252 &self,
253 audio_path: &Path,
254 ) -> Result<TranscriptionResult, AppError> {
255 use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
256
257 let model_path = self.get_whisper_model_path()?;
259
260 tracing::info!("Loading Whisper model from: {:?}", model_path);
261
262 let ctx = WhisperContext::new_with_params(
264 model_path
265 .to_str()
266 .ok_or_else(|| AppError::Voice("Invalid model path encoding".to_string()))?,
267 WhisperContextParameters::default(),
268 )
269 .map_err(|e| AppError::Voice(format!("Failed to load Whisper model: {}", e)))?;
270
271 let samples = self.load_audio_samples(audio_path)?;
273 let duration_secs = samples.len() as f32 / 16000.0; let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
277 params.set_language(Some("en"));
278 params.set_print_special(false);
279 params.set_print_progress(false);
280 params.set_print_realtime(false);
281 params.set_print_timestamps(false);
282 params.set_translate(false);
283 params.set_no_context(true);
284 params.set_single_segment(false);
285
286 let mut state = ctx
288 .create_state()
289 .map_err(|e| AppError::Voice(format!("Failed to create Whisper state: {}", e)))?;
290
291 state
292 .full(params, &samples)
293 .map_err(|e| AppError::Voice(format!("Whisper transcription failed: {}", e)))?;
294
295 let num_segments = state
297 .full_n_segments()
298 .map_err(|e| AppError::Voice(format!("Failed to get segment count: {}", e)))?;
299
300 let mut text = String::new();
301 for i in 0..num_segments {
302 if let Ok(segment_text) = state.full_get_segment_text(i) {
303 text.push_str(&segment_text);
304 text.push(' ');
305 }
306 }
307
308 let text = text.trim().to_string();
309 tracing::info!("Local Whisper transcription complete: {} chars", text.len());
310
311 Ok(TranscriptionResult {
312 text,
313 duration_secs,
314 audio_path: Some(audio_path.to_path_buf()),
315 provider: "local-whisper".to_string(),
316 })
317 }
318
319 #[cfg(not(feature = "voice-local"))]
321 #[allow(dead_code)]
322 async fn transcribe_with_local_whisper(
323 &self,
324 _audio_path: &Path,
325 ) -> Result<TranscriptionResult, AppError> {
326 Err(AppError::Voice(
327 "Local Whisper transcription requires the 'whisper' feature. \
328 Build with `cargo build --features whisper` or use OpenAI Whisper API instead."
329 .to_string(),
330 ))
331 }
332
333 #[allow(dead_code)]
335 fn get_whisper_model_path(&self) -> Result<PathBuf, AppError> {
336 if let Ok(path) = std::env::var("GESTURA_WHISPER_MODEL") {
338 let path = PathBuf::from(path);
339 if path.exists() {
340 return Ok(path);
341 }
342 }
343
344 let model_names = ["ggml-base.en.bin", "ggml-small.en.bin", "ggml-tiny.en.bin"];
346 let search_dirs = [
347 dirs::data_dir().map(|d| d.join("gestura").join("models")),
349 dirs::home_dir().map(|d| d.join(".gestura").join("models")),
351 Some(PathBuf::from("models")),
353 ];
354
355 for dir in search_dirs.iter().flatten() {
356 for model_name in &model_names {
357 let model_path = dir.join(model_name);
358 if model_path.exists() {
359 tracing::info!("Found Whisper model at: {:?}", model_path);
360 return Ok(model_path);
361 }
362 }
363 }
364
365 Err(AppError::Voice(
366 "Whisper model not found. Please download a model (e.g., ggml-base.en.bin) \
367 and place it in ~/.gestura/models/ or set GESTURA_WHISPER_MODEL environment variable."
368 .to_string(),
369 ))
370 }
371
372 #[cfg(feature = "voice-local")]
374 #[allow(dead_code)]
375 fn load_audio_samples(&self, audio_path: &Path) -> Result<Vec<f32>, AppError> {
376 use hound::WavReader;
377
378 let mut reader = WavReader::open(audio_path)
379 .map_err(|e| AppError::Voice(format!("Failed to open audio file: {}", e)))?;
380
381 let spec = reader.spec();
382 let sample_rate = spec.sample_rate;
383 let channels = spec.channels as usize;
384
385 let samples: Vec<f32> = match spec.sample_format {
387 hound::SampleFormat::Int => {
388 let max_val = (1 << (spec.bits_per_sample - 1)) as f32;
389 let raw_samples: Result<Vec<i32>, _> = reader.samples::<i32>().collect();
390 raw_samples
391 .map_err(|e| AppError::Voice(format!("Failed to decode audio samples: {}", e)))?
392 .into_iter()
393 .map(|s| s as f32 / max_val)
394 .collect()
395 }
396 hound::SampleFormat::Float => {
397 let raw_samples: Result<Vec<f32>, _> = reader.samples::<f32>().collect();
398 raw_samples.map_err(|e| {
399 AppError::Voice(format!("Failed to decode audio samples: {}", e))
400 })?
401 }
402 };
403
404 let mono_samples: Vec<f32> = if channels > 1 {
406 samples
407 .chunks(channels)
408 .map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
409 .collect()
410 } else {
411 samples
412 };
413
414 let target_rate = 16000;
416 if sample_rate != target_rate {
417 let ratio = sample_rate as f32 / target_rate as f32;
418 let new_len = (mono_samples.len() as f32 / ratio) as usize;
419 let mut resampled = Vec::with_capacity(new_len);
420
421 for i in 0..new_len {
422 let src_idx = i as f32 * ratio;
423 let idx = src_idx as usize;
424 let frac = src_idx - idx as f32;
425
426 let sample = if idx + 1 < mono_samples.len() {
427 mono_samples[idx] * (1.0 - frac) + mono_samples[idx + 1] * frac
428 } else {
429 mono_samples[idx.min(mono_samples.len() - 1)]
430 };
431 resampled.push(sample);
432 }
433
434 Ok(resampled)
435 } else {
436 Ok(mono_samples)
437 }
438 }
439
440 pub fn is_conversation(&self, text: &str) -> bool {
442 let conversation_keywords = [
443 "help", "what", "how", "can you", "please", "tell me", "explain",
444 ];
445 let command_keywords = [
446 "open", "close", "start", "stop", "launch", "quit", "show", "hide",
447 ];
448
449 let text_lower = text.to_lowercase();
450 let conversation_score = conversation_keywords
451 .iter()
452 .filter(|&keyword| text_lower.contains(keyword))
453 .count();
454 let command_score = command_keywords
455 .iter()
456 .filter(|&keyword| text_lower.contains(keyword))
457 .count();
458
459 conversation_score > command_score
460 }
461}
462
463#[cfg(feature = "voice-local")]
472pub fn resolve_whisper_model_path(config: &AppConfig) -> Result<PathBuf, AppError> {
473 if let Some(ref path_str) = config.voice.local_model_path {
475 let path = PathBuf::from(path_str);
476 if path.exists() {
477 return Ok(path);
478 }
479 tracing::warn!(
481 "Configured local_model_path '{}' does not exist; searching defaults",
482 path_str
483 );
484 }
485
486 if let Ok(path_str) = std::env::var("GESTURA_WHISPER_MODEL") {
488 let path = PathBuf::from(&path_str);
489 if path.exists() {
490 return Ok(path);
491 }
492 }
493
494 let model_names = [
496 "ggml-base.en.bin",
497 "ggml-base.bin",
498 "ggml-small.en.bin",
499 "ggml-small.bin",
500 "ggml-medium.en.bin",
501 "ggml-medium.bin",
502 "ggml-large.bin",
503 ];
504
505 let search_dirs: Vec<Option<PathBuf>> = vec![
506 dirs::home_dir().map(|h| h.join(".gestura").join("models")),
507 Some(PathBuf::from("models")),
508 ];
509
510 for dir in search_dirs.iter().flatten() {
511 for model_name in &model_names {
512 let model_path = dir.join(model_name);
513 if model_path.exists() {
514 tracing::info!("Found Whisper model at: {:?}", model_path);
515 return Ok(model_path);
516 }
517 }
518 }
519
520 Err(AppError::Voice(
521 "Whisper model not found. Please download a model (e.g., ggml-base.en.bin) \
522 and place it in ~/.gestura/models/ or set GESTURA_WHISPER_MODEL environment variable, \
523 or configure voice.local_model_path in your settings."
524 .to_string(),
525 ))
526}
527
528#[cfg(feature = "voice-local")]
541pub fn resolve_whisper_model_path_with_override(
542 config: &AppConfig,
543 session_model: Option<&str>,
544) -> Result<PathBuf, AppError> {
545 resolve_whisper_model_path_with_override_in_dir(
546 config,
547 session_model,
548 &AppConfig::whisper_models_dir(),
549 )
550}
551
552#[cfg(feature = "voice-local")]
556fn resolve_whisper_model_path_with_override_in_dir(
557 config: &AppConfig,
558 session_model: Option<&str>,
559 whisper_models_dir: &Path,
560) -> Result<PathBuf, AppError> {
561 let session_model = session_model.map(str::trim).filter(|m| !m.is_empty());
562
563 if let Some(model) = session_model {
564 let candidate = if model.contains('/') || model.contains('\\') {
565 PathBuf::from(model)
566 } else {
567 whisper_models_dir.join(model)
568 };
569
570 if !candidate.exists() {
571 return Err(AppError::Voice(format!(
572 "Local Whisper model file not found at: {}. Please download a whisper.cpp compatible model (.bin file).",
573 candidate.display()
574 )));
575 }
576
577 return Ok(candidate);
578 }
579
580 resolve_whisper_model_path(config)
581}
582
583#[cfg(all(test, feature = "voice-local"))]
584mod whisper_model_override_tests {
585 use super::*;
586
587 #[test]
588 fn session_model_filename_is_resolved_under_models_dir() {
589 let tmp = tempfile::tempdir().expect("tempdir");
590 let models_dir = tmp.path().join("models");
591 std::fs::create_dir_all(&models_dir).expect("create models dir");
592
593 let model_file = models_dir.join("ggml-tiny.en.bin");
594 std::fs::write(&model_file, b"test").expect("write model");
595
596 let cfg = AppConfig::default();
597 let resolved = resolve_whisper_model_path_with_override_in_dir(
598 &cfg,
599 Some("ggml-tiny.en.bin"),
600 &models_dir,
601 )
602 .expect("resolve");
603
604 assert_eq!(resolved, model_file);
605 }
606
607 #[test]
608 fn session_model_path_is_used_as_is() {
609 let tmp = tempfile::tempdir().expect("tempdir");
610 let model_file = tmp.path().join("ggml-small.en.bin");
611 std::fs::write(&model_file, b"test").expect("write model");
612
613 let cfg = AppConfig::default();
614 let resolved = resolve_whisper_model_path_with_override_in_dir(
615 &cfg,
616 Some(model_file.to_string_lossy().as_ref()),
617 tmp.path(),
618 )
619 .expect("resolve");
620
621 assert_eq!(resolved, model_file);
622 }
623
624 #[test]
625 fn missing_session_model_returns_actionable_error() {
626 let tmp = tempfile::tempdir().expect("tempdir");
627 let cfg = AppConfig::default();
628 let err = resolve_whisper_model_path_with_override_in_dir(
629 &cfg,
630 Some("does-not-exist.bin"),
631 tmp.path(),
632 )
633 .expect_err("should error");
634
635 match err {
636 AppError::Voice(msg) => assert!(msg.contains("Local Whisper model file not found")),
637 other => panic!("unexpected error type: {other:?}"),
638 }
639 }
640}
641
642#[cfg(feature = "voice-local")]
647pub fn load_audio_samples_16khz_mono(audio_path: &Path) -> Result<Vec<f32>, AppError> {
648 use hound::WavReader;
649
650 let mut reader = WavReader::open(audio_path)
651 .map_err(|e| AppError::Voice(format!("Failed to open audio file: {}", e)))?;
652
653 let spec = reader.spec();
654 let sample_rate = spec.sample_rate;
655 let channels = spec.channels as usize;
656
657 let samples: Vec<f32> = match spec.sample_format {
659 hound::SampleFormat::Int => {
660 let max_val = (1 << (spec.bits_per_sample - 1)) as f32;
661 let raw_samples: Result<Vec<i32>, _> = reader.samples::<i32>().collect();
662 raw_samples
663 .map_err(|e| AppError::Voice(format!("Failed to decode audio samples: {}", e)))?
664 .into_iter()
665 .map(|s| s as f32 / max_val)
666 .collect()
667 }
668 hound::SampleFormat::Float => {
669 let raw_samples: Result<Vec<f32>, _> = reader.samples::<f32>().collect();
670 raw_samples
671 .map_err(|e| AppError::Voice(format!("Failed to decode audio samples: {}", e)))?
672 }
673 };
674
675 let mono_samples: Vec<f32> = if channels > 1 {
677 samples
678 .chunks(channels)
679 .map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
680 .collect()
681 } else {
682 samples
683 };
684
685 let target_rate = 16000u32;
687 if sample_rate != target_rate {
688 let ratio = sample_rate as f32 / target_rate as f32;
689 let new_len = (mono_samples.len() as f32 / ratio) as usize;
690 let mut resampled = Vec::with_capacity(new_len);
691
692 for i in 0..new_len {
693 let src_idx = i as f32 * ratio;
694 let idx = src_idx as usize;
695 let frac = src_idx - idx as f32;
696
697 let sample = if idx + 1 < mono_samples.len() {
698 mono_samples[idx] * (1.0 - frac) + mono_samples[idx + 1] * frac
699 } else {
700 mono_samples[idx.min(mono_samples.len() - 1)]
701 };
702 resampled.push(sample);
703 }
704
705 Ok(resampled)
706 } else {
707 Ok(mono_samples)
708 }
709}