コード例 #1
0
ファイル: test_wpe.py プロジェクト: zuroh/nara_wpe
    def test_wpe_v8(self):
        desired = wpe.wpe_v6(self.Y,
                             self.K,
                             self.delay,
                             statistics_mode='valid')
        actual = wpe.wpe_v8(self.Y,
                            self.K,
                            self.delay,
                            statistics_mode='valid')
        tc.assert_allclose(actual, desired, atol=1e-6)

        desired = wpe.wpe_v7(self.Y,
                             self.K,
                             self.delay,
                             statistics_mode='valid')
        actual = wpe.wpe_v8(self.Y,
                            self.K,
                            self.delay,
                            statistics_mode='valid')
        tc.assert_allclose(actual, desired, atol=1e-6)

        desired = wpe.wpe_v6(self.Y,
                             self.K,
                             self.delay,
                             statistics_mode='full')
        actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full')
        tc.assert_allclose(actual, desired, atol=1e-6)

        desired = wpe.wpe_v7(self.Y,
                             self.K,
                             self.delay,
                             statistics_mode='full')
        actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full')
        tc.assert_allclose(actual, desired, atol=1e-6)
コード例 #2
0
ファイル: test_wpe.py プロジェクト: zuroh/nara_wpe
    def test_wpe_batched_multi_freq(self):
        def to_batched_multi_freq(x):
            return np.array([
                [x, x * 2],
                [x * 3, x * 4],
                [x * 5, x * 6],
            ])

        Y_batched_multi_freq = to_batched_multi_freq(self.Y)

        tc.assert_raises(NotImplementedError,
                         wpe.wpe_v0,
                         Y_batched_multi_freq,
                         self.K,
                         self.delay,
                         statistics_mode='full')

        desired = wpe.wpe_v7(self.Y,
                             self.K,
                             self.delay,
                             statistics_mode='full')
        desired = to_batched_multi_freq(desired)
        actual = wpe.wpe_v7(Y_batched_multi_freq,
                            self.K,
                            self.delay,
                            statistics_mode='full')
        assert desired.shape == (3, 2, self.D, self.T)
        assert actual.shape == (3, 2, self.D, self.T)
        tc.assert_allclose(actual, desired, atol=1e-6)

        desired = wpe.wpe_v6(self.Y,
                             self.K,
                             self.delay,
                             statistics_mode='full')
        desired = to_batched_multi_freq(desired)
        actual = wpe.wpe_v6(Y_batched_multi_freq,
                            self.K,
                            self.delay,
                            statistics_mode='full')
        assert desired.shape == (3, 2, self.D, self.T)
        assert actual.shape == (3, 2, self.D, self.T)
        tc.assert_allclose(actual, desired, atol=1e-6)

        desired = wpe.wpe_v8(self.Y,
                             self.K,
                             self.delay,
                             statistics_mode='full')
        desired = to_batched_multi_freq(desired)
        actual = wpe.wpe_v8(Y_batched_multi_freq,
                            self.K,
                            self.delay,
                            statistics_mode='full')
        assert desired.shape == (3, 2, self.D, self.T)
        assert actual.shape == (3, 2, self.D, self.T)
        tc.assert_allclose(actual, desired, atol=1e-6)
コード例 #3
0
ファイル: test_wpe.py プロジェクト: zuroh/nara_wpe
    def test_wpe_multi_freq(self):
        desired = wpe.wpe_v0(self.Y,
                             self.K,
                             self.delay,
                             statistics_mode='full')
        desired = [desired, desired]
        actual = wpe.wpe_v0(np.array([self.Y, self.Y]),
                            self.K,
                            self.delay,
                            statistics_mode='full')
        tc.assert_allclose(actual, desired, atol=1e-6)

        desired = wpe.wpe_v7(self.Y,
                             self.K,
                             self.delay,
                             statistics_mode='full')
        desired = [desired, desired]
        actual = wpe.wpe_v7(np.array([self.Y, self.Y]),
                            self.K,
                            self.delay,
                            statistics_mode='full')
        tc.assert_allclose(actual, desired, atol=1e-6)

        desired = wpe.wpe_v6(self.Y,
                             self.K,
                             self.delay,
                             statistics_mode='full')
        desired = [desired, desired]
        actual = wpe.wpe_v6(np.array([self.Y, self.Y]),
                            self.K,
                            self.delay,
                            statistics_mode='full')
        tc.assert_allclose(actual, desired, atol=1e-6)

        desired = wpe.wpe_v8(self.Y,
                             self.K,
                             self.delay,
                             statistics_mode='full')
        desired = [desired, desired]
        actual = wpe.wpe_v8(np.array([self.Y, self.Y]),
                            self.K,
                            self.delay,
                            statistics_mode='full')
        tc.assert_allclose(actual, desired, atol=1e-6)
コード例 #4
0
def run(args):
    stft_kwargs = {
        "frame_len": args.frame_len,
        "frame_hop": args.frame_hop,
        "window": args.window,
        "center": args.center,  # false to comparable with kaldi
        "transpose": True  # T x F
    }
    wpe_kwargs = {
        "num_iters": args.num_iters,
        "context": args.context,
        "taps": args.taps,
        "delay": args.delay
    }
    spectrogram_reader = SpectrogramReader(
        args.wav_scp,
        round_power_of_two=args.round_power_of_two,
        **stft_kwargs)

    num_done = 0
    with WaveWriter(args.dst_dir, fs=args.sr) as writer:
        for key, reverbed in spectrogram_reader:
            logger.info(f"Processing utt {key}...")
            if reverbed.ndim == 2:
                reverbed = reverbed[None, ...]
            # N x T x F => F x N x T
            reverbed = np.transpose(reverbed, (2, 0, 1))
            try:
                if args.nara_wpe:
                    from nara_wpe.wpe import wpe_v8
                    # T x F x N
                    dereverb = wpe_v8(reverbed,
                                      taps=args.taps,
                                      delay=args.delay,
                                      iterations=args.num_iters,
                                      psd_context=args.context)
                else:
                    dereverb = wpe(reverbed, **wpe_kwargs)
            except np.linalg.LinAlgError:
                logger.warn(f"{key}: Failed cause LinAlgError in wpe")
                continue
            # F x N x T => N x T x F
            dereverb = np.transpose(dereverb, (1, 2, 0))
            # dump multi-channel
            samps = np.stack(
                [inverse_stft(spectra, **stft_kwargs) for spectra in dereverb])
            writer.write(key, samps)
            # show progress cause slow speed
            num_done += 1
            if not num_done % 100:
                logger.info(f"Processed {num_done:d} utterances...")
    logger.info(
        f"Processed {num_done:d} utterances over {len(spectrogram_reader):d}")