1use std::path::Path;
7#[cfg(feature = "voice-local")]
8use std::path::PathBuf;
9
10use crate::speech::TranscriptionResult;
11use gestura_core_config::AppConfig;
12use gestura_core_foundation::AppError;
13use gestura_core_foundation::secrets::{SecretKey, SecretProvider};
14use gestura_core_sessions::agent_sessions::SessionVoiceConfig;
15
16fn normalize_override(value: Option<&str>) -> Option<&str> {
18 value.map(str::trim).filter(|s| !s.is_empty())
19}
20
21fn resolve_effective_provider(config: &AppConfig, session: Option<&SessionVoiceConfig>) -> String {
25 normalize_override(session.and_then(|s| s.provider.as_deref()))
26 .map(str::to_string)
27 .unwrap_or_else(|| config.voice.provider.clone())
28}
29
30fn resolve_effective_openai_model(
37 config: &AppConfig,
38 session: Option<&SessionVoiceConfig>,
39) -> String {
40 if let Some(m) = normalize_override(session.and_then(|s| s.model.as_deref())) {
41 return m.to_string();
42 }
43
44 config
45 .voice
46 .openai_model
47 .clone()
48 .unwrap_or_else(|| "gpt-4o-transcribe".to_string())
49}
50
51#[async_trait::async_trait]
56pub trait SttProvider: Send + Sync {
57 fn provider_id(&self) -> &'static str;
59
60 async fn transcribe_file(&self, audio_path: &Path) -> Result<TranscriptionResult, AppError>;
62}
63
64pub struct UnconfiguredSttProvider {
66 message: String,
67}
68
69impl UnconfiguredSttProvider {
70 pub fn new(message: impl Into<String>) -> Self {
72 Self {
73 message: message.into(),
74 }
75 }
76}
77
78#[async_trait::async_trait]
79impl SttProvider for UnconfiguredSttProvider {
80 fn provider_id(&self) -> &'static str {
81 "unconfigured"
82 }
83
84 async fn transcribe_file(&self, _audio_path: &Path) -> Result<TranscriptionResult, AppError> {
85 Err(AppError::Voice(self.message.clone()))
86 }
87}
88
89pub struct OpenAiSttProvider {
91 pub api_key: String,
92 pub base_url: String,
93 pub model: String,
94}
95
96impl OpenAiSttProvider {
97 pub fn transcription_url(&self) -> String {
99 let base = self.base_url.trim_end_matches('/');
100 format!("{base}/v1/audio/transcriptions")
101 }
102}
103
104#[async_trait::async_trait]
105impl SttProvider for OpenAiSttProvider {
106 fn provider_id(&self) -> &'static str {
107 "openai"
108 }
109
110 async fn transcribe_file(&self, audio_path: &Path) -> Result<TranscriptionResult, AppError> {
111 let client = reqwest::Client::new();
112
113 let bytes = std::fs::read(audio_path)
114 .map_err(|e| AppError::Voice(format!("Failed to read audio file: {e}")))?;
115 let file_name = audio_path
116 .file_name()
117 .and_then(|s| s.to_str())
118 .unwrap_or("audio.wav")
119 .to_string();
120
121 let part = reqwest::multipart::Part::bytes(bytes)
122 .file_name(file_name)
123 .mime_str("audio/wav")
124 .map_err(|e| AppError::Voice(format!("Invalid multipart audio part: {e}")))?;
125
126 let form = reqwest::multipart::Form::new()
127 .text("model", self.model.clone())
128 .part("file", part);
129
130 let resp = client
131 .post(self.transcription_url())
132 .bearer_auth(&self.api_key)
133 .multipart(form)
134 .send()
135 .await
136 .map_err(|e| AppError::Voice(format!("OpenAI STT request failed: {e}")))?;
137
138 if !resp.status().is_success() {
139 let status = resp.status();
140 let body = resp.text().await.unwrap_or_default();
141 return Err(AppError::Voice(format!(
142 "OpenAI STT API error {status}: {body}"
143 )));
144 }
145
146 #[derive(serde::Deserialize)]
147 struct WhisperResponse {
148 text: String,
149 }
150
151 let result: WhisperResponse = resp
152 .json()
153 .await
154 .map_err(|e| AppError::Voice(format!("Failed to parse OpenAI STT response: {e}")))?;
155
156 Ok(TranscriptionResult {
157 text: result.text,
158 duration_secs: 0.0,
159 audio_path: Some(audio_path.to_path_buf()),
160 provider: "openai-whisper".to_string(),
161 })
162 }
163}
164
165#[cfg(feature = "voice-local")]
167pub struct LocalWhisperProvider {
168 pub model_path: PathBuf,
169}
170
171#[cfg(feature = "voice-local")]
172#[async_trait::async_trait]
173impl SttProvider for LocalWhisperProvider {
174 fn provider_id(&self) -> &'static str {
175 "local-whisper"
176 }
177
178 async fn transcribe_file(&self, audio_path: &Path) -> Result<TranscriptionResult, AppError> {
179 let model_path = self.model_path.clone();
182 let audio_path = audio_path.to_path_buf();
183
184 tokio::task::spawn_blocking(move || {
185 use whisper_rs::{
186 FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters,
187 };
188
189 let ctx = WhisperContext::new_with_params(
190 model_path
191 .to_str()
192 .ok_or_else(|| AppError::Voice("Invalid model path encoding".to_string()))?,
193 WhisperContextParameters::default(),
194 )
195 .map_err(|e| AppError::Voice(format!("Failed to load Whisper model: {e}")))?;
196
197 let samples = crate::speech::load_audio_samples_16khz_mono(&audio_path)?;
198 let duration_secs = samples.len() as f32 / 16000.0;
199
200 let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
201 params.set_language(Some("en"));
202 params.set_print_special(false);
203 params.set_print_progress(false);
204 params.set_print_realtime(false);
205 params.set_print_timestamps(false);
206 params.set_translate(false);
207 params.set_no_context(true);
208 params.set_single_segment(false);
209
210 let mut state = ctx
211 .create_state()
212 .map_err(|e| AppError::Voice(format!("Failed to create Whisper state: {e}")))?;
213 state
214 .full(params, &samples)
215 .map_err(|e| AppError::Voice(format!("Whisper transcription failed: {e}")))?;
216
217 let num_segments = state
218 .full_n_segments()
219 .map_err(|e| AppError::Voice(format!("Failed to get segment count: {e}")))?;
220 let mut text = String::new();
221 for i in 0..num_segments {
222 if let Ok(seg) = state.full_get_segment_text(i) {
223 text.push_str(seg.trim());
224 text.push(' ');
225 }
226 }
227 let text = text.trim().to_string();
228
229 Ok(TranscriptionResult {
230 text,
231 duration_secs,
232 audio_path: Some(audio_path),
233 provider: "local-whisper".to_string(),
234 })
235 })
236 .await
237 .map_err(|e| AppError::Voice(format!("Local Whisper transcription task failed: {e}")))?
238 }
239}
240
241pub async fn select_provider(
250 config: &AppConfig,
251 secrets: Option<&dyn SecretProvider>,
252) -> Box<dyn SttProvider> {
253 select_provider_with_session_voice_config(config, None, secrets).await
254}
255
256pub async fn select_provider_with_session_voice_config(
272 config: &AppConfig,
273 session_voice_config: Option<&SessionVoiceConfig>,
274 secrets: Option<&dyn SecretProvider>,
275) -> Box<dyn SttProvider> {
276 let effective_provider = resolve_effective_provider(config, session_voice_config);
277
278 match effective_provider.as_str() {
279 "openai" => {
280 let api_key = resolve_openai_stt_api_key(config, secrets).await;
281
282 if api_key.is_empty() {
283 return Box::new(UnconfiguredSttProvider::new(
284 "OpenAI STT selected but no API key configured. Set voice.openai_api_key, or store a key in secure storage under 'voice_openai' (preferred) or 'openai'.",
285 ));
286 }
287
288 let base_url = config
289 .voice
290 .openai_base_url
291 .clone()
292 .unwrap_or_else(|| "https://api.openai.com".to_string());
293 let model = resolve_effective_openai_model(config, session_voice_config);
294
295 Box::new(OpenAiSttProvider {
296 api_key,
297 base_url,
298 model,
299 })
300 }
301 "local" => {
302 #[cfg(feature = "voice-local")]
303 {
304 let session_model =
305 normalize_override(session_voice_config.and_then(|s| s.model.as_deref()));
306
307 match crate::speech::resolve_whisper_model_path_with_override(config, session_model)
308 {
309 Ok(model_path) => Box::new(LocalWhisperProvider { model_path }),
310 Err(e) => Box::new(UnconfiguredSttProvider::new(e.to_string())),
311 }
312 }
313 #[cfg(not(feature = "voice-local"))]
314 {
315 Box::new(UnconfiguredSttProvider::new(
316 "Local Whisper selected but the 'voice-local' feature is disabled.",
317 ))
318 }
319 }
320 "none" => Box::new(UnconfiguredSttProvider::new(
321 "STT provider is disabled (voice.provider=none).",
322 )),
323 other => Box::new(UnconfiguredSttProvider::new(format!(
324 "Unknown STT provider '{other}'. Supported: openai | local | none"
325 ))),
326 }
327}
328
329async fn resolve_openai_stt_api_key(
337 config: &AppConfig,
338 secrets: Option<&dyn SecretProvider>,
339) -> String {
340 let config_key = config.voice.openai_api_key.clone().unwrap_or_default();
341 if !config_key.is_empty() {
342 return config_key;
343 }
344
345 if let Some(secrets) = secrets {
346 if let Some(k) = secrets.get_secret(SecretKey::VoiceOpenAi).await
347 && !k.is_empty()
348 {
349 return k;
350 }
351 if let Some(k) = secrets.get_secret(SecretKey::OpenAi).await
352 && !k.is_empty()
353 {
354 return k;
355 }
356 }
357
358 config
360 .llm
361 .openai
362 .as_ref()
363 .map(|c| c.api_key.clone())
364 .unwrap_or_default()
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 use std::io::{Read, Write};
372 use std::net::{TcpListener, TcpStream};
373 use std::sync::{Arc, Mutex};
374 use std::thread;
375 use std::time::Duration;
376
377 #[derive(Debug, Default)]
378 struct TestSecrets(std::collections::HashMap<SecretKey, String>);
379
380 #[async_trait::async_trait]
381 impl SecretProvider for TestSecrets {
382 async fn get_secret(&self, key: SecretKey) -> Option<String> {
383 self.0.get(&key).cloned().filter(|s| !s.is_empty())
384 }
385 }
386
387 #[derive(Clone, Default)]
389 struct CapturedRequest(Arc<Mutex<Vec<u8>>>);
390
391 impl CapturedRequest {
392 fn take(&self) -> Vec<u8> {
394 std::mem::take(&mut *self.0.lock().expect("capture lock"))
395 }
396 }
397
398 fn find_subslice(haystack: &[u8], needle: &[u8]) -> Option<usize> {
400 haystack.windows(needle.len()).position(|w| w == needle)
401 }
402
403 fn capture_http_request(stream: &mut TcpStream) -> Vec<u8> {
408 let mut buf = Vec::<u8>::new();
409 let mut tmp = [0u8; 8 * 1024];
410
411 let mut header_end: Option<usize> = None;
412 let mut content_length: Option<usize> = None;
413 let mut chunked = false;
414 let mut sent_continue = false;
415
416 loop {
417 match stream.read(&mut tmp) {
418 Ok(0) => break,
419 Ok(n) => {
420 buf.extend_from_slice(&tmp[..n]);
421 }
422 Err(_) => break,
423 }
424
425 if header_end.is_none()
426 && let Some(pos) = find_subslice(&buf, b"\r\n\r\n")
427 {
428 let end = pos + 4;
429 header_end = Some(end);
430
431 let header_text = String::from_utf8_lossy(&buf[..end]);
432 for line in header_text.split("\r\n") {
433 let lower = line.to_ascii_lowercase();
434 if let Some(v) = lower.strip_prefix("content-length:")
435 && let Ok(n) = v.trim().parse::<usize>()
436 {
437 content_length = Some(n);
438 }
439 if lower.starts_with("transfer-encoding:") && lower.contains("chunked") {
440 chunked = true;
441 }
442 if lower == "expect: 100-continue" {
443 if !sent_continue {
445 let _ = stream.write_all(b"HTTP/1.1 100 Continue\r\n\r\n");
446 let _ = stream.flush();
447 sent_continue = true;
448 }
449 }
450 }
451 }
452
453 if let Some(h_end) = header_end {
454 if let Some(len) = content_length {
455 if buf.len() >= h_end + len {
456 break;
457 }
458 } else if chunked {
459 if find_subslice(&buf[h_end..], b"\r\n0\r\n\r\n").is_some() {
461 break;
462 }
463 }
464 }
465 }
466
467 buf
468 }
469
470 fn spawn_mock_http_server(
475 status: u16,
476 content_type: &'static str,
477 body: &'static str,
478 ) -> (String, CapturedRequest, thread::JoinHandle<()>) {
479 let listener = TcpListener::bind(("127.0.0.1", 0)).expect("bind tcp listener");
480 let addr = listener.local_addr().expect("local addr");
481
482 let captured = CapturedRequest::default();
483 let captured_for_thread = captured.clone();
484
485 let handle = thread::spawn(move || {
486 let (mut stream, _) = listener.accept().expect("accept");
487 let _ = stream.set_read_timeout(Some(Duration::from_secs(2)));
488
489 let req = capture_http_request(&mut stream);
490 *captured_for_thread.0.lock().expect("capture lock") = req;
491
492 let body_bytes = body.as_bytes();
493 let resp = format!(
494 "HTTP/1.1 {status} OK\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
495 body_bytes.len()
496 );
497 stream
498 .write_all(resp.as_bytes())
499 .and_then(|_| stream.write_all(body_bytes))
500 .and_then(|_| stream.flush())
501 .expect("write response");
502 });
503
504 (format!("http://{addr}"), captured, handle)
505 }
506
507 #[test]
508 fn openai_transcription_url_uses_base_url() {
509 let p = OpenAiSttProvider {
510 api_key: "x".into(),
511 base_url: "https://example.com/".into(),
512 model: "whisper-1".into(),
513 };
514 assert_eq!(
515 p.transcription_url(),
516 "https://example.com/v1/audio/transcriptions"
517 );
518 }
519
520 #[test]
521 fn resolve_openai_model_prefers_session_then_config_then_default() {
522 let mut cfg = AppConfig::default();
523 cfg.voice.openai_model = Some("cfg-model".into());
524
525 let session = SessionVoiceConfig {
526 provider: None,
527 model: Some("session-model".into()),
528 };
529 assert_eq!(
530 resolve_effective_openai_model(&cfg, Some(&session)),
531 "session-model"
532 );
533
534 let session_blank = SessionVoiceConfig {
535 provider: None,
536 model: Some(" ".into()),
537 };
538 assert_eq!(
539 resolve_effective_openai_model(&cfg, Some(&session_blank)),
540 "cfg-model"
541 );
542
543 let mut cfg2 = AppConfig::default();
544 cfg2.voice.openai_model = None;
545 assert_eq!(
546 resolve_effective_openai_model(&cfg2, None),
547 "gpt-4o-transcribe"
548 );
549 }
550
551 #[tokio::test]
552 async fn session_provider_override_wins_over_config() {
553 let mut cfg = AppConfig::default();
554 cfg.voice.provider = "openai".into();
555 cfg.voice.openai_api_key = Some("cfg_voice".into());
556
557 let session = SessionVoiceConfig {
558 provider: Some("none".into()),
559 model: None,
560 };
561
562 let p = select_provider_with_session_voice_config(&cfg, Some(&session), None).await;
563 assert_eq!(p.provider_id(), "unconfigured");
564 }
565
566 #[tokio::test]
567 async fn session_provider_override_is_trimmed() {
568 let mut cfg = AppConfig::default();
569 cfg.voice.provider = "none".into();
570 cfg.voice.openai_api_key = Some("cfg_voice".into());
571
572 let session = SessionVoiceConfig {
573 provider: Some(" openai ".into()),
574 model: None,
575 };
576
577 let p = select_provider_with_session_voice_config(&cfg, Some(&session), None).await;
578 assert_eq!(p.provider_id(), "openai");
579 }
580
581 #[tokio::test]
582 async fn blank_session_provider_override_uses_config_provider() {
583 let mut cfg = AppConfig::default();
584 cfg.voice.provider = "none".into();
585
586 let session = SessionVoiceConfig {
587 provider: Some(" ".into()),
588 model: None,
589 };
590
591 let p = select_provider_with_session_voice_config(&cfg, Some(&session), None).await;
592 assert_eq!(p.provider_id(), "unconfigured");
593 }
594
595 #[tokio::test]
596 async fn unknown_provider_yields_unconfigured_provider() {
597 let mut cfg = AppConfig::default();
598 cfg.voice.provider = "wat".into();
599
600 let p = select_provider(&cfg, None).await;
601 assert_eq!(p.provider_id(), "unconfigured");
602 }
603
604 #[cfg(feature = "voice-local")]
605 #[tokio::test]
606 async fn session_local_model_path_override_selects_local_provider() {
607 let tmp = tempfile::tempdir().expect("tempdir");
608 let model_file = tmp.path().join("ggml-tiny.en.bin");
609 std::fs::write(&model_file, b"test").expect("write model");
610
611 let mut cfg = AppConfig::default();
612 cfg.voice.provider = "openai".into();
613
614 let session = SessionVoiceConfig {
615 provider: Some("local".into()),
616 model: Some(model_file.to_string_lossy().to_string()),
617 };
618
619 let p = select_provider_with_session_voice_config(&cfg, Some(&session), None).await;
620 assert_eq!(p.provider_id(), "local-whisper");
621 }
622
623 #[tokio::test]
624 async fn resolve_openai_key_prefers_voice_config_over_secrets() {
625 let mut cfg = AppConfig::default();
626 cfg.voice.provider = "openai".into();
627 cfg.voice.openai_api_key = Some("cfg_voice".into());
628
629 let mut secrets = TestSecrets::default();
630 secrets
631 .0
632 .insert(SecretKey::VoiceOpenAi, "secret_voice".into());
633
634 let p = select_provider(&cfg, Some(&secrets)).await;
635 assert_eq!(p.provider_id(), "openai");
636 }
637
638 #[tokio::test]
639 async fn resolve_openai_key_uses_voice_secret_then_general_secret_then_llm_fallback() {
640 let mut cfg = AppConfig::default();
641 cfg.voice.provider = "openai".into();
642 cfg.voice.openai_api_key = None;
643 cfg.llm.openai = Some(gestura_core_config::OpenAiConfig {
644 api_key: "cfg_llm".into(),
645 model: "gpt-4o-mini".into(),
646 base_url: None,
647 });
648
649 let mut s1 = TestSecrets::default();
651 s1.0.insert(SecretKey::VoiceOpenAi, "secret_voice".into());
652 let p1 = select_provider(&cfg, Some(&s1)).await;
653 assert_eq!(p1.provider_id(), "openai");
654
655 let mut s2 = TestSecrets::default();
657 s2.0.insert(SecretKey::OpenAi, "secret_general".into());
658 let p2 = select_provider(&cfg, Some(&s2)).await;
659 assert_eq!(p2.provider_id(), "openai");
660
661 let p3 = select_provider(&cfg, None).await;
663 assert_eq!(p3.provider_id(), "openai");
664 }
665
666 #[tokio::test]
667 async fn openai_selected_without_any_key_is_unconfigured() {
668 let mut cfg = AppConfig::default();
669 cfg.voice.provider = "openai".into();
670 cfg.voice.openai_api_key = None;
671 cfg.llm.openai = None;
672
673 let secrets = TestSecrets::default();
674 let p = select_provider(&cfg, Some(&secrets)).await;
675 assert_eq!(p.provider_id(), "unconfigured");
676 }
677
678 #[tokio::test]
679 async fn openai_stt_request_includes_bearer_auth_and_model_field() {
680 let (base_url, captured, server) =
681 spawn_mock_http_server(200, "application/json", r#"{"text":"hello"}"#);
682
683 let tmp = tempfile::tempdir().expect("tempdir");
684 let audio_path = tmp.path().join("audio.wav");
685 std::fs::write(&audio_path, b"RIFF....WAVEfmt ").expect("write audio");
686
687 let p = OpenAiSttProvider {
688 api_key: "TEST_KEY".into(),
689 base_url,
690 model: "gpt-4o-transcribe".into(),
691 };
692
693 let result = tokio::time::timeout(Duration::from_secs(5), p.transcribe_file(&audio_path))
694 .await
695 .expect("transcribe timeout")
696 .expect("transcribe ok");
697 assert_eq!(result.text, "hello");
698
699 server.join().expect("server join");
700 let req = String::from_utf8_lossy(&captured.take()).to_ascii_lowercase();
701
702 assert!(req.contains("post /v1/audio/transcriptions"));
703 assert!(req.contains("authorization: bearer test_key"));
704 assert!(req.contains("content-type: multipart/form-data"));
705 assert!(req.contains("name=\"model\""));
706 assert!(req.contains("gpt-4o-transcribe"));
707 assert!(req.contains("name=\"file\""));
708 }
709
710 #[tokio::test]
711 async fn openai_stt_non_success_status_maps_to_voice_error_with_body() {
712 let (base_url, _captured, server) = spawn_mock_http_server(401, "text/plain", "nope");
713
714 let tmp = tempfile::tempdir().expect("tempdir");
715 let audio_path = tmp.path().join("audio.wav");
716 std::fs::write(&audio_path, b"x").expect("write audio");
717
718 let p = OpenAiSttProvider {
719 api_key: "TEST_KEY".into(),
720 base_url,
721 model: "gpt-4o-transcribe".into(),
722 };
723
724 let err = tokio::time::timeout(Duration::from_secs(5), p.transcribe_file(&audio_path))
725 .await
726 .expect("transcribe timeout")
727 .expect_err("expected error");
728
729 server.join().expect("server join");
730
731 match err {
732 AppError::Voice(msg) => {
733 assert!(msg.contains("401"), "msg={msg}");
734 assert!(msg.contains("nope"), "msg={msg}");
735 }
736 other => panic!("expected AppError::Voice, got {other:?}"),
737 }
738 }
739}