1#[allow(unused_imports)]
5use gestura_core_foundation::AppError;
6use std::collections::VecDeque;
7
8#[derive(Debug, Clone)]
10pub struct NoiseCancellationConfig {
11 pub sample_rate: u32,
13 pub frame_size: usize,
15 pub noise_window_size: usize,
17 pub subtraction_factor: f32,
19 pub min_gain: f32,
21 pub smoothing_factor: f32,
23 pub adaptive_estimation: bool,
25}
26
27impl Default for NoiseCancellationConfig {
28 fn default() -> Self {
29 Self {
30 sample_rate: 16000,
31 frame_size: 512,
32 noise_window_size: 20,
33 subtraction_factor: 2.0,
34 min_gain: 0.1,
35 smoothing_factor: 0.8,
36 adaptive_estimation: true,
37 }
38 }
39}
40
41pub struct NoiseCancellationProcessor {
43 config: NoiseCancellationConfig,
44 noise_spectrum: Vec<f32>,
45 gain_history: VecDeque<Vec<f32>>,
46 frame_buffer: Vec<f32>,
47 window_function: Vec<f32>,
48 noise_frames: VecDeque<Vec<f32>>,
49 is_noise_estimated: bool,
50}
51
52impl NoiseCancellationProcessor {
53 pub fn new(config: NoiseCancellationConfig) -> Self {
55 let window_function = Self::create_hann_window(config.frame_size);
56
57 Self {
58 noise_spectrum: vec![0.0; config.frame_size / 2 + 1],
59 gain_history: VecDeque::with_capacity(10),
60 frame_buffer: Vec::with_capacity(config.frame_size * 2),
61 window_function,
62 noise_frames: VecDeque::with_capacity(config.noise_window_size),
63 is_noise_estimated: false,
64 config,
65 }
66 }
67
68 pub fn process_frame(&mut self, input_frame: &[f32]) -> Result<Vec<f32>, AppError> {
70 if input_frame.len() != self.config.frame_size {
71 return Err(AppError::Internal(format!(
72 "Frame size mismatch: expected {}, got {}",
73 self.config.frame_size,
74 input_frame.len()
75 )));
76 }
77
78 let windowed_frame: Vec<f32> = input_frame
80 .iter()
81 .zip(self.window_function.iter())
82 .map(|(sample, window)| sample * window)
83 .collect();
84
85 let spectrum = self.compute_spectrum(&windowed_frame);
87
88 if !self.is_noise_estimated || self.config.adaptive_estimation {
90 self.update_noise_estimation(&spectrum);
91 }
92
93 let enhanced_spectrum = self.apply_spectral_subtraction(&spectrum);
95
96 let enhanced_frame = self.spectrum_to_time_domain(&enhanced_spectrum);
98
99 Ok(enhanced_frame)
100 }
101
102 pub fn process_stream(&mut self, audio_data: &[f32]) -> Result<Vec<f32>, AppError> {
104 let mut output = Vec::new();
105
106 self.frame_buffer.extend_from_slice(audio_data);
108
109 let hop_size = self.config.frame_size / 2;
111
112 while self.frame_buffer.len() >= self.config.frame_size {
113 let frame: Vec<f32> = self
114 .frame_buffer
115 .iter()
116 .take(self.config.frame_size)
117 .cloned()
118 .collect();
119 let processed_frame = self.process_frame(&frame)?;
120
121 if output.len() < hop_size {
123 output.extend_from_slice(&processed_frame[..hop_size]);
124 } else {
125 let start_idx = output.len() - hop_size;
126 for (i, &sample) in processed_frame.iter().take(hop_size).enumerate() {
127 output[start_idx + i] += sample;
128 }
129 output.extend_from_slice(&processed_frame[hop_size..]);
130 }
131
132 self.frame_buffer.drain(..hop_size);
134 }
135
136 Ok(output)
137 }
138
139 pub fn estimate_noise(&mut self, noise_samples: &[Vec<f32>]) -> Result<(), AppError> {
141 if noise_samples.is_empty() {
142 return Err(AppError::Internal(
143 "Need at least one noise sample".to_string(),
144 ));
145 }
146
147 let mut accumulated_spectrum = vec![0.0; self.config.frame_size / 2 + 1];
148
149 for sample in noise_samples {
150 if sample.len() != self.config.frame_size {
151 continue;
152 }
153
154 let windowed: Vec<f32> = sample
155 .iter()
156 .zip(self.window_function.iter())
157 .map(|(s, w)| s * w)
158 .collect();
159
160 let spectrum = self.compute_spectrum(&windowed);
161
162 for (i, &mag) in spectrum.iter().enumerate() {
163 accumulated_spectrum[i] += mag;
164 }
165 }
166
167 let count = noise_samples.len() as f32;
169 for magnitude in accumulated_spectrum.iter_mut() {
170 *magnitude /= count;
171 }
172
173 self.noise_spectrum = accumulated_spectrum;
174 self.is_noise_estimated = true;
175
176 tracing::info!(
177 "Noise spectrum estimated from {} samples",
178 noise_samples.len()
179 );
180 Ok(())
181 }
182
183 fn update_noise_estimation(&mut self, current_spectrum: &[f32]) {
185 if !self.config.adaptive_estimation {
186 return;
187 }
188
189 self.noise_frames.push_back(current_spectrum.to_vec());
191
192 if self.noise_frames.len() > self.config.noise_window_size {
193 self.noise_frames.pop_front();
194 }
195
196 if self.noise_frames.len() >= self.config.noise_window_size / 2 {
198 let mut updated_noise = vec![0.0; current_spectrum.len()];
199
200 for frame in &self.noise_frames {
201 for (i, &mag) in frame.iter().enumerate() {
202 updated_noise[i] += mag;
203 }
204 }
205
206 let count = self.noise_frames.len() as f32;
207 for magnitude in updated_noise.iter_mut() {
208 *magnitude /= count;
209 }
210
211 if self.is_noise_estimated {
213 let alpha = 0.1; for (i, &new_mag) in updated_noise.iter().enumerate() {
215 self.noise_spectrum[i] =
216 (1.0 - alpha) * self.noise_spectrum[i] + alpha * new_mag;
217 }
218 } else {
219 self.noise_spectrum = updated_noise;
220 self.is_noise_estimated = true;
221 }
222 }
223 }
224
225 fn apply_spectral_subtraction(&mut self, input_spectrum: &[f32]) -> Vec<f32> {
227 let mut enhanced_spectrum = Vec::with_capacity(input_spectrum.len());
228 let mut current_gains = Vec::with_capacity(input_spectrum.len());
229
230 for (i, &input_mag) in input_spectrum.iter().enumerate() {
231 let noise_mag = if i < self.noise_spectrum.len() {
232 self.noise_spectrum[i]
233 } else {
234 0.0
235 };
236
237 let subtracted_mag = input_mag - self.config.subtraction_factor * noise_mag;
239
240 let gain = if input_mag > 0.0 {
242 (subtracted_mag / input_mag).max(self.config.min_gain)
243 } else {
244 self.config.min_gain
245 };
246
247 current_gains.push(gain);
248 enhanced_spectrum.push(input_mag * gain);
249 }
250
251 if let Some(previous_gains) = self.gain_history.back() {
253 for (i, gain) in current_gains.iter_mut().enumerate() {
254 if i < previous_gains.len() {
255 *gain = self.config.smoothing_factor * previous_gains[i]
256 + (1.0 - self.config.smoothing_factor) * *gain;
257 enhanced_spectrum[i] = input_spectrum[i] * *gain;
258 }
259 }
260 }
261
262 self.gain_history.push_back(current_gains);
264 if self.gain_history.len() > 5 {
265 self.gain_history.pop_front();
266 }
267
268 enhanced_spectrum
269 }
270
271 fn compute_spectrum(&self, frame: &[f32]) -> Vec<f32> {
273 let mut spectrum = Vec::with_capacity(self.config.frame_size / 2 + 1);
277
278 for k in 0..=self.config.frame_size / 2 {
279 let mut real = 0.0;
280 let mut imag = 0.0;
281
282 for (n, &sample) in frame.iter().enumerate() {
283 let angle = -2.0 * std::f32::consts::PI * (k as f32) * (n as f32)
284 / (self.config.frame_size as f32);
285 real += sample * angle.cos();
286 imag += sample * angle.sin();
287 }
288
289 let magnitude = (real * real + imag * imag).sqrt();
290 spectrum.push(magnitude);
291 }
292
293 spectrum
294 }
295
296 fn spectrum_to_time_domain(&self, spectrum: &[f32]) -> Vec<f32> {
298 let mut time_domain = vec![0.0; self.config.frame_size];
302
303 for (n, sample) in time_domain.iter_mut().enumerate() {
304 for (k, &magnitude) in spectrum.iter().enumerate() {
305 let angle = 2.0 * std::f32::consts::PI * (k as f32) * (n as f32)
306 / (self.config.frame_size as f32);
307 *sample += magnitude * angle.cos() / (self.config.frame_size as f32);
308 }
309 }
310
311 for (i, sample) in time_domain.iter_mut().enumerate() {
313 *sample *= self.window_function[i];
314 }
315
316 time_domain
317 }
318
319 fn create_hann_window(size: usize) -> Vec<f32> {
321 (0..size)
322 .map(|n| {
323 0.5 * (1.0 - (2.0 * std::f32::consts::PI * n as f32 / (size - 1) as f32).cos())
324 })
325 .collect()
326 }
327
328 pub fn get_stats(&self) -> NoiseReductionStats {
330 let noise_level = if !self.noise_spectrum.is_empty() {
331 self.noise_spectrum.iter().sum::<f32>() / self.noise_spectrum.len() as f32
332 } else {
333 0.0
334 };
335
336 NoiseReductionStats {
337 is_noise_estimated: self.is_noise_estimated,
338 noise_level,
339 frames_processed: self.gain_history.len(),
340 subtraction_factor: self.config.subtraction_factor,
341 min_gain: self.config.min_gain,
342 }
343 }
344
345 pub fn reset(&mut self) {
347 self.noise_spectrum.fill(0.0);
348 self.gain_history.clear();
349 self.frame_buffer.clear();
350 self.noise_frames.clear();
351 self.is_noise_estimated = false;
352 }
353
354 pub fn update_config(&mut self, config: NoiseCancellationConfig) {
356 if config.frame_size != self.config.frame_size {
357 self.window_function = Self::create_hann_window(config.frame_size);
359 self.noise_spectrum = vec![0.0; config.frame_size / 2 + 1];
360 self.reset();
361 }
362
363 self.config = config;
364 }
365}
366
367#[derive(Debug, Clone, serde::Serialize)]
369pub struct NoiseReductionStats {
370 pub is_noise_estimated: bool,
371 pub noise_level: f32,
372 pub frames_processed: usize,
373 pub subtraction_factor: f32,
374 pub min_gain: f32,
375}
376
377pub fn create_speech_noise_canceller() -> NoiseCancellationProcessor {
379 let config = NoiseCancellationConfig {
380 sample_rate: 16000,
381 frame_size: 512,
382 noise_window_size: 30,
383 subtraction_factor: 1.5,
384 min_gain: 0.15,
385 smoothing_factor: 0.85,
386 adaptive_estimation: true,
387 };
388 NoiseCancellationProcessor::new(config)
389}
390
391pub fn create_music_noise_canceller() -> NoiseCancellationProcessor {
393 let config = NoiseCancellationConfig {
394 sample_rate: 44100,
395 frame_size: 1024,
396 noise_window_size: 50,
397 subtraction_factor: 1.2,
398 min_gain: 0.2,
399 smoothing_factor: 0.9,
400 adaptive_estimation: true,
401 };
402 NoiseCancellationProcessor::new(config)
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_noise_canceller_creation() {
411 let config = NoiseCancellationConfig::default();
412 let processor = NoiseCancellationProcessor::new(config);
413
414 let stats = processor.get_stats();
415 assert!(!stats.is_noise_estimated);
416 assert_eq!(stats.frames_processed, 0);
417 }
418
419 #[test]
420 fn test_hann_window() {
421 let window = NoiseCancellationProcessor::create_hann_window(4);
422 assert_eq!(window.len(), 4);
423 assert!(window[0] < window[1]); assert!(window[2] > window[3]); }
426
427 #[test]
428 fn test_noise_estimation() {
429 let config = NoiseCancellationConfig {
430 frame_size: 4,
431 ..NoiseCancellationConfig::default()
432 };
433 let mut processor = NoiseCancellationProcessor::new(config);
434
435 let noise_samples = vec![vec![0.1, 0.1, 0.1, 0.1], vec![0.2, 0.2, 0.2, 0.2]];
436
437 processor.estimate_noise(&noise_samples).unwrap();
438
439 let stats = processor.get_stats();
440 assert!(stats.is_noise_estimated);
441 assert!(stats.noise_level > 0.0);
442 }
443
444 #[test]
445 fn test_frame_processing() {
446 let config = NoiseCancellationConfig {
447 frame_size: 4,
448 ..NoiseCancellationConfig::default()
449 };
450 let mut processor = NoiseCancellationProcessor::new(config);
451
452 let noise_samples = vec![vec![0.01, 0.01, 0.01, 0.01]];
454 processor.estimate_noise(&noise_samples).unwrap();
455
456 let input_frame = vec![0.5, 0.4, 0.3, 0.2];
458 let output_frame = processor.process_frame(&input_frame).unwrap();
459
460 assert_eq!(output_frame.len(), input_frame.len());
461 }
462
463 #[test]
464 fn test_speech_optimized_settings() {
465 let processor = create_speech_noise_canceller();
466 let stats = processor.get_stats();
467
468 assert_eq!(stats.subtraction_factor, 1.5);
469 assert_eq!(stats.min_gain, 0.15);
470 }
471}