gestura_core_audio/
noise_cancellation.rs

1//! Noise cancellation and audio enhancement for Gestura.app
2//! Reduces background noise and enhances speech quality
3
4#[allow(unused_imports)]
5use gestura_core_foundation::AppError;
6use std::collections::VecDeque;
7
8/// Noise cancellation configuration
9#[derive(Debug, Clone)]
10pub struct NoiseCancellationConfig {
11    /// Sample rate in Hz
12    pub sample_rate: u32,
13    /// Frame size for processing
14    pub frame_size: usize,
15    /// Noise floor estimation window size
16    pub noise_window_size: usize,
17    /// Spectral subtraction factor
18    pub subtraction_factor: f32,
19    /// Minimum gain to prevent over-subtraction
20    pub min_gain: f32,
21    /// Smoothing factor for gain updates
22    pub smoothing_factor: f32,
23    /// Enable adaptive noise estimation
24    pub adaptive_estimation: bool,
25}
26
27impl Default for NoiseCancellationConfig {
28    fn default() -> Self {
29        Self {
30            sample_rate: 16000,
31            frame_size: 512,
32            noise_window_size: 20,
33            subtraction_factor: 2.0,
34            min_gain: 0.1,
35            smoothing_factor: 0.8,
36            adaptive_estimation: true,
37        }
38    }
39}
40
41/// Noise cancellation processor
42pub struct NoiseCancellationProcessor {
43    config: NoiseCancellationConfig,
44    noise_spectrum: Vec<f32>,
45    gain_history: VecDeque<Vec<f32>>,
46    frame_buffer: Vec<f32>,
47    window_function: Vec<f32>,
48    noise_frames: VecDeque<Vec<f32>>,
49    is_noise_estimated: bool,
50}
51
52impl NoiseCancellationProcessor {
53    /// Create a new noise cancellation processor
54    pub fn new(config: NoiseCancellationConfig) -> Self {
55        let window_function = Self::create_hann_window(config.frame_size);
56
57        Self {
58            noise_spectrum: vec![0.0; config.frame_size / 2 + 1],
59            gain_history: VecDeque::with_capacity(10),
60            frame_buffer: Vec::with_capacity(config.frame_size * 2),
61            window_function,
62            noise_frames: VecDeque::with_capacity(config.noise_window_size),
63            is_noise_estimated: false,
64            config,
65        }
66    }
67
68    /// Process audio frame with noise cancellation
69    pub fn process_frame(&mut self, input_frame: &[f32]) -> Result<Vec<f32>, AppError> {
70        if input_frame.len() != self.config.frame_size {
71            return Err(AppError::Internal(format!(
72                "Frame size mismatch: expected {}, got {}",
73                self.config.frame_size,
74                input_frame.len()
75            )));
76        }
77
78        // Apply window function
79        let windowed_frame: Vec<f32> = input_frame
80            .iter()
81            .zip(self.window_function.iter())
82            .map(|(sample, window)| sample * window)
83            .collect();
84
85        // Compute FFT (simplified - in real implementation use proper FFT library)
86        let spectrum = self.compute_spectrum(&windowed_frame);
87
88        // Update noise estimation if needed
89        if !self.is_noise_estimated || self.config.adaptive_estimation {
90            self.update_noise_estimation(&spectrum);
91        }
92
93        // Apply spectral subtraction
94        let enhanced_spectrum = self.apply_spectral_subtraction(&spectrum);
95
96        // Convert back to time domain (simplified IFFT)
97        let enhanced_frame = self.spectrum_to_time_domain(&enhanced_spectrum);
98
99        Ok(enhanced_frame)
100    }
101
102    /// Process streaming audio with overlap-add
103    pub fn process_stream(&mut self, audio_data: &[f32]) -> Result<Vec<f32>, AppError> {
104        let mut output = Vec::new();
105
106        // Add new data to buffer
107        self.frame_buffer.extend_from_slice(audio_data);
108
109        // Process complete frames with 50% overlap
110        let hop_size = self.config.frame_size / 2;
111
112        while self.frame_buffer.len() >= self.config.frame_size {
113            let frame: Vec<f32> = self
114                .frame_buffer
115                .iter()
116                .take(self.config.frame_size)
117                .cloned()
118                .collect();
119            let processed_frame = self.process_frame(&frame)?;
120
121            // Overlap-add
122            if output.len() < hop_size {
123                output.extend_from_slice(&processed_frame[..hop_size]);
124            } else {
125                let start_idx = output.len() - hop_size;
126                for (i, &sample) in processed_frame.iter().take(hop_size).enumerate() {
127                    output[start_idx + i] += sample;
128                }
129                output.extend_from_slice(&processed_frame[hop_size..]);
130            }
131
132            // Remove processed samples
133            self.frame_buffer.drain(..hop_size);
134        }
135
136        Ok(output)
137    }
138
139    /// Estimate noise spectrum from initial frames
140    pub fn estimate_noise(&mut self, noise_samples: &[Vec<f32>]) -> Result<(), AppError> {
141        if noise_samples.is_empty() {
142            return Err(AppError::Internal(
143                "Need at least one noise sample".to_string(),
144            ));
145        }
146
147        let mut accumulated_spectrum = vec![0.0; self.config.frame_size / 2 + 1];
148
149        for sample in noise_samples {
150            if sample.len() != self.config.frame_size {
151                continue;
152            }
153
154            let windowed: Vec<f32> = sample
155                .iter()
156                .zip(self.window_function.iter())
157                .map(|(s, w)| s * w)
158                .collect();
159
160            let spectrum = self.compute_spectrum(&windowed);
161
162            for (i, &mag) in spectrum.iter().enumerate() {
163                accumulated_spectrum[i] += mag;
164            }
165        }
166
167        // Average the accumulated spectrum
168        let count = noise_samples.len() as f32;
169        for magnitude in accumulated_spectrum.iter_mut() {
170            *magnitude /= count;
171        }
172
173        self.noise_spectrum = accumulated_spectrum;
174        self.is_noise_estimated = true;
175
176        tracing::info!(
177            "Noise spectrum estimated from {} samples",
178            noise_samples.len()
179        );
180        Ok(())
181    }
182
183    /// Update noise estimation adaptively
184    fn update_noise_estimation(&mut self, current_spectrum: &[f32]) {
185        if !self.config.adaptive_estimation {
186            return;
187        }
188
189        // Store current frame for noise estimation
190        self.noise_frames.push_back(current_spectrum.to_vec());
191
192        if self.noise_frames.len() > self.config.noise_window_size {
193            self.noise_frames.pop_front();
194        }
195
196        // Update noise spectrum (simple moving average)
197        if self.noise_frames.len() >= self.config.noise_window_size / 2 {
198            let mut updated_noise = vec![0.0; current_spectrum.len()];
199
200            for frame in &self.noise_frames {
201                for (i, &mag) in frame.iter().enumerate() {
202                    updated_noise[i] += mag;
203                }
204            }
205
206            let count = self.noise_frames.len() as f32;
207            for magnitude in updated_noise.iter_mut() {
208                *magnitude /= count;
209            }
210
211            // Smooth update with existing noise estimate
212            if self.is_noise_estimated {
213                let alpha = 0.1; // Update rate
214                for (i, &new_mag) in updated_noise.iter().enumerate() {
215                    self.noise_spectrum[i] =
216                        (1.0 - alpha) * self.noise_spectrum[i] + alpha * new_mag;
217                }
218            } else {
219                self.noise_spectrum = updated_noise;
220                self.is_noise_estimated = true;
221            }
222        }
223    }
224
225    /// Apply spectral subtraction for noise reduction
226    fn apply_spectral_subtraction(&mut self, input_spectrum: &[f32]) -> Vec<f32> {
227        let mut enhanced_spectrum = Vec::with_capacity(input_spectrum.len());
228        let mut current_gains = Vec::with_capacity(input_spectrum.len());
229
230        for (i, &input_mag) in input_spectrum.iter().enumerate() {
231            let noise_mag = if i < self.noise_spectrum.len() {
232                self.noise_spectrum[i]
233            } else {
234                0.0
235            };
236
237            // Spectral subtraction
238            let subtracted_mag = input_mag - self.config.subtraction_factor * noise_mag;
239
240            // Calculate gain
241            let gain = if input_mag > 0.0 {
242                (subtracted_mag / input_mag).max(self.config.min_gain)
243            } else {
244                self.config.min_gain
245            };
246
247            current_gains.push(gain);
248            enhanced_spectrum.push(input_mag * gain);
249        }
250
251        // Apply gain smoothing
252        if let Some(previous_gains) = self.gain_history.back() {
253            for (i, gain) in current_gains.iter_mut().enumerate() {
254                if i < previous_gains.len() {
255                    *gain = self.config.smoothing_factor * previous_gains[i]
256                        + (1.0 - self.config.smoothing_factor) * *gain;
257                    enhanced_spectrum[i] = input_spectrum[i] * *gain;
258                }
259            }
260        }
261
262        // Store gains for next frame
263        self.gain_history.push_back(current_gains);
264        if self.gain_history.len() > 5 {
265            self.gain_history.pop_front();
266        }
267
268        enhanced_spectrum
269    }
270
271    /// Compute magnitude spectrum (simplified)
272    fn compute_spectrum(&self, frame: &[f32]) -> Vec<f32> {
273        // This is a simplified spectrum computation
274        // In a real implementation, use proper FFT library like rustfft
275
276        let mut spectrum = Vec::with_capacity(self.config.frame_size / 2 + 1);
277
278        for k in 0..=self.config.frame_size / 2 {
279            let mut real = 0.0;
280            let mut imag = 0.0;
281
282            for (n, &sample) in frame.iter().enumerate() {
283                let angle = -2.0 * std::f32::consts::PI * (k as f32) * (n as f32)
284                    / (self.config.frame_size as f32);
285                real += sample * angle.cos();
286                imag += sample * angle.sin();
287            }
288
289            let magnitude = (real * real + imag * imag).sqrt();
290            spectrum.push(magnitude);
291        }
292
293        spectrum
294    }
295
296    /// Convert spectrum back to time domain (simplified IFFT)
297    fn spectrum_to_time_domain(&self, spectrum: &[f32]) -> Vec<f32> {
298        // This is a simplified inverse transform
299        // In a real implementation, use proper IFFT
300
301        let mut time_domain = vec![0.0; self.config.frame_size];
302
303        for (n, sample) in time_domain.iter_mut().enumerate() {
304            for (k, &magnitude) in spectrum.iter().enumerate() {
305                let angle = 2.0 * std::f32::consts::PI * (k as f32) * (n as f32)
306                    / (self.config.frame_size as f32);
307                *sample += magnitude * angle.cos() / (self.config.frame_size as f32);
308            }
309        }
310
311        // Apply window function again
312        for (i, sample) in time_domain.iter_mut().enumerate() {
313            *sample *= self.window_function[i];
314        }
315
316        time_domain
317    }
318
319    /// Create Hann window function
320    fn create_hann_window(size: usize) -> Vec<f32> {
321        (0..size)
322            .map(|n| {
323                0.5 * (1.0 - (2.0 * std::f32::consts::PI * n as f32 / (size - 1) as f32).cos())
324            })
325            .collect()
326    }
327
328    /// Get noise reduction statistics
329    pub fn get_stats(&self) -> NoiseReductionStats {
330        let noise_level = if !self.noise_spectrum.is_empty() {
331            self.noise_spectrum.iter().sum::<f32>() / self.noise_spectrum.len() as f32
332        } else {
333            0.0
334        };
335
336        NoiseReductionStats {
337            is_noise_estimated: self.is_noise_estimated,
338            noise_level,
339            frames_processed: self.gain_history.len(),
340            subtraction_factor: self.config.subtraction_factor,
341            min_gain: self.config.min_gain,
342        }
343    }
344
345    /// Reset processor state
346    pub fn reset(&mut self) {
347        self.noise_spectrum.fill(0.0);
348        self.gain_history.clear();
349        self.frame_buffer.clear();
350        self.noise_frames.clear();
351        self.is_noise_estimated = false;
352    }
353
354    /// Update configuration
355    pub fn update_config(&mut self, config: NoiseCancellationConfig) {
356        if config.frame_size != self.config.frame_size {
357            // Frame size changed, need to recreate window and reset state
358            self.window_function = Self::create_hann_window(config.frame_size);
359            self.noise_spectrum = vec![0.0; config.frame_size / 2 + 1];
360            self.reset();
361        }
362
363        self.config = config;
364    }
365}
366
367/// Noise reduction statistics
368#[derive(Debug, Clone, serde::Serialize)]
369pub struct NoiseReductionStats {
370    pub is_noise_estimated: bool,
371    pub noise_level: f32,
372    pub frames_processed: usize,
373    pub subtraction_factor: f32,
374    pub min_gain: f32,
375}
376
377/// Create a noise cancellation processor with speech-optimized settings
378pub fn create_speech_noise_canceller() -> NoiseCancellationProcessor {
379    let config = NoiseCancellationConfig {
380        sample_rate: 16000,
381        frame_size: 512,
382        noise_window_size: 30,
383        subtraction_factor: 1.5,
384        min_gain: 0.15,
385        smoothing_factor: 0.85,
386        adaptive_estimation: true,
387    };
388    NoiseCancellationProcessor::new(config)
389}
390
391/// Create a noise cancellation processor with music-optimized settings
392pub fn create_music_noise_canceller() -> NoiseCancellationProcessor {
393    let config = NoiseCancellationConfig {
394        sample_rate: 44100,
395        frame_size: 1024,
396        noise_window_size: 50,
397        subtraction_factor: 1.2,
398        min_gain: 0.2,
399        smoothing_factor: 0.9,
400        adaptive_estimation: true,
401    };
402    NoiseCancellationProcessor::new(config)
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_noise_canceller_creation() {
411        let config = NoiseCancellationConfig::default();
412        let processor = NoiseCancellationProcessor::new(config);
413
414        let stats = processor.get_stats();
415        assert!(!stats.is_noise_estimated);
416        assert_eq!(stats.frames_processed, 0);
417    }
418
419    #[test]
420    fn test_hann_window() {
421        let window = NoiseCancellationProcessor::create_hann_window(4);
422        assert_eq!(window.len(), 4);
423        assert!(window[0] < window[1]); // Should increase from edges
424        assert!(window[2] > window[3]); // Should decrease to edges
425    }
426
427    #[test]
428    fn test_noise_estimation() {
429        let config = NoiseCancellationConfig {
430            frame_size: 4,
431            ..NoiseCancellationConfig::default()
432        };
433        let mut processor = NoiseCancellationProcessor::new(config);
434
435        let noise_samples = vec![vec![0.1, 0.1, 0.1, 0.1], vec![0.2, 0.2, 0.2, 0.2]];
436
437        processor.estimate_noise(&noise_samples).unwrap();
438
439        let stats = processor.get_stats();
440        assert!(stats.is_noise_estimated);
441        assert!(stats.noise_level > 0.0);
442    }
443
444    #[test]
445    fn test_frame_processing() {
446        let config = NoiseCancellationConfig {
447            frame_size: 4,
448            ..NoiseCancellationConfig::default()
449        };
450        let mut processor = NoiseCancellationProcessor::new(config);
451
452        // Estimate noise first
453        let noise_samples = vec![vec![0.01, 0.01, 0.01, 0.01]];
454        processor.estimate_noise(&noise_samples).unwrap();
455
456        // Process a frame
457        let input_frame = vec![0.5, 0.4, 0.3, 0.2];
458        let output_frame = processor.process_frame(&input_frame).unwrap();
459
460        assert_eq!(output_frame.len(), input_frame.len());
461    }
462
463    #[test]
464    fn test_speech_optimized_settings() {
465        let processor = create_speech_noise_canceller();
466        let stats = processor.get_stats();
467
468        assert_eq!(stats.subtraction_factor, 1.5);
469        assert_eq!(stats.min_gain, 0.15);
470    }
471}