def test_wpe_v8(self): desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='valid') actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='valid') tc.assert_allclose(actual, desired, atol=1e-6) desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='valid') actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='valid') tc.assert_allclose(actual, desired, atol=1e-6) desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full') actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full') tc.assert_allclose(actual, desired, atol=1e-6) desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full') actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full') tc.assert_allclose(actual, desired, atol=1e-6)
def test_wpe_batched_multi_freq(self): def to_batched_multi_freq(x): return np.array([ [x, x * 2], [x * 3, x * 4], [x * 5, x * 6], ]) Y_batched_multi_freq = to_batched_multi_freq(self.Y) tc.assert_raises(NotImplementedError, wpe.wpe_v0, Y_batched_multi_freq, self.K, self.delay, statistics_mode='full') desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full') desired = to_batched_multi_freq(desired) actual = wpe.wpe_v7(Y_batched_multi_freq, self.K, self.delay, statistics_mode='full') assert desired.shape == (3, 2, self.D, self.T) assert actual.shape == (3, 2, self.D, self.T) tc.assert_allclose(actual, desired, atol=1e-6) desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full') desired = to_batched_multi_freq(desired) actual = wpe.wpe_v6(Y_batched_multi_freq, self.K, self.delay, statistics_mode='full') assert desired.shape == (3, 2, self.D, self.T) assert actual.shape == (3, 2, self.D, self.T) tc.assert_allclose(actual, desired, atol=1e-6) desired = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full') desired = to_batched_multi_freq(desired) actual = wpe.wpe_v8(Y_batched_multi_freq, self.K, self.delay, statistics_mode='full') assert desired.shape == (3, 2, self.D, self.T) assert actual.shape == (3, 2, self.D, self.T) tc.assert_allclose(actual, desired, atol=1e-6)
def test_wpe_multi_freq(self): desired = wpe.wpe_v0(self.Y, self.K, self.delay, statistics_mode='full') desired = [desired, desired] actual = wpe.wpe_v0(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full') tc.assert_allclose(actual, desired, atol=1e-6) desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full') desired = [desired, desired] actual = wpe.wpe_v7(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full') tc.assert_allclose(actual, desired, atol=1e-6) desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full') desired = [desired, desired] actual = wpe.wpe_v6(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full') tc.assert_allclose(actual, desired, atol=1e-6) desired = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full') desired = [desired, desired] actual = wpe.wpe_v8(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full') tc.assert_allclose(actual, desired, atol=1e-6)
def run(args): stft_kwargs = { "frame_len": args.frame_len, "frame_hop": args.frame_hop, "window": args.window, "center": args.center, # false to comparable with kaldi "transpose": True # T x F } wpe_kwargs = { "num_iters": args.num_iters, "context": args.context, "taps": args.taps, "delay": args.delay } spectrogram_reader = SpectrogramReader( args.wav_scp, round_power_of_two=args.round_power_of_two, **stft_kwargs) num_done = 0 with WaveWriter(args.dst_dir, fs=args.sr) as writer: for key, reverbed in spectrogram_reader: logger.info(f"Processing utt {key}...") if reverbed.ndim == 2: reverbed = reverbed[None, ...] # N x T x F => F x N x T reverbed = np.transpose(reverbed, (2, 0, 1)) try: if args.nara_wpe: from nara_wpe.wpe import wpe_v8 # T x F x N dereverb = wpe_v8(reverbed, taps=args.taps, delay=args.delay, iterations=args.num_iters, psd_context=args.context) else: dereverb = wpe(reverbed, **wpe_kwargs) except np.linalg.LinAlgError: logger.warn(f"{key}: Failed cause LinAlgError in wpe") continue # F x N x T => N x T x F dereverb = np.transpose(dereverb, (1, 2, 0)) # dump multi-channel samps = np.stack( [inverse_stft(spectra, **stft_kwargs) for spectra in dereverb]) writer.write(key, samps) # show progress cause slow speed num_done += 1 if not num_done % 100: logger.info(f"Processed {num_done:d} utterances...") logger.info( f"Processed {num_done:d} utterances over {len(spectrogram_reader):d}")