Exemplo n.º 1
0
class TORCHAUDIODS(Dataset):

    test_dirpath, test_dir = create_temp_assets_dir()

    def __init__(self):
        self.asset_dirpath = os.path.join(self.test_dirpath, "assets")
        sound_files = ["sinewave.wav", "steam-train-whistle-daniel_simon.mp3"]
        self.data = [
            os.path.join(self.asset_dirpath, fn) for fn in sound_files
        ]
        self.si, self.ei = torchaudio.info(
            os.path.join(self.asset_dirpath, "sinewave.wav"))
        self.si.precision = 16
        self.E = torchaudio.sox_effects.SoxEffectsChain()
        self.E.append_effect_to_chain("rate",
                                      [self.si.rate])  # resample to 16000hz
        self.E.append_effect_to_chain("channels",
                                      [self.si.channels])  # mono signal
        self.E.append_effect_to_chain(
            "trim", [0, "16000s"])  # first 16000 samples of audio

    def __getitem__(self, index):
        fn = self.data[index]
        self.E.set_input_file(fn)
        x, sr = self.E.sox_build_flow_effects()
        return x

    def __len__(self):
        return len(self.data)
Exemplo n.º 2
0
class Test_KaldiIO(unittest.TestCase):
    data1 = [[1, 2, 3], [11, 12, 13], [21, 22, 23]]
    data2 = [[31, 32, 33], [41, 42, 43], [51, 52, 53]]
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()

    def _test_helper(self, file_name, expected_data, fn, expected_dtype):
        """ Takes a file_name to the input data and a function fn to extract the
        data. It compares the extracted data to the expected_data. The expected_dtype
        will be used to check that the extracted data is of the right type.
        """
        test_filepath = os.path.join(self.test_dirpath, "assets", file_name)
        expected_output = {
            'key' + str(idx + 1): torch.tensor(val, dtype=expected_dtype)
            for idx, val in enumerate(expected_data)
        }

        for key, vec in fn(test_filepath):
            self.assertTrue(key in expected_output)
            self.assertTrue(isinstance(vec, torch.Tensor))
            self.assertEqual(vec.dtype, expected_dtype)
            self.assertTrue(torch.all(torch.eq(vec, expected_output[key])))

    def test_read_vec_int_ark(self):
        self._test_helper("vec_int.ark", self.data1, kio.read_vec_int_ark,
                          torch.int32)

    def test_read_vec_flt_ark(self):
        self._test_helper("vec_flt.ark", self.data1, kio.read_vec_flt_ark,
                          torch.float32)

    def test_read_mat_ark(self):
        self._test_helper("mat.ark", [self.data1, self.data2],
                          kio.read_mat_ark, torch.float32)
Exemplo n.º 3
0
class TestDatasets(unittest.TestCase):
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
    path = os.path.join(test_dirpath, "assets")

    def test_yesno(self):
        data = YESNO(self.path, return_dict=True)
        data[0]

    def test_vctk(self):
        data = VCTK(self.path, return_dict=True)
        data[0]

    def test_librispeech(self):
        data = LIBRISPEECH(self.path, "dev-clean")
        data[0]

    def test_commonvoice(self):
        path = os.path.join(self.path, "commonvoice")
        data = COMMONVOICE(path, "train.tsv", "tatar")
        data[0]

    def test_commonvoice_diskcache(self):
        path = os.path.join(self.path, "commonvoice")
        data = COMMONVOICE(path, "train.tsv", "tatar")
        data = DiskCache(data)
        # Save
        data[0]
        # Load
        data[0]
Exemplo n.º 4
0
    def test_pitch(self):
        test_dirpath, test_dir = common_utils.create_temp_assets_dir()
        test_filepath_100 = os.path.join(test_dirpath, 'assets',
                                         "100Hz_44100Hz_16bit_05sec.wav")
        test_filepath_440 = os.path.join(test_dirpath, 'assets',
                                         "440Hz_44100Hz_16bit_05sec.wav")

        # Files from https://www.mediacollege.com/audio/tone/download/
        tests = [
            (test_filepath_100, 100),
            (test_filepath_440, 440),
        ]

        for filename, freq_ref in tests:
            waveform, sample_rate = torchaudio.load(filename)

            freq = torchaudio.functional.detect_pitch_frequency(
                waveform, sample_rate)

            threshold = 1
            s = ((freq - freq_ref).abs() > threshold).sum()
            self.assertFalse(s)

            # Convert to stereo and batch for testing purposes
            freq = freq.repeat(3, 2, 1, 1)
            waveform = waveform.repeat(3, 2, 1, 1)

            freq2 = torchaudio.functional.detect_pitch_frequency(
                waveform, sample_rate)

            assert torch.allclose(freq, freq2, atol=1e-5)
Exemplo n.º 5
0
class TestDatasets(unittest.TestCase):
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
    path = os.path.join(test_dirpath, "assets")

    def test_yesno(self):
        data = YESNO(self.path)
        data[0]

    def test_vctk(self):
        data = VCTK(self.path)
        data[0]

    def test_librispeech(self):
        data = LIBRISPEECH(self.path, "dev-clean")
        data[0]

    def test_commonvoice(self):
        path = os.path.join(self.path, "commonvoice")
        data = COMMONVOICE(path, "train.tsv", "tatar")
        data[0]

    def test_commonvoice_diskcache(self):
        path = os.path.join(self.path, "commonvoice")
        data = COMMONVOICE(path, "train.tsv", "tatar")
        data = diskcache_iterator(data)
        # Save
        data[0]
        # Load
        data[0]

    def test_commonvoice_bg(self):
        path = os.path.join(self.path, "commonvoice")
        data = COMMONVOICE(path, "train.tsv", "tatar")
        data = bg_iterator(data, 5)
        for d in data:
            pass

    def test_ljspeech(self):
        data = LJSPEECH(self.path)
        data[0]

    def test_speechcommands(self):
        data = SPEECHCOMMANDS(self.path)
        data[0]
Exemplo n.º 6
0
    def test_pitch(self):
        test_dirpath, test_dir = common_utils.create_temp_assets_dir()
        test_filepath_100 = os.path.join(test_dirpath, 'assets',
                                         "100Hz_44100Hz_16bit_05sec.wav")
        test_filepath_440 = os.path.join(test_dirpath, 'assets',
                                         "440Hz_44100Hz_16bit_05sec.wav")

        # Files from https://www.mediacollege.com/audio/tone/download/
        tests = [
            (test_filepath_100, 100),
            (test_filepath_440, 440),
        ]

        for filename, freq_ref in tests:
            waveform, sample_rate = torchaudio.load(filename)

            freq = torchaudio.functional.detect_pitch_frequency(
                waveform, sample_rate)

            threshold = 1
            s = ((freq - freq_ref).abs() > threshold).sum()
            self.assertFalse(s)
Exemplo n.º 7
0
class Tester(unittest.TestCase):

    # create a sinewave signal for testing
    sample_rate = 16000
    freq = 440
    volume = .3
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
    # file for stereo stft test
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.mp3')

    def scale(self, waveform, factor=float(2**31)):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor

    def test_mu_law_companding(self):

        quantization_channels = 256

        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)

        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)

        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)

    def test_mel2(self):
        top_db = 80.
        s2db = transforms.AmplitudeToDB('power', top_db)

        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
        mel_transform = transforms.MelSpectrogram()
        # check defaults
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
        self.assertTrue(spectrogram_torch.dim() == 3)
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
        # check correctness of filterbank conversion matrix
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
        # check options
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
        self.assertTrue(spectrogram2_torch.dim() == 3)
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
        # check on multi-channel audio
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
        # check filterbank matrix creation
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))

    def test_mfcc(self):
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
        self.assertTrue(torch_mfcc.dim() == 3)
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
        # check melkwargs are passed through
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)

        # check norms work correctly
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)

        norm_check = torch_mfcc.clone()
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

    @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
    def test_librosa_consistency(self):
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
            sound_librosa = sound.cpu().numpy().squeeze()  # (64000)

            # test core spectrogram
            spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2)
            out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
                                                                n_fft=n_fft,
                                                                hop_length=hop_length,
                                                                power=2)

            out_torch = spect_transform(sound).squeeze().cpu()
            self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))

            # test mel spectrogram
            melspect_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate, window_fn=torch.hann_window,
                hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
                                                         n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
                                                         htk=True, norm=None)
            librosa_mel_tensor = torch.from_numpy(librosa_mel)
            torch_mel = melspect_transform(sound).squeeze().cpu()

            self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))

            # test s2db
            db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
            db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
            db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))

            db_torch = db_transform(melspect_transform(sound)).squeeze().cpu()
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
            db_librosa_tensor = torch.from_numpy(db_librosa)

            self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3))

            # test MFCC
            melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
                                                        n_mfcc=n_mfcc,
                                                        norm='ortho',
                                                        melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

    #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
    #                                             sr=sample_rate,
    #                                             n_mfcc = n_mfcc,
    #                                             hop_length=hop_length,
    #                                             n_fft=n_fft,
    #                                             htk=True,
    #                                             norm=None,
    #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
            librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
            torch_mfcc = mfcc_transform(sound).squeeze().cpu()

            self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))

        kwargs1 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        kwargs2 = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }

        kwargs3 = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }

        _test_librosa_consistency_helper(**kwargs1)
        _test_librosa_consistency_helper(**kwargs2)
        _test_librosa_consistency_helper(**kwargs3)

    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
        waveform, sample_rate = torchaudio.load(input_path)

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

        self.assertRaises(ValueError, invalid_resample, waveform)

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
        up_sampled = upsample_resample(waveform)

        # we expect the upsampled signal to have twice as many samples
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
        down_sampled = downsample_resample(waveform)

        # we expect the downsampled signal to have half as many samples
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
Exemplo n.º 8
0
class TestFunctional(unittest.TestCase):
    data_sizes = [(2, 20), (3, 15), (4, 10)]
    number_of_trials = 100
    specgram = torch.tensor([1., 2., 3., 4.])

    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.mp3')

    def _test_compute_deltas(self,
                             specgram,
                             expected,
                             win_length=3,
                             atol=1e-6,
                             rtol=1e-8):
        computed = F.compute_deltas(specgram, win_length=win_length)
        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        torch.testing.assert_allclose(computed, expected, atol=atol, rtol=rtol)

    def test_compute_deltas_onechannel(self):
        specgram = self.specgram.unsqueeze(0).unsqueeze(0)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
        self._test_compute_deltas(specgram, expected)

    def test_compute_deltas_twochannel(self):
        specgram = self.specgram.repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], [0.5, 1.0, 1.0, 0.5]]])
        self._test_compute_deltas(specgram, expected)

    def test_compute_deltas_randn(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        computed = F.compute_deltas(specgram, win_length=win_length)
        self.assertTrue(computed.shape == specgram.shape,
                        (computed.shape, specgram.shape))

    def test_batch_pitch(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = F.detect_pitch_frequency(waveform, sample_rate)
        expected = expected.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = F.detect_pitch_frequency(waveform, sample_rate)

        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
        # trim sound for case when constructed signal is shorter than original
        sound = sound[..., :estimate.size(-1)]

        self.assertTrue(sound.shape == estimate.shape,
                        (sound.shape, estimate.shape))
        self.assertTrue(torch.allclose(sound, estimate, atol=atol, rtol=rtol))

    def _test_istft_is_inverse_of_stft(self, kwargs):
        # generates a random sound signal for each tril and then does the stft/istft
        # operation to check whether we can reconstruct signal
        for data_size in self.data_sizes:
            for i in range(self.number_of_trials):

                # Non-batch
                sound = common_utils.random_float_tensor(i, data_size)

                stft = torch.stft(sound, **kwargs)
                estimate = torchaudio.functional.istft(stft,
                                                       length=sound.size(1),
                                                       **kwargs)

                self._compare_estimate(sound, estimate)

                # Batch
                stft = torch.stft(sound, **kwargs)
                stft = stft.repeat(3, 1, 1, 1, 1)
                sound = sound.repeat(3, 1, 1)

                estimate = torchaudio.functional.istft(stft,
                                                       length=sound.size(1),
                                                       **kwargs)
                self._compare_estimate(sound, estimate)

    def test_istft_is_inverse_of_stft1(self):
        # hann_window, centered, normalized, onesided
        kwargs1 = {
            'n_fft': 12,
            'hop_length': 4,
            'win_length': 12,
            'window': torch.hann_window(12),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }

        self._test_istft_is_inverse_of_stft(kwargs1)

    def test_istft_is_inverse_of_stft2(self):
        # hann_window, centered, not normalized, not onesided
        kwargs2 = {
            'n_fft': 12,
            'hop_length': 2,
            'win_length': 8,
            'window': torch.hann_window(8),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': False,
            'onesided': False,
        }

        self._test_istft_is_inverse_of_stft(kwargs2)

    def test_istft_is_inverse_of_stft3(self):
        # hamming_window, centered, normalized, not onesided
        kwargs3 = {
            'n_fft': 15,
            'hop_length': 3,
            'win_length': 11,
            'window': torch.hamming_window(11),
            'center': True,
            'pad_mode': 'constant',
            'normalized': True,
            'onesided': False,
        }

        self._test_istft_is_inverse_of_stft(kwargs3)

    def test_istft_is_inverse_of_stft4(self):
        # hamming_window, not centered, not normalized, onesided
        # window same size as n_fft
        kwargs4 = {
            'n_fft': 5,
            'hop_length': 2,
            'win_length': 5,
            'window': torch.hamming_window(5),
            'center': False,
            'pad_mode': 'constant',
            'normalized': False,
            'onesided': True,
        }

        self._test_istft_is_inverse_of_stft(kwargs4)

    def test_istft_is_inverse_of_stft5(self):
        # hamming_window, not centered, not normalized, not onesided
        # window same size as n_fft
        kwargs5 = {
            'n_fft': 3,
            'hop_length': 2,
            'win_length': 3,
            'window': torch.hamming_window(3),
            'center': False,
            'pad_mode': 'reflect',
            'normalized': False,
            'onesided': False,
        }

        self._test_istft_is_inverse_of_stft(kwargs5)

    def test_istft_of_ones(self):
        # stft = torch.stft(torch.ones(4), 4)
        stft = torch.tensor([[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4.,
                                                                       0.]],
                             [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0.,
                                                                       0.]],
                             [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0.,
                                                                       0.]]])

        estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
        self._compare_estimate(torch.ones(4), estimate)

    def test_istft_of_zeros(self):
        # stft = torch.stft(torch.zeros(4), 4)
        stft = torch.zeros((3, 5, 2))

        estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
        self._compare_estimate(torch.zeros(4), estimate)

    def test_istft_requires_overlap_windows(self):
        # the window is size 1 but it hops 20 so there is a gap which throw an error
        stft = torch.zeros((3, 5, 2))
        self.assertRaises(AssertionError,
                          torchaudio.functional.istft,
                          stft,
                          n_fft=4,
                          hop_length=20,
                          win_length=1,
                          window=torch.ones(1))

    def test_istft_requires_nola(self):
        stft = torch.zeros((3, 5, 2))
        kwargs_ok = {
            'n_fft': 4,
            'win_length': 4,
            'window': torch.ones(4),
        }

        kwargs_not_ok = {
            'n_fft': 4,
            'win_length': 4,
            'window': torch.zeros(4),
        }

        # A window of ones meets NOLA but a window of zeros does not. This should
        # throw an error.
        torchaudio.functional.istft(stft, **kwargs_ok)
        self.assertRaises(AssertionError, torchaudio.functional.istft, stft,
                          **kwargs_not_ok)

    def test_istft_requires_non_empty(self):
        self.assertRaises(AssertionError, torchaudio.functional.istft,
                          torch.zeros((3, 0, 2)), 2)
        self.assertRaises(AssertionError, torchaudio.functional.istft,
                          torch.zeros((0, 3, 2)), 2)

    def _test_istft_of_sine(self, amplitude, L, n):
        # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
        x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype())
        sound = amplitude * torch.sin(2 * math.pi / L * x * n)
        # stft = torch.stft(sound, L, hop_length=L, win_length=L,
        #                   window=torch.ones(L), center=False, normalized=False)
        stft = torch.zeros((L // 2 + 1, 2, 2))
        stft_largest_val = (amplitude * L) / 2.0
        if n < stft.size(0):
            stft[n, :, 1] = -stft_largest_val

        if 0 <= L - n < stft.size(0):
            # symmetric about L // 2
            stft[L - n, :, 1] = stft_largest_val

        estimate = torchaudio.functional.istft(stft,
                                               L,
                                               hop_length=L,
                                               win_length=L,
                                               window=torch.ones(L),
                                               center=False,
                                               normalized=False)
        # There is a larger error due to the scaling of amplitude
        self._compare_estimate(sound, estimate, atol=1e-3)

    def test_istft_of_sine(self):
        self._test_istft_of_sine(amplitude=123, L=5, n=1)
        self._test_istft_of_sine(amplitude=150, L=5, n=2)
        self._test_istft_of_sine(amplitude=111, L=5, n=3)
        self._test_istft_of_sine(amplitude=160, L=7, n=4)
        self._test_istft_of_sine(amplitude=145, L=8, n=5)
        self._test_istft_of_sine(amplitude=80, L=9, n=6)
        self._test_istft_of_sine(amplitude=99, L=10, n=7)

    def _test_linearity_of_istft(self,
                                 data_size,
                                 kwargs,
                                 atol=1e-6,
                                 rtol=1e-8):
        for i in range(self.number_of_trials):
            tensor1 = common_utils.random_float_tensor(i, data_size)
            tensor2 = common_utils.random_float_tensor(i * 2, data_size)
            a, b = torch.rand(2)
            istft1 = torchaudio.functional.istft(tensor1, **kwargs)
            istft2 = torchaudio.functional.istft(tensor2, **kwargs)
            istft = a * istft1 + b * istft2
            estimate = torchaudio.functional.istft(a * tensor1 + b * tensor2,
                                                   **kwargs)
            self._compare_estimate(istft, estimate, atol, rtol)

    def test_linearity_of_istft1(self):
        # hann_window, centered, normalized, onesided
        kwargs1 = {
            'n_fft': 12,
            'window': torch.hann_window(12),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }
        data_size = (2, 7, 7, 2)
        self._test_linearity_of_istft(data_size, kwargs1)

    def test_linearity_of_istft2(self):
        # hann_window, centered, not normalized, not onesided
        kwargs2 = {
            'n_fft': 12,
            'window': torch.hann_window(12),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': False,
            'onesided': False,
        }
        data_size = (2, 12, 7, 2)
        self._test_linearity_of_istft(data_size, kwargs2)

    def test_linearity_of_istft3(self):
        # hamming_window, centered, normalized, not onesided
        kwargs3 = {
            'n_fft': 12,
            'window': torch.hamming_window(12),
            'center': True,
            'pad_mode': 'constant',
            'normalized': True,
            'onesided': False,
        }
        data_size = (2, 12, 7, 2)
        self._test_linearity_of_istft(data_size, kwargs3)

    def test_linearity_of_istft4(self):
        # hamming_window, not centered, not normalized, onesided
        kwargs4 = {
            'n_fft': 12,
            'window': torch.hamming_window(12),
            'center': False,
            'pad_mode': 'constant',
            'normalized': False,
            'onesided': True,
        }
        data_size = (2, 7, 3, 2)
        self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)

    def _test_create_fb(self,
                        n_mels=40,
                        sample_rate=22050,
                        n_fft=2048,
                        fmin=0.0,
                        fmax=8000.0):
        # Using a decorator here causes parametrize to fail on Python 2
        if not IMPORT_LIBROSA:
            raise unittest.SkipTest('Librosa is not available')

        librosa_fb = librosa.filters.mel(sr=sample_rate,
                                         n_fft=n_fft,
                                         n_mels=n_mels,
                                         fmax=fmax,
                                         fmin=fmin,
                                         htk=True,
                                         norm=None)
        fb = F.create_fb_matrix(sample_rate=sample_rate,
                                n_mels=n_mels,
                                f_max=fmax,
                                f_min=fmin,
                                n_freqs=(n_fft // 2 + 1))

        for i_mel_bank in range(n_mels):
            assert torch.allclose(fb[:, i_mel_bank],
                                  torch.tensor(librosa_fb[i_mel_bank]),
                                  atol=1e-4)

    def test_create_fb(self):
        self._test_create_fb()
        self._test_create_fb(n_mels=128, sample_rate=44100)
        self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0)
        self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0)
        self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0)
        self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
        self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)

    def test_pitch(self):

        test_dirpath, test_dir = common_utils.create_temp_assets_dir()
        test_filepath_100 = os.path.join(test_dirpath, 'assets',
                                         "100Hz_44100Hz_16bit_05sec.wav")
        test_filepath_440 = os.path.join(test_dirpath, 'assets',
                                         "440Hz_44100Hz_16bit_05sec.wav")

        # Files from https://www.mediacollege.com/audio/tone/download/
        tests = [
            (test_filepath_100, 100),
            (test_filepath_440, 440),
        ]

        for filename, freq_ref in tests:
            waveform, sample_rate = torchaudio.load(filename)

            freq = torchaudio.functional.detect_pitch_frequency(
                waveform, sample_rate)

            threshold = 1
            s = ((freq - freq_ref).abs() > threshold).sum()
            self.assertFalse(s)

            # Convert to stereo and batch for testing purposes
            freq = freq.repeat(3, 2, 1, 1)
            waveform = waveform.repeat(3, 2, 1, 1)

            freq2 = torchaudio.functional.detect_pitch_frequency(
                waveform, sample_rate)

            assert torch.allclose(freq, freq2, atol=1e-5)

    def _test_batch(self, functional):
        waveform, sample_rate = torchaudio.load(
            self.test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1)

        # Batch then transform
        waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = functional(waveform)
Exemplo n.º 9
0
 def setUp(self):
     self.test_dirpath, self.test_dir = common_utils.create_temp_assets_dir(
     )
Exemplo n.º 10
0
class Tester(unittest.TestCase):

    # create a sinewave signal for testing
    sample_rate = 16000
    freq = 440
    volume = .3
    waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
    # file for stereo stft test
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.mp3')

    def scale(self, waveform, factor=float(2**31)):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor

    def test_scriptmodule_Spectrogram(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.Spectrogram, tensor)

    def test_scriptmodule_GriffinLim(self):
        tensor = torch.rand((1, 201, 6))
        _test_script_module(transforms.GriffinLim, tensor, length=1000, rand_init=False)

    def test_mu_law_companding(self):

        quantization_channels = 256

        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)

        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)

        waveform_exp = transforms.MuLawDecoding(quantization_channels)(waveform_mu)
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)

    def test_scriptmodule_AmplitudeToDB(self):
        spec = torch.rand((6, 201))
        _test_script_module(transforms.AmplitudeToDB, spec)

    def test_batch_AmplitudeToDB(self):
        spec = torch.rand((6, 201))

        # Single then transform then batch
        expected = transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_AmplitudeToDB(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
        power_to_db_transform = transforms.AmplitudeToDB('power', 80.)

        mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
        power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))

        self.assertTrue(torch.allclose(mag_to_db_torch, power_to_db_torch))

    def test_scriptmodule_MelScale(self):
        spec_f = torch.rand((1, 6, 201))
        _test_script_module(transforms.MelScale, spec_f)

    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

    def test_scriptmodule_MelSpectrogram(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.MelSpectrogram, tensor)

    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

        self.assertTrue(torch.allclose(window, window_copy))
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

    def test_mel2(self):
        top_db = 80.
        s2db = transforms.AmplitudeToDB('power', top_db)

        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
        mel_transform = transforms.MelSpectrogram()
        # check defaults
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
        self.assertTrue(spectrogram_torch.dim() == 3)
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
        # check correctness of filterbank conversion matrix
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
        # check options
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
        self.assertTrue(spectrogram2_torch.dim() == 3)
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
        # check on multi-channel audio
        x_stereo, sr_stereo = torchaudio.load(self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
        # check filterbank matrix creation
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))

    def test_scriptmodule_MFCC(self):
        tensor = torch.rand((1, 1000))
        _test_script_module(transforms.MFCC, tensor)

    def test_mfcc(self):
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
        self.assertTrue(torch_mfcc.dim() == 3)
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
        # check melkwargs are passed through
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)

        # check norms work correctly
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
                                                              n_mfcc=n_mfcc,
                                                              norm=None)
        torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)  # (1, 40, 321)

        norm_check = torch_mfcc.clone()
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

    @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available')
    def test_librosa_consistency(self):
        def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
            input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
            sound, sample_rate = torchaudio.load(input_path)
            sound_librosa = sound.cpu().numpy().squeeze()  # (64000)

            # test core spectrogram
            spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=power)
            out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
                                                                n_fft=n_fft,
                                                                hop_length=hop_length,
                                                                power=power)

            out_torch = spect_transform(sound).squeeze().cpu()
            self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))

            # test mel spectrogram
            melspect_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=sample_rate, window_fn=torch.hann_window,
                hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
            librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
                                                         n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
                                                         htk=True, norm=None)
            librosa_mel_tensor = torch.from_numpy(librosa_mel)
            torch_mel = melspect_transform(sound).squeeze().cpu()

            self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))

            # test s2db
            power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
            power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
            power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
            self.assertTrue(torch.allclose(power_to_db_torch, torch.from_numpy(power_to_db_librosa), atol=5e-3))

            mag_to_db_transform = torchaudio.transforms.AmplitudeToDB('magnitude', 80.)
            mag_to_db_torch = mag_to_db_transform(torch.abs(sound)).squeeze().cpu()
            mag_to_db_librosa = librosa.core.spectrum.amplitude_to_db(sound_librosa)
            self.assertTrue(
                torch.allclose(mag_to_db_torch, torch.from_numpy(mag_to_db_librosa), atol=5e-3)
            )

            power_to_db_torch = power_to_db_transform(melspect_transform(sound)).squeeze().cpu()
            db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
            db_librosa_tensor = torch.from_numpy(db_librosa)
            self.assertTrue(
                torch.allclose(power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)
            )

            # test MFCC
            melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
            mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
                                                        n_mfcc=n_mfcc,
                                                        norm='ortho',
                                                        melkwargs=melkwargs)

            # librosa.feature.mfcc doesn't pass kwargs properly since some of the
            # kwargs for melspectrogram and mfcc are the same. We just follow the
            # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
            # to mirror this function call with correct args:

    #         librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
    #                                             sr=sample_rate,
    #                                             n_mfcc = n_mfcc,
    #                                             hop_length=hop_length,
    #                                             n_fft=n_fft,
    #                                             htk=True,
    #                                             norm=None,
    #                                             n_mels=n_mels)

            librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
            librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
            torch_mfcc = mfcc_transform(sound).squeeze().cpu()

            self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))

        kwargs1 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        kwargs2 = {
            'n_fft': 600,
            'hop_length': 100,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 20,
            'sample_rate': 16000
        }

        kwargs3 = {
            'n_fft': 200,
            'hop_length': 50,
            'power': 2.0,
            'n_mels': 128,
            'n_mfcc': 50,
            'sample_rate': 24000
        }

        kwargs4 = {
            'n_fft': 400,
            'hop_length': 200,
            'power': 3.0,
            'n_mels': 128,
            'n_mfcc': 40,
            'sample_rate': 16000
        }

        _test_librosa_consistency_helper(**kwargs1)
        _test_librosa_consistency_helper(**kwargs2)
        # NOTE Test passes offline, but fails on CircleCI, see #372.
        # _test_librosa_consistency_helper(**kwargs3)
        _test_librosa_consistency_helper(**kwargs4)

    def test_scriptmodule_Resample(self):
        tensor = torch.rand((2, 1000))
        sample_rate = 100.
        sample_rate_2 = 50.

        _test_script_module(transforms.Resample, tensor, sample_rate, sample_rate_2)

    def test_batch_Resample(self):
        waveform = torch.randn(2, 2786)

        # Single then transform then batch
        expected = transforms.Resample()(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Resample()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_scriptmodule_ComplexNorm(self):
        tensor = torch.rand((1, 2, 201, 2))
        _test_script_module(transforms.ComplexNorm, tensor)

    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
        waveform, sample_rate = torchaudio.load(input_path)

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

        self.assertRaises(ValueError, invalid_resample, waveform)

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
        up_sampled = upsample_resample(waveform)

        # we expect the upsampled signal to have twice as many samples
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
        down_sampled = downsample_resample(waveform)

        # we expect the downsampled signal to have half as many samples
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)

    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
        torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol)

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
                                  [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas()
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

    def test_batch_MelScale(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_batch_InverseMelScale(self):
        n_fft = 8
        n_mels = 32
        n_stft = 5
        mel_spec = torch.randn(2, n_mels, 32) ** 2

        # Single then transform then batch
        expected = transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))

        # shape = (3, 2, n_mels, 32)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))

        # Because InverseMelScale runs SGD on randomly initialized values so they do not yield
        # exactly same result. For this reason, tolerance is very relaxed here.
        self.assertTrue(torch.allclose(computed, expected, atol=1.0))

    def test_batch_compute_deltas(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_scriptmodule_MuLawEncoding(self):
        tensor = torch.rand((1, 10))
        _test_script_module(transforms.MuLawEncoding, tensor)

    def test_scriptmodule_MuLawDecoding(self):
        tensor = torch.rand((1, 10))
        _test_script_module(transforms.MuLawDecoding, tensor)

    def test_batch_mulaw(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        waveform_encoded = transforms.MuLawEncoding()(waveform)
        expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = transforms.MuLawEncoding()(waveform_batched)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

        # Single then transform then batch
        waveform_decoded = transforms.MuLawDecoding()(waveform_encoded)
        expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.MuLawDecoding()(computed)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_batch_spectrogram(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_batch_melspectrogram(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_batch_mfcc(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MFCC()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected, atol=1e-5))

    def test_scriptmodule_TimeStretch(self):
        n_freq = 400
        hop_length = 512
        fixed_rate = 1.3
        tensor = torch.rand((10, 2, n_freq, 10, 2))
        _test_script_module(transforms.TimeStretch, tensor, n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate)

    def test_batch_TimeStretch(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        kwargs = {
            'n_fft': 2048,
            'hop_length': 512,
            'win_length': 2048,
            'window': torch.hann_window(2048),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }
        rate = 2

        complex_specgrams = torch.stft(waveform, **kwargs)

        # Single then transform then batch
        expected = transforms.TimeStretch(fixed_rate=rate,
                                          n_freq=1025,
                                          hop_length=512)(complex_specgrams).repeat(3, 1, 1, 1, 1)

        # Batch then transform
        computed = transforms.TimeStretch(fixed_rate=rate,
                                          n_freq=1025,
                                          hop_length=512)(complex_specgrams.repeat(3, 1, 1, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected, atol=1e-5))

    def test_batch_Fade(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)
        fade_in_len = 3000
        fade_out_len = 3000

        # Single then transform then batch
        expected = transforms.Fade(fade_in_len, fade_out_len)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Fade(fade_in_len, fade_out_len)(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_scriptmodule_Fade(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)
        fade_in_len = 3000
        fade_out_len = 3000

        _test_script_module(transforms.Fade, waveform, fade_in_len, fade_out_len)

    def test_scriptmodule_FrequencyMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
        _test_script_module(transforms.FrequencyMasking, tensor, freq_mask_param=60, iid_masks=False)

    def test_scriptmodule_TimeMasking(self):
        tensor = torch.rand((10, 2, 50, 10, 2))
        _test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
Exemplo n.º 11
0
class TestFunctionalFiltering(unittest.TestCase):
    test_dirpath, test_dir = create_temp_assets_dir()

    def _test_lfilter_basic(self, dtype, device):
        """
        Create a very basic signal,
        Then make a simple 4th order delay
        The output should be same as the input but shifted
        """

        torch.random.manual_seed(42)
        waveform = torch.rand(2, 44100 * 1, dtype=dtype, device=device)
        b_coeffs = torch.tensor([0, 0, 0, 1], dtype=dtype, device=device)
        a_coeffs = torch.tensor([1, 0, 0, 0], dtype=dtype, device=device)
        output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)

        torch.testing.assert_allclose(output_waveform[:, 3:],
                                      waveform[:, 0:-3],
                                      atol=1e-5,
                                      rtol=1e-5)

    def test_lfilter_basic(self):
        self._test_lfilter_basic(torch.float32, torch.device("cpu"))

    def test_lfilter_basic_double(self):
        self._test_lfilter_basic(torch.float64, torch.device("cpu"))

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
    def test_lfilter_basic_gpu(self):
        self._test_lfilter_basic(torch.float32, torch.device("cuda:0"))

    def _test_lfilter(self, waveform, device):
        """
        Design an IIR lowpass filter using scipy.signal filter design
        https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign

        Example
          >>> from scipy.signal import iirdesign
          >>> b, a = iirdesign(0.2, 0.3, 1, 60)
        """

        b_coeffs = torch.tensor(
            [
                0.00299893,
                -0.0051152,
                0.00841964,
                -0.00747802,
                0.00841964,
                -0.0051152,
                0.00299893,
            ],
            device=device,
        )
        a_coeffs = torch.tensor(
            [
                1.0,
                -4.8155751,
                10.2217618,
                -12.14481273,
                8.49018171,
                -3.3066882,
                0.56088705,
            ],
            device=device,
        )

        output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
        assert len(output_waveform.size()) == 2
        assert output_waveform.size(0) == waveform.size(0)
        assert output_waveform.size(1) == waveform.size(1)

    def test_lfilter(self):

        filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        self._test_lfilter(waveform, torch.device("cpu"))

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
    def test_lfilter_gpu(self):
        filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)
        cuda0 = torch.device("cuda:0")
        cuda_waveform = waveform.cuda(device=cuda0)
        self._test_lfilter(cuda_waveform, cuda0)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_gain(self):
        test_filepath = os.path.join(self.test_dirpath, "assets",
                                     "steam-train-whistle-daniel_simon.wav")
        waveform, _ = torchaudio.load(test_filepath)

        waveform_gain = F.gain(waveform, 3)
        self.assertTrue(waveform_gain.abs().max().item(), 1.)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(test_filepath)
        E.append_effect_to_chain("gain", [3])
        sox_gain_waveform = E.sox_build_flow_effects()[0]

        torch.testing.assert_allclose(waveform_gain,
                                      sox_gain_waveform,
                                      atol=1e-04,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_dither(self):
        test_filepath = os.path.join(self.test_dirpath, "assets",
                                     "steam-train-whistle-daniel_simon.wav")
        waveform, _ = torchaudio.load(test_filepath)

        waveform_dithered = F.dither(waveform)
        waveform_dithered_noiseshaped = F.dither(waveform, noise_shaping=True)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(test_filepath)
        E.append_effect_to_chain("dither", [])
        sox_dither_waveform = E.sox_build_flow_effects()[0]

        torch.testing.assert_allclose(waveform_dithered,
                                      sox_dither_waveform,
                                      atol=1e-04,
                                      rtol=1e-5)
        E.clear_chain()

        E.append_effect_to_chain("dither", ["-s"])
        sox_dither_waveform_ns = E.sox_build_flow_effects()[0]

        torch.testing.assert_allclose(waveform_dithered_noiseshaped,
                                      sox_dither_waveform_ns,
                                      atol=1e-02,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_vctk_transform_pipeline(self):
        test_filepath_vctk = os.path.join(self.test_dirpath,
                                          "assets/VCTK-Corpus/wav48/p224/",
                                          "p224_002.wav")
        wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk)

        # rate
        sample = T.Resample(sr_vctk,
                            16000,
                            resampling_method='sinc_interpolation')
        wf_vctk = sample(wf_vctk)
        # dither
        wf_vctk = F.dither(wf_vctk, noise_shaping=True)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(test_filepath_vctk)
        E.append_effect_to_chain("gain", ["-h"])
        E.append_effect_to_chain("channels", [1])
        E.append_effect_to_chain("rate", [16000])
        E.append_effect_to_chain("gain", ["-rh"])
        E.append_effect_to_chain("dither", ["-s"])
        wf_vctk_sox = E.sox_build_flow_effects()[0]

        torch.testing.assert_allclose(wf_vctk,
                                      wf_vctk_sox,
                                      rtol=1e-03,
                                      atol=1e-03)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_lowpass(self):
        """
        Test biquad lowpass filter, compare to SoX implementation
        """

        CUTOFF_FREQ = 3000

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("lowpass", [CUTOFF_FREQ])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)

        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-4,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_highpass(self):
        """
        Test biquad highpass filter, compare to SoX implementation
        """

        CUTOFF_FREQ = 2000

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("highpass", [CUTOFF_FREQ])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.highpass_biquad(waveform, sample_rate, CUTOFF_FREQ)

        # TBD - this fails at the 1e-4 level, debug why
        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-3,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_allpass(self):
        """
        Test biquad allpass filter, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("allpass", [CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.allpass_biquad(waveform, sample_rate, CENTRAL_FREQ,
                                           Q)

        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-4,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_bandpass_with_csg(self):
        """
        Test biquad bandpass filter, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707
        CONST_SKIRT_GAIN = True

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain(
            "bandpass", ["-c", CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.bandpass_biquad(waveform, sample_rate,
                                            CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)

        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-4,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_bandpass_without_csg(self):
        """
        Test biquad bandpass filter, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707
        CONST_SKIRT_GAIN = False

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("bandpass", [CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.bandpass_biquad(waveform, sample_rate,
                                            CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)

        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-4,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_bandreject(self):
        """
        Test biquad bandreject filter, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("bandreject", [CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.bandreject_biquad(waveform, sample_rate,
                                              CENTRAL_FREQ, Q)

        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-4,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_band_with_noise(self):
        """
        Test biquad band filter with noise mode, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707
        NOISE = True

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("band", ["-n", CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q,
                                        NOISE)

        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-4,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_band_without_noise(self):
        """
        Test biquad band filter without noise mode, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707
        NOISE = False

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("band", [CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q,
                                        NOISE)

        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-4,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_treble(self):
        """
        Test biquad treble filter, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707
        GAIN = 40

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("treble", [GAIN, CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.treble_biquad(waveform, sample_rate, GAIN,
                                          CENTRAL_FREQ, Q)

        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-4,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_deemph(self):
        """
        Test biquad deemph filter, compare to SoX implementation
        """

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("deemph")
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.deemph_biquad(waveform, sample_rate)

        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-4,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_riaa(self):
        """
        Test biquad riaa filter, compare to SoX implementation
        """

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("riaa")
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.riaa_biquad(waveform, sample_rate)

        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-4,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_equalizer(self):
        """
        Test biquad peaking equalizer filter, compare to SoX implementation
        """

        CENTER_FREQ = 300
        Q = 0.707
        GAIN = 1

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("equalizer", [CENTER_FREQ, Q, GAIN])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.equalizer_biquad(waveform, sample_rate,
                                             CENTER_FREQ, GAIN, Q)

        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-4,
                                      rtol=1e-5)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_perf_biquad_filtering(self):

        fn_sine = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")

        b0 = 0.4
        b1 = 0.2
        b2 = 0.9
        a0 = 0.7
        a1 = 0.2
        a2 = 0.6

        # SoX method
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(fn_sine)
        E.append_effect_to_chain("biquad", [b0, b1, b2, a0, a1, a2])
        waveform_sox_out, _ = E.sox_build_flow_effects()

        waveform, _ = torchaudio.load(fn_sine, normalization=True)
        waveform_lfilter_out = F.lfilter(waveform, torch.tensor([a0, a1, a2]),
                                         torch.tensor([b0, b1, b2]))

        torch.testing.assert_allclose(waveform_lfilter_out,
                                      waveform_sox_out,
                                      atol=1e-4,
                                      rtol=1e-5)
Exemplo n.º 12
0
class Tester(unittest.TestCase):

    # create a sinewave signal for testing
    sample_rate = 16000
    freq = 440
    volume = .3
    waveform = (torch.cos(2 * math.pi *
                          torch.arange(0, 4 * sample_rate).float() * freq /
                          sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
    # file for stereo stft test
    test_dirpath, test_dir = create_temp_assets_dir()
    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.wav')

    def scale(self, waveform, factor=2.0**31):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor

    def test_mu_law_companding(self):

        quantization_channels = 256

        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)

        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0.
                        and waveform_mu.max() <= quantization_channels)

        waveform_exp = transforms.MuLawDecoding(quantization_channels)(
            waveform_mu)
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)

    def test_AmplitudeToDB(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
        power_to_db_transform = transforms.AmplitudeToDB('power', 80.)

        mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
        power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))

        torch.testing.assert_allclose(mag_to_db_torch, power_to_db_torch)

    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(
            melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
        torch.testing.assert_allclose(fb, fb_copy)

    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(
            mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

        torch.testing.assert_allclose(window, window_copy)
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
        torch.testing.assert_allclose(fb, fb_copy)

    def test_mel2(self):
        top_db = 80.
        s2db = transforms.AmplitudeToDB('power', top_db)

        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
        mel_transform = transforms.MelSpectrogram()
        # check defaults
        spectrogram_torch = s2db(
            mel_transform(waveform_scaled))  # (1, 128, 321)
        self.assertTrue(spectrogram_torch.dim() == 3)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
        # check correctness of filterbank conversion matrix
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
        # check options
        kwargs = {
            'window_fn': torch.hamming_window,
            'pad': 10,
            'win_length': 500,
            'hop_length': 125,
            'n_fft': 800,
            'n_mels': 50
        }
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
        spectrogram2_torch = s2db(
            mel_transform2(waveform_scaled))  # (1, 50, 513)
        self.assertTrue(spectrogram2_torch.dim() == 3)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
        # check on multi-channel audio
        x_stereo, sr_stereo = torchaudio.load(
            self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
        # check filterbank matrix creation
        fb_matrix_transform = transforms.MelScale(n_mels=100,
                                                  sample_rate=16000,
                                                  f_min=0.,
                                                  f_max=None,
                                                  n_stft=400)
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))

    def test_mfcc(self):
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
        self.assertTrue(torch_mfcc.dim() == 3)
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
        # check melkwargs are passed through
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)

        # check norms work correctly
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(
            sample_rate=sample_rate, n_mfcc=n_mfcc, norm=None)
        torch_mfcc_norm_none = mfcc_transform_norm_none(
            audio_scaled)  # (1, 40, 321)

        norm_check = torch_mfcc.clone()
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
        waveform, sample_rate = torchaudio.load(input_path)

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='foo')

        self.assertRaises(ValueError, invalid_resample, waveform)

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
        up_sampled = upsample_resample(waveform)

        # we expect the upsampled signal to have twice as many samples
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate,
            downsample_rate,
            resampling_method='sinc_interpolation')
        down_sampled = downsample_resample(waveform)

        # we expect the downsampled signal to have half as many samples
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)

    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape,
                        (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(
            self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
        torch.testing.assert_allclose(computed_functional,
                                      computed_transform,
                                      atol=atol,
                                      rtol=rtol)

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas(win_length=3)
        computed = transform(specgram)
        assert computed.shape == expected.shape, (computed.shape,
                                                  expected.shape)
        torch.testing.assert_allclose(computed, expected, atol=1e-6, rtol=1e-8)
Exemplo n.º 13
0
class Test_SoxEffectsChain(unittest.TestCase):
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
    test_filepath = os.path.join(test_dirpath, "assets",
                                 "steam-train-whistle-daniel_simon.mp3")

    @classmethod
    def setUpClass(cls):
        torchaudio.initialize_sox()

    @classmethod
    def tearDownClass(cls):
        torchaudio.shutdown_sox()

    def test_single_channel(self):
        fn_sine = os.path.join(self.test_dirpath, "assets", "sinewave.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(fn_sine)
        E.append_effect_to_chain("echos", [0.8, 0.7, 40, 0.25, 63, 0.3])
        x, sr = E.sox_build_flow_effects()
        # check if effects worked
        # print(x.size())

    def test_rate_channels(self):
        target_rate = 16000
        target_channels = 1
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("rate", [target_rate])
        E.append_effect_to_chain("channels", [target_channels])
        x, sr = E.sox_build_flow_effects()
        # check if effects worked
        self.assertEqual(sr, target_rate)
        self.assertEqual(x.size(0), target_channels)

    def test_lowpass_speed(self):
        speed = .8
        si, _ = torchaudio.info(self.test_filepath)
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("lowpass", 100)
        E.append_effect_to_chain("speed", speed)
        E.append_effect_to_chain("rate", si.rate)
        x, sr = E.sox_build_flow_effects()
        # check if effects worked
        self.assertEqual(x.size(1), int((si.length / si.channels) / speed))

    def test_ulaw_and_siginfo(self):
        si_out = torchaudio.sox_signalinfo_t()
        ei_out = torchaudio.sox_encodinginfo_t()
        si_out.precision = 8
        ei_out.encoding = torchaudio.get_sox_encoding_t(9)
        ei_out.bits_per_sample = 8
        si_in, ei_in = torchaudio.info(self.test_filepath)
        si_out.rate = 44100
        si_out.channels = 2
        E = torchaudio.sox_effects.SoxEffectsChain(out_siginfo=si_out,
                                                   out_encinfo=ei_out)
        E.set_input_file(self.test_filepath)
        x, sr = E.sox_build_flow_effects()
        # Note: the output was encoded into ulaw because the
        #       number of unique values in the output is less than 256.
        self.assertLess(x.unique().size(0), 2**8 + 1)
        self.assertEqual(x.numel(), si_in.length)

    def test_band_chorus(self):
        si_in, ei_in = torchaudio.info(self.test_filepath)
        ei_in.encoding = torchaudio.get_sox_encoding_t(1)
        E = torchaudio.sox_effects.SoxEffectsChain(out_encinfo=ei_in,
                                                   out_siginfo=si_in)
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("band", ["-n", "10k", "3.5k"])
        E.append_effect_to_chain("chorus", [.5, .7, 55, 0.4, .25, 2, '-s'])
        E.append_effect_to_chain("rate", [si_in.rate])
        E.append_effect_to_chain("channels", [si_in.channels])
        x, sr = E.sox_build_flow_effects()
        # The chorus effect will make the output file longer than the input
        self.assertEqual(x.size(0), si_in.channels)
        self.assertGreaterEqual(x.size(1) * x.size(0), si_in.length)

    def test_synth(self):
        si_in, ei_in = torchaudio.info(self.test_filepath)
        len_in_seconds = si_in.length / si_in.channels / si_in.rate
        ei_in.encoding = torchaudio.get_sox_encoding_t(1)
        E = torchaudio.sox_effects.SoxEffectsChain(out_encinfo=ei_in,
                                                   out_siginfo=si_in)
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("synth",
                                 [str(len_in_seconds), "pinknoise", "mix"])
        E.append_effect_to_chain("rate", [44100])
        E.append_effect_to_chain("channels", [2])
        x, sr = E.sox_build_flow_effects()
        self.assertEqual(x.size(0), si_in.channels)
        self.assertEqual(si_in.length, x.size(0) * x.size(1))

    def test_gain(self):
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("gain", ["5"])
        x, sr = E.sox_build_flow_effects()
        E.clear_chain()
        self.assertTrue(x.abs().max().item(), 1.)
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("gain", ["-e", "-5"])
        x, sr = E.sox_build_flow_effects()
        E.clear_chain()
        self.assertLess(x.abs().max().item(), 1.)
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("gain", ["-b", "8"])
        x, sr = E.sox_build_flow_effects()
        E.clear_chain()
        self.assertTrue(x.abs().max().item(), 1.)
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("gain", ["-n", "-10"])
        x, sr = E.sox_build_flow_effects()
        E.clear_chain()
        self.assertLess(x.abs().max().item(), 1.)

    def test_tempo_or_speed(self):
        tempo = .8
        si, _ = torchaudio.info(self.test_filepath)
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("tempo", ["-s", tempo])
        x, sr = E.sox_build_flow_effects()
        # check if effect worked
        self.assertAlmostEqual(x.size(1),
                               math.ceil((si.length / si.channels) / tempo),
                               delta=1)
        # tempo > 1
        E.clear_chain()
        tempo = 1.2
        E.append_effect_to_chain("tempo", ["-s", tempo])
        x, sr = E.sox_build_flow_effects()
        # check if effect worked
        self.assertAlmostEqual(x.size(1),
                               math.ceil((si.length / si.channels) / tempo),
                               delta=1)
        # tempo > 1
        E.clear_chain()
        speed = 1.2
        E.append_effect_to_chain("speed", [speed])
        E.append_effect_to_chain("rate", [si.rate])
        x, sr = E.sox_build_flow_effects()
        # check if effect worked
        self.assertAlmostEqual(x.size(1),
                               math.ceil((si.length / si.channels) / speed),
                               delta=1)
        # speed < 1
        E.clear_chain()
        speed = 0.8
        E.append_effect_to_chain("speed", [speed])
        E.append_effect_to_chain("rate", [si.rate])
        x, sr = E.sox_build_flow_effects()
        # check if effect worked
        self.assertAlmostEqual(x.size(1),
                               math.ceil((si.length / si.channels) / speed),
                               delta=1)

    def test_trim(self):
        x_orig, _ = torchaudio.load(self.test_filepath)
        offset = "10000s"
        offset_int = int(offset[:-1])
        num_frames = "20000s"
        num_frames_int = int(num_frames[:-1])
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("trim", [offset, num_frames])
        x, sr = E.sox_build_flow_effects()
        # check if effect worked
        self.assertTrue(
            x.allclose(x_orig[:, offset_int:(offset_int + num_frames_int)],
                       rtol=1e-4,
                       atol=1e-4))

    def test_silence_contrast(self):
        si, _ = torchaudio.info(self.test_filepath)
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("silence", [1, 100, 1])
        E.append_effect_to_chain("contrast", [])
        x, sr = E.sox_build_flow_effects()
        # check if effect worked
        self.assertLess(x.numel(), si.length)

    def test_reverse(self):
        x_orig, _ = torchaudio.load(self.test_filepath)
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("reverse", "")
        x_rev, _ = E.sox_build_flow_effects()
        # check if effect worked
        rev_idx = torch.LongTensor(range(x_orig.size(1))[::-1])
        self.assertTrue(
            x_orig.allclose(x_rev[:, rev_idx], rtol=1e-5, atol=2e-5))

    def test_compand_fade(self):
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain(
            "compand", ["0.3,1", "6:-70,-60,-20", "-5", "-90", "0.2"])
        E.append_effect_to_chain("fade", ["q", "0.25", "0", "0.33"])
        x, _ = E.sox_build_flow_effects()
        # check if effect worked
        # print(x.size())

    def test_biquad_delay(self):
        si, _ = torchaudio.info(self.test_filepath)
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("biquad", [
            "0.25136437", "0.50272873", "0.25136437", "1.0", "-0.17123075",
            "0.17668821"
        ])
        E.append_effect_to_chain("delay", ["15000s"])
        x, _ = E.sox_build_flow_effects()
        # check if effect worked
        self.assertTrue(x.size(1) == (si.length / si.channels) + 15000)

    def test_invalid_effect_name(self):
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        # there is no effect named "special"
        with self.assertRaises(LookupError):
            E.append_effect_to_chain("special", [""])

    def test_unimplemented_effect(self):
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        # the sox spectrogram function is not implemented in torchaudio
        with self.assertRaises(NotImplementedError):
            E.append_effect_to_chain("spectrogram", [""])

    def test_invalid_effect_options(self):
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        # first two options should be combined to "0.3,1"
        E.append_effect_to_chain(
            "compand", ["0.3", "1", "6:-70,-60,-20", "-5", "-90", "0.2"])
        with self.assertRaises(RuntimeError):
            E.sox_build_flow_effects()
Exemplo n.º 14
0
class Tester(unittest.TestCase):

    # create a sinewave signal for testing
    sample_rate = 16000
    freq = 440
    volume = .3
    waveform = (torch.cos(2 * math.pi *
                          torch.arange(0, 4 * sample_rate).float() * freq /
                          sample_rate))
    waveform.unsqueeze_(0)  # (1, 64000)
    waveform = (waveform * volume * 2**31).long()
    # file for stereo stft test
    test_dirpath, test_dir = create_temp_assets_dir()
    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.wav')

    def scale(self, waveform, factor=float(2**31)):
        # scales a waveform by a factor
        if not waveform.is_floating_point():
            waveform = waveform.to(torch.get_default_dtype())
        return waveform / factor

    def test_mu_law_companding(self):

        quantization_channels = 256

        waveform = self.waveform.clone()
        waveform /= torch.abs(waveform).max()
        self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)

        waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
        self.assertTrue(waveform_mu.min() >= 0.
                        and waveform_mu.max() <= quantization_channels)

        waveform_exp = transforms.MuLawDecoding(quantization_channels)(
            waveform_mu)
        self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)

    def test_batch_AmplitudeToDB(self):
        spec = torch.rand((6, 201))

        # Single then transform then batch
        expected = transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.AmplitudeToDB()(spec.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_AmplitudeToDB(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
        power_to_db_transform = transforms.AmplitudeToDB('power', 80.)

        mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
        power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))

        self.assertTrue(torch.allclose(mag_to_db_torch, power_to_db_torch))

    def test_melscale_load_save(self):
        specgram = torch.ones(1, 1000, 100)
        melscale_transform = transforms.MelScale()
        melscale_transform(specgram)

        melscale_transform_copy = transforms.MelScale(n_stft=1000)
        melscale_transform_copy.load_state_dict(
            melscale_transform.state_dict())

        fb = melscale_transform.fb
        fb_copy = melscale_transform_copy.fb

        self.assertEqual(fb_copy.size(), (1000, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

    def test_melspectrogram_load_save(self):
        waveform = self.waveform.float()
        mel_spectrogram_transform = transforms.MelSpectrogram()
        mel_spectrogram_transform(waveform)

        mel_spectrogram_transform_copy = transforms.MelSpectrogram()
        mel_spectrogram_transform_copy.load_state_dict(
            mel_spectrogram_transform.state_dict())

        window = mel_spectrogram_transform.spectrogram.window
        window_copy = mel_spectrogram_transform_copy.spectrogram.window

        fb = mel_spectrogram_transform.mel_scale.fb
        fb_copy = mel_spectrogram_transform_copy.mel_scale.fb

        self.assertTrue(torch.allclose(window, window_copy))
        # the default for n_fft = 400 and n_mels = 128
        self.assertEqual(fb_copy.size(), (201, 128))
        self.assertTrue(torch.allclose(fb, fb_copy))

    def test_mel2(self):
        top_db = 80.
        s2db = transforms.AmplitudeToDB('power', top_db)

        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
        mel_transform = transforms.MelSpectrogram()
        # check defaults
        spectrogram_torch = s2db(
            mel_transform(waveform_scaled))  # (1, 128, 321)
        self.assertTrue(spectrogram_torch.dim() == 3)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
        # check correctness of filterbank conversion matrix
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
        # check options
        kwargs = {
            'window_fn': torch.hamming_window,
            'pad': 10,
            'win_length': 500,
            'hop_length': 125,
            'n_fft': 800,
            'n_mels': 50
        }
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
        spectrogram2_torch = s2db(
            mel_transform2(waveform_scaled))  # (1, 50, 513)
        self.assertTrue(spectrogram2_torch.dim() == 3)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
        # check on multi-channel audio
        x_stereo, sr_stereo = torchaudio.load(
            self.test_filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
        # check filterbank matrix creation
        fb_matrix_transform = transforms.MelScale(n_mels=100,
                                                  sample_rate=16000,
                                                  f_min=0.,
                                                  f_max=None,
                                                  n_stft=400)
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))

    def test_mfcc(self):
        audio_orig = self.waveform.clone()
        audio_scaled = self.scale(audio_orig)  # (1, 16000)

        sample_rate = 16000
        n_mfcc = 40
        n_mels = 128
        mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
                                                    n_mfcc=n_mfcc,
                                                    norm='ortho')
        # check defaults
        torch_mfcc = mfcc_transform(audio_scaled)  # (1, 40, 321)
        self.assertTrue(torch_mfcc.dim() == 3)
        self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
        self.assertTrue(torch_mfcc.shape[2] == 321)
        # check melkwargs are passed through
        melkwargs = {'win_length': 200}
        mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
                                                     n_mfcc=n_mfcc,
                                                     norm='ortho',
                                                     melkwargs=melkwargs)
        torch_mfcc2 = mfcc_transform2(audio_scaled)  # (1, 40, 641)
        self.assertTrue(torch_mfcc2.shape[2] == 641)

        # check norms work correctly
        mfcc_transform_norm_none = torchaudio.transforms.MFCC(
            sample_rate=sample_rate, n_mfcc=n_mfcc, norm=None)
        torch_mfcc_norm_none = mfcc_transform_norm_none(
            audio_scaled)  # (1, 40, 321)

        norm_check = torch_mfcc.clone()
        norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
        norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2

        self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))

    def test_batch_Resample(self):
        waveform = torch.randn(2, 2786)

        # Single then transform then batch
        expected = transforms.Resample()(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Resample()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_resample_size(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
        waveform, sample_rate = torchaudio.load(input_path)

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='foo')

        self.assertRaises(ValueError, invalid_resample, waveform)

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
        up_sampled = upsample_resample(waveform)

        # we expect the upsampled signal to have twice as many samples
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate,
            downsample_rate,
            resampling_method='sinc_interpolation')
        down_sampled = downsample_resample(waveform)

        # we expect the downsampled signal to have half as many samples
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)

    def test_compute_deltas(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        transform = transforms.ComputeDeltas(win_length=win_length)
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape,
                        (computed.shape, specgram.shape))

    def test_compute_deltas_transform_same_as_functional(
            self, atol=1e-6, rtol=1e-8):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)

        transform = transforms.ComputeDeltas(win_length=win_length)
        computed_transform = transform(specgram)

        computed_functional = F.compute_deltas(specgram, win_length=win_length)
        torch.testing.assert_allclose(computed_functional,
                                      computed_transform,
                                      atol=atol,
                                      rtol=rtol)

    def test_compute_deltas_twochannel(self):
        specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], [0.5, 1.0, 1.0, 0.5]]])
        transform = transforms.ComputeDeltas()
        computed = transform(specgram)
        self.assertTrue(computed.shape == specgram.shape,
                        (computed.shape, specgram.shape))

    def test_batch_MelScale(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_batch_InverseMelScale(self):
        n_fft = 8
        n_mels = 32
        n_stft = 5
        mel_spec = torch.randn(2, n_mels, 32)**2

        # Single then transform then batch
        expected = transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(
            3, 1, 1, 1)

        # Batch then transform
        computed = transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(
            3, 1, 1, 1))

        # shape = (3, 2, n_mels, 32)
        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))

        # Because InverseMelScale runs SGD on randomly initialized values so they do not yield
        # exactly same result. For this reason, tolerance is very relaxed here.
        self.assertTrue(torch.allclose(computed, expected, atol=1.0))

    def test_batch_compute_deltas(self):
        specgram = torch.randn(2, 31, 2786)

        # Single then transform then batch
        expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1))

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_batch_mulaw(self):
        waveform, sample_rate = torchaudio.load(
            self.test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        waveform_encoded = transforms.MuLawEncoding()(waveform)
        expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = transforms.MuLawEncoding()(waveform_batched)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

        # Single then transform then batch
        waveform_decoded = transforms.MuLawDecoding()(waveform_encoded)
        expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.MuLawDecoding()(computed)

        # shape = (3, 2, 201, 1394)
        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_batch_spectrogram(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.Spectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.Spectrogram()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_batch_melspectrogram(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_batch_mfcc(self):
        test_filepath = os.path.join(self.test_dirpath, 'assets',
                                     'steam-train-whistle-daniel_simon.mp3')
        waveform, sample_rate = torchaudio.load(test_filepath)

        # Single then transform then batch
        expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = transforms.MFCC()(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected, atol=1e-5))

    def test_batch_TimeStretch(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        kwargs = {
            'n_fft': 2048,
            'hop_length': 512,
            'win_length': 2048,
            'window': torch.hann_window(2048),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }
        rate = 2

        complex_specgrams = torch.stft(waveform, **kwargs)

        # Single then transform then batch
        expected = transforms.TimeStretch(
            fixed_rate=rate, n_freq=1025,
            hop_length=512)(complex_specgrams).repeat(3, 1, 1, 1, 1)

        # Batch then transform
        computed = transforms.TimeStretch(
            fixed_rate=rate, n_freq=1025,
            hop_length=512)(complex_specgrams.repeat(3, 1, 1, 1, 1))

        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected, atol=1e-5))

    def test_batch_Fade(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)
        fade_in_len = 3000
        fade_out_len = 3000

        # Single then transform then batch
        expected = transforms.Fade(fade_in_len,
                                   fade_out_len)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Fade(fade_in_len,
                                   fade_out_len)(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))

    def test_batch_Vol(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = transforms.Vol(gain=1.1)(waveform).repeat(3, 1, 1)

        # Batch then transform
        computed = transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))

        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))
Exemplo n.º 15
0
class Test_LoadSave(unittest.TestCase):
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
    test_filepath = os.path.join(test_dirpath, "assets",
                                 "steam-train-whistle-daniel_simon.mp3")

    def test_1_save(self):
        # load signal
        x, sr = torchaudio.load(self.test_filepath, normalization=False)

        # check save
        new_filepath = os.path.join(self.test_dirpath, "test.wav")
        torchaudio.save(new_filepath, x, sr)
        self.assertTrue(os.path.isfile(new_filepath))
        os.unlink(new_filepath)

        # check automatic normalization
        x /= 1 << 31
        torchaudio.save(new_filepath, x, sr)
        self.assertTrue(os.path.isfile(new_filepath))
        os.unlink(new_filepath)

        # test save 1d tensor
        x = x[0, :]  # get mono signal
        x.squeeze_()  # remove channel dim
        torchaudio.save(new_filepath, x, sr)
        self.assertTrue(os.path.isfile(new_filepath))
        os.unlink(new_filepath)

        # don't allow invalid sizes as inputs
        with self.assertRaises(ValueError):
            x.unsqueeze_(1)  # L x C not C x L
            torchaudio.save(new_filepath, x, sr)

        with self.assertRaises(ValueError):
            x.squeeze_()
            x.unsqueeze_(1)
            x.unsqueeze_(0)  # 1 x L x 1
            torchaudio.save(new_filepath, x, sr)

        # don't save to folders that don't exist
        with self.assertRaises(OSError):
            new_filepath = os.path.join(self.test_dirpath, "no-path",
                                        "test.wav")
            torchaudio.save(new_filepath, x, sr)

        # save created file
        sinewave_filepath = os.path.join(self.test_dirpath, "assets",
                                         "sinewave.wav")
        sr = 16000
        freq = 440
        volume = 0.3

        y = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr).float() * freq /
                       sr))
        y.unsqueeze_(0)
        # y is between -1 and 1, so must scale
        y = (y * volume * (2**31)).long()
        torchaudio.save(sinewave_filepath, y, sr)
        self.assertTrue(os.path.isfile(sinewave_filepath))

        # test precision
        new_precision = 32
        new_filepath = os.path.join(self.test_dirpath, "test.wav")
        si, ei = torchaudio.info(sinewave_filepath)
        torchaudio.save(new_filepath, y, sr, new_precision)
        si32, ei32 = torchaudio.info(new_filepath)
        self.assertEqual(si.precision, 16)
        self.assertEqual(si32.precision, new_precision)
        os.unlink(new_filepath)

    def test_2_load(self):
        # check normal loading
        x, sr = torchaudio.load(self.test_filepath)
        self.assertEqual(sr, 44100)
        self.assertEqual(x.size(), (2, 278756))

        # check no normalizing
        x, _ = torchaudio.load(self.test_filepath, normalization=False)
        self.assertTrue(x.min() <= -1.0)
        self.assertTrue(x.max() >= 1.0)

        # check offset
        offset = 15
        x, _ = torchaudio.load(self.test_filepath)
        x_offset, _ = torchaudio.load(self.test_filepath, offset=offset)
        self.assertTrue(x[:, offset:].allclose(x_offset))

        # check number of frames
        n = 201
        x, _ = torchaudio.load(self.test_filepath, num_frames=n)
        self.assertTrue(x.size(), (2, n))

        # check channels first
        x, _ = torchaudio.load(self.test_filepath, channels_first=False)
        self.assertEqual(x.size(), (278756, 2))

        # check different input tensor type
        x, _ = torchaudio.load(self.test_filepath,
                               torch.LongTensor(),
                               normalization=False)
        self.assertTrue(isinstance(x, torch.LongTensor))

        # check raising errors
        with self.assertRaises(OSError):
            torchaudio.load("file-does-not-exist.mp3")

        with self.assertRaises(OSError):
            tdir = os.path.join(os.path.dirname(self.test_dirpath),
                                "torchaudio")
            torchaudio.load(tdir)

    def test_3_load_and_save_is_identity(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
        tensor, sample_rate = torchaudio.load(input_path)
        output_path = os.path.join(self.test_dirpath, 'test.wav')
        torchaudio.save(output_path, tensor, sample_rate)
        tensor2, sample_rate2 = torchaudio.load(output_path)
        self.assertTrue(tensor.allclose(tensor2))
        self.assertEqual(sample_rate, sample_rate2)
        os.unlink(output_path)

    def test_4_load_partial(self):
        num_frames = 101
        offset = 201
        # load entire mono sinewave wav file, load a partial copy and then compare
        input_sine_path = os.path.join(self.test_dirpath, 'assets',
                                       'sinewave.wav')
        x_sine_full, sr_sine = torchaudio.load(input_sine_path)
        x_sine_part, _ = torchaudio.load(input_sine_path,
                                         num_frames=num_frames,
                                         offset=offset)
        l1_error = x_sine_full[:, offset:(
            num_frames + offset)].sub(x_sine_part).abs().sum().item()
        # test for the correct number of samples and that the correct portion was loaded
        self.assertEqual(x_sine_part.size(1), num_frames)
        self.assertEqual(l1_error, 0.)
        # create a two channel version of this wavefile
        x_2ch_sine = x_sine_full.repeat(1, 2)
        out_2ch_sine_path = os.path.join(self.test_dirpath, 'assets',
                                         '2ch_sinewave.wav')
        torchaudio.save(out_2ch_sine_path, x_2ch_sine, sr_sine)
        x_2ch_sine_load, _ = torchaudio.load(out_2ch_sine_path,
                                             num_frames=num_frames,
                                             offset=offset)
        os.unlink(out_2ch_sine_path)
        l1_error = x_2ch_sine_load.sub(
            x_2ch_sine[:, offset:(offset + num_frames)]).abs().sum().item()
        self.assertEqual(l1_error, 0.)

        # test with two channel mp3
        x_2ch_full, sr_2ch = torchaudio.load(self.test_filepath,
                                             normalization=True)
        x_2ch_part, _ = torchaudio.load(self.test_filepath,
                                        normalization=True,
                                        num_frames=num_frames,
                                        offset=offset)
        l1_error = x_2ch_full[:, offset:(
            offset + num_frames)].sub(x_2ch_part).abs().sum().item()
        self.assertEqual(x_2ch_part.size(1), num_frames)
        self.assertEqual(l1_error, 0.)

        # check behavior if number of samples would exceed file length
        offset_ns = 300
        x_ns, _ = torchaudio.load(input_sine_path,
                                  num_frames=100000,
                                  offset=offset_ns)
        self.assertEqual(x_ns.size(1), x_sine_full.size(1) - offset_ns)

        # check when offset is beyond the end of the file
        with self.assertRaises(RuntimeError):
            torchaudio.load(input_sine_path, offset=100000)

    def test_5_get_info(self):
        input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
        channels, samples, rate, precision = (1, 64000, 16000, 16)
        si, ei = torchaudio.info(input_path)
        self.assertEqual(si.channels, channels)
        self.assertEqual(si.length, samples)
        self.assertEqual(si.rate, rate)
        self.assertEqual(ei.bits_per_sample, precision)
Exemplo n.º 16
0
class TestFunctional(unittest.TestCase):
    data_sizes = [(2, 20), (3, 15), (4, 10)]
    number_of_trials = 100
    specgram = torch.tensor([1., 2., 3., 4.])

    test_dirpath, test_dir = common_utils.create_temp_assets_dir()

    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.mp3')
    waveform_train, sr_train = torchaudio.load(test_filepath)

    def test_torchscript_spectrogram(self):

        tensor = torch.rand((1, 1000))
        n_fft = 400
        ws = 400
        hop = 200
        pad = 0
        window = torch.hann_window(ws)
        power = 2
        normalize = False

        _test_torchscript_functional(F.spectrogram, tensor, pad, window, n_fft,
                                     hop, ws, power, normalize)

    def test_torchscript_griffinlim(self):
        tensor = torch.rand((1, 201, 6))
        n_fft = 400
        ws = 400
        hop = 200
        window = torch.hann_window(ws)
        power = 2
        normalize = False
        momentum = 0.99
        n_iter = 32
        length = 1000
        init = 0

        _test_torchscript_functional(F.griffinlim, tensor, window, n_fft, hop,
                                     ws, power, normalize, n_iter, momentum,
                                     length, 0)

    @unittest.skipIf(not IMPORT_LIBROSA, 'Librosa not available')
    def test_griffinlim(self):

        # NOTE: This test is flaky without a fixed random seed
        # See https://github.com/pytorch/audio/issues/382
        torch.random.manual_seed(42)
        tensor = torch.rand((1, 1000))

        n_fft = 400
        ws = 400
        hop = 100
        window = torch.hann_window(ws)
        normalize = False
        momentum = 0.99
        n_iter = 8
        length = 1000
        rand_init = False
        init = 'random' if rand_init else None

        specgram = F.spectrogram(tensor, 0, window, n_fft, hop, ws, 2,
                                 normalize).sqrt()
        ta_out = F.griffinlim(specgram, window, n_fft, hop, ws, 1, normalize,
                              n_iter, momentum, length, rand_init)
        lr_out = librosa.griffinlim(specgram.squeeze(0).numpy(),
                                    n_iter=n_iter,
                                    hop_length=hop,
                                    momentum=momentum,
                                    init=init,
                                    length=length)
        lr_out = torch.from_numpy(lr_out).unsqueeze(0)

        self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5))

    def _test_compute_deltas(self,
                             specgram,
                             expected,
                             win_length=3,
                             atol=1e-6,
                             rtol=1e-8):
        computed = F.compute_deltas(specgram, win_length=win_length)
        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        torch.testing.assert_allclose(computed, expected, atol=atol, rtol=rtol)

    def test_compute_deltas_onechannel(self):
        specgram = self.specgram.unsqueeze(0).unsqueeze(0)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
        self._test_compute_deltas(specgram, expected)

    def test_compute_deltas_twochannel(self):
        specgram = self.specgram.repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], [0.5, 1.0, 1.0, 0.5]]])
        self._test_compute_deltas(specgram, expected)

    def test_compute_deltas_randn(self):
        channel = 13
        n_mfcc = channel * 3
        time = 1021
        win_length = 2 * 7 + 1
        specgram = torch.randn(channel, n_mfcc, time)
        computed = F.compute_deltas(specgram, win_length=win_length)
        self.assertTrue(computed.shape == specgram.shape,
                        (computed.shape, specgram.shape))
        _test_torchscript_functional(F.compute_deltas,
                                     specgram,
                                     win_length=win_length)

    def test_batch_pitch(self):
        waveform, sample_rate = torchaudio.load(self.test_filepath)

        # Single then transform then batch
        expected = F.detect_pitch_frequency(waveform, sample_rate)
        expected = expected.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = F.detect_pitch_frequency(waveform, sample_rate)

        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        self.assertTrue(torch.allclose(computed, expected))
        _test_torchscript_functional(F.detect_pitch_frequency, waveform,
                                     sample_rate)

    def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
        # trim sound for case when constructed signal is shorter than original
        sound = sound[..., :estimate.size(-1)]

        self.assertTrue(sound.shape == estimate.shape,
                        (sound.shape, estimate.shape))
        self.assertTrue(torch.allclose(sound, estimate, atol=atol, rtol=rtol))

    def _test_istft_is_inverse_of_stft(self, kwargs):
        # generates a random sound signal for each tril and then does the stft/istft
        # operation to check whether we can reconstruct signal
        for data_size in self.data_sizes:
            for i in range(self.number_of_trials):

                # Non-batch
                sound = common_utils.random_float_tensor(i, data_size)

                stft = torch.stft(sound, **kwargs)
                estimate = torchaudio.functional.istft(stft,
                                                       length=sound.size(1),
                                                       **kwargs)

                self._compare_estimate(sound, estimate)

                # Batch
                stft = torch.stft(sound, **kwargs)
                stft = stft.repeat(3, 1, 1, 1, 1)
                sound = sound.repeat(3, 1, 1)

                estimate = torchaudio.functional.istft(stft,
                                                       length=sound.size(1),
                                                       **kwargs)
                self._compare_estimate(sound, estimate)

    def test_istft_is_inverse_of_stft1(self):
        # hann_window, centered, normalized, onesided
        kwargs1 = {
            'n_fft': 12,
            'hop_length': 4,
            'win_length': 12,
            'window': torch.hann_window(12),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }

        self._test_istft_is_inverse_of_stft(kwargs1)

    def test_istft_is_inverse_of_stft2(self):
        # hann_window, centered, not normalized, not onesided
        kwargs2 = {
            'n_fft': 12,
            'hop_length': 2,
            'win_length': 8,
            'window': torch.hann_window(8),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': False,
            'onesided': False,
        }

        self._test_istft_is_inverse_of_stft(kwargs2)

    def test_istft_is_inverse_of_stft3(self):
        # hamming_window, centered, normalized, not onesided
        kwargs3 = {
            'n_fft': 15,
            'hop_length': 3,
            'win_length': 11,
            'window': torch.hamming_window(11),
            'center': True,
            'pad_mode': 'constant',
            'normalized': True,
            'onesided': False,
        }

        self._test_istft_is_inverse_of_stft(kwargs3)

    def test_istft_is_inverse_of_stft4(self):
        # hamming_window, not centered, not normalized, onesided
        # window same size as n_fft
        kwargs4 = {
            'n_fft': 5,
            'hop_length': 2,
            'win_length': 5,
            'window': torch.hamming_window(5),
            'center': False,
            'pad_mode': 'constant',
            'normalized': False,
            'onesided': True,
        }

        self._test_istft_is_inverse_of_stft(kwargs4)

    def test_istft_is_inverse_of_stft5(self):
        # hamming_window, not centered, not normalized, not onesided
        # window same size as n_fft
        kwargs5 = {
            'n_fft': 3,
            'hop_length': 2,
            'win_length': 3,
            'window': torch.hamming_window(3),
            'center': False,
            'pad_mode': 'reflect',
            'normalized': False,
            'onesided': False,
        }

        self._test_istft_is_inverse_of_stft(kwargs5)

    def test_istft_of_ones(self):
        # stft = torch.stft(torch.ones(4), 4)
        stft = torch.tensor([[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4.,
                                                                       0.]],
                             [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0.,
                                                                       0.]],
                             [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0.,
                                                                       0.]]])

        estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
        self._compare_estimate(torch.ones(4), estimate)

    def test_istft_of_zeros(self):
        # stft = torch.stft(torch.zeros(4), 4)
        stft = torch.zeros((3, 5, 2))

        estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
        self._compare_estimate(torch.zeros(4), estimate)

    def test_istft_requires_overlap_windows(self):
        # the window is size 1 but it hops 20 so there is a gap which throw an error
        stft = torch.zeros((3, 5, 2))
        self.assertRaises(AssertionError,
                          torchaudio.functional.istft,
                          stft,
                          n_fft=4,
                          hop_length=20,
                          win_length=1,
                          window=torch.ones(1))

    def test_istft_requires_nola(self):
        stft = torch.zeros((3, 5, 2))
        kwargs_ok = {
            'n_fft': 4,
            'win_length': 4,
            'window': torch.ones(4),
        }

        kwargs_not_ok = {
            'n_fft': 4,
            'win_length': 4,
            'window': torch.zeros(4),
        }

        # A window of ones meets NOLA but a window of zeros does not. This should
        # throw an error.
        torchaudio.functional.istft(stft, **kwargs_ok)
        self.assertRaises(AssertionError, torchaudio.functional.istft, stft,
                          **kwargs_not_ok)

    def test_istft_requires_non_empty(self):
        self.assertRaises(AssertionError, torchaudio.functional.istft,
                          torch.zeros((3, 0, 2)), 2)
        self.assertRaises(AssertionError, torchaudio.functional.istft,
                          torch.zeros((0, 3, 2)), 2)

    def _test_istft_of_sine(self, amplitude, L, n):
        # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
        x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype())
        sound = amplitude * torch.sin(2 * math.pi / L * x * n)
        # stft = torch.stft(sound, L, hop_length=L, win_length=L,
        #                   window=torch.ones(L), center=False, normalized=False)
        stft = torch.zeros((L // 2 + 1, 2, 2))
        stft_largest_val = (amplitude * L) / 2.0
        if n < stft.size(0):
            stft[n, :, 1] = -stft_largest_val

        if 0 <= L - n < stft.size(0):
            # symmetric about L // 2
            stft[L - n, :, 1] = stft_largest_val

        estimate = torchaudio.functional.istft(stft,
                                               L,
                                               hop_length=L,
                                               win_length=L,
                                               window=torch.ones(L),
                                               center=False,
                                               normalized=False)
        # There is a larger error due to the scaling of amplitude
        self._compare_estimate(sound, estimate, atol=1e-3)

    def test_istft_of_sine(self):
        self._test_istft_of_sine(amplitude=123, L=5, n=1)
        self._test_istft_of_sine(amplitude=150, L=5, n=2)
        self._test_istft_of_sine(amplitude=111, L=5, n=3)
        self._test_istft_of_sine(amplitude=160, L=7, n=4)
        self._test_istft_of_sine(amplitude=145, L=8, n=5)
        self._test_istft_of_sine(amplitude=80, L=9, n=6)
        self._test_istft_of_sine(amplitude=99, L=10, n=7)

    def _test_linearity_of_istft(self,
                                 data_size,
                                 kwargs,
                                 atol=1e-6,
                                 rtol=1e-8):
        for i in range(self.number_of_trials):
            tensor1 = common_utils.random_float_tensor(i, data_size)
            tensor2 = common_utils.random_float_tensor(i * 2, data_size)
            a, b = torch.rand(2)
            istft1 = torchaudio.functional.istft(tensor1, **kwargs)
            istft2 = torchaudio.functional.istft(tensor2, **kwargs)
            istft = a * istft1 + b * istft2
            estimate = torchaudio.functional.istft(a * tensor1 + b * tensor2,
                                                   **kwargs)
            self._compare_estimate(istft, estimate, atol, rtol)

    def test_linearity_of_istft1(self):
        # hann_window, centered, normalized, onesided
        kwargs1 = {
            'n_fft': 12,
            'window': torch.hann_window(12),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }
        data_size = (2, 7, 7, 2)
        self._test_linearity_of_istft(data_size, kwargs1)

    def test_linearity_of_istft2(self):
        # hann_window, centered, not normalized, not onesided
        kwargs2 = {
            'n_fft': 12,
            'window': torch.hann_window(12),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': False,
            'onesided': False,
        }
        data_size = (2, 12, 7, 2)
        self._test_linearity_of_istft(data_size, kwargs2)

    def test_linearity_of_istft3(self):
        # hamming_window, centered, normalized, not onesided
        kwargs3 = {
            'n_fft': 12,
            'window': torch.hamming_window(12),
            'center': True,
            'pad_mode': 'constant',
            'normalized': True,
            'onesided': False,
        }
        data_size = (2, 12, 7, 2)
        self._test_linearity_of_istft(data_size, kwargs3)

    def test_linearity_of_istft4(self):
        # hamming_window, not centered, not normalized, onesided
        kwargs4 = {
            'n_fft': 12,
            'window': torch.hamming_window(12),
            'center': False,
            'pad_mode': 'constant',
            'normalized': False,
            'onesided': True,
        }
        data_size = (2, 7, 3, 2)
        self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)

    def _test_create_fb(self,
                        n_mels=40,
                        sample_rate=22050,
                        n_fft=2048,
                        fmin=0.0,
                        fmax=8000.0):
        # Using a decorator here causes parametrize to fail on Python 2
        if not IMPORT_LIBROSA:
            raise unittest.SkipTest('Librosa is not available')

        librosa_fb = librosa.filters.mel(sr=sample_rate,
                                         n_fft=n_fft,
                                         n_mels=n_mels,
                                         fmax=fmax,
                                         fmin=fmin,
                                         htk=True,
                                         norm=None)
        fb = F.create_fb_matrix(sample_rate=sample_rate,
                                n_mels=n_mels,
                                f_max=fmax,
                                f_min=fmin,
                                n_freqs=(n_fft // 2 + 1))

        for i_mel_bank in range(n_mels):
            assert torch.allclose(fb[:, i_mel_bank],
                                  torch.tensor(librosa_fb[i_mel_bank]),
                                  atol=1e-4)

    def test_create_fb(self):
        self._test_create_fb()
        self._test_create_fb(n_mels=128, sample_rate=44100)
        self._test_create_fb(n_mels=128, fmin=2000.0, fmax=5000.0)
        self._test_create_fb(n_mels=56, fmin=100.0, fmax=9000.0)
        self._test_create_fb(n_mels=56, fmin=800.0, fmax=900.0)
        self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
        self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)

    def test_gain(self):
        waveform_gain = F.gain(self.waveform_train, 3)
        self.assertTrue(waveform_gain.abs().max().item(), 1.)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("gain", [3])
        sox_gain_waveform = E.sox_build_flow_effects()[0]

        self.assertTrue(
            torch.allclose(waveform_gain, sox_gain_waveform, atol=1e-04))

    def test_dither(self):
        waveform_dithered = F.dither(self.waveform_train)
        waveform_dithered_noiseshaped = F.dither(self.waveform_train,
                                                 noise_shaping=True)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("dither", [])
        sox_dither_waveform = E.sox_build_flow_effects()[0]

        self.assertTrue(
            torch.allclose(waveform_dithered, sox_dither_waveform, atol=1e-04))
        E.clear_chain()

        E.append_effect_to_chain("dither", ["-s"])
        sox_dither_waveform_ns = E.sox_build_flow_effects()[0]

        self.assertTrue(
            torch.allclose(waveform_dithered_noiseshaped,
                           sox_dither_waveform_ns,
                           atol=1e-02))

    def test_vctk_transform_pipeline(self):
        test_filepath_vctk = os.path.join(self.test_dirpath,
                                          "assets/VCTK-Corpus/wav48/p224/",
                                          "p224_002.wav")
        wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk)

        # rate
        sample = T.Resample(sr_vctk,
                            16000,
                            resampling_method='sinc_interpolation')
        wf_vctk = sample(wf_vctk)
        # dither
        wf_vctk = F.dither(wf_vctk, noise_shaping=True)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(test_filepath_vctk)
        E.append_effect_to_chain("gain", ["-h"])
        E.append_effect_to_chain("channels", [1])
        E.append_effect_to_chain("rate", [16000])
        E.append_effect_to_chain("gain", ["-rh"])
        E.append_effect_to_chain("dither", ["-s"])
        wf_vctk_sox = E.sox_build_flow_effects()[0]

        self.assertTrue(
            torch.allclose(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03))

    def test_pitch(self):
        test_dirpath, test_dir = common_utils.create_temp_assets_dir()
        test_filepath_100 = os.path.join(test_dirpath, 'assets',
                                         "100Hz_44100Hz_16bit_05sec.wav")
        test_filepath_440 = os.path.join(test_dirpath, 'assets',
                                         "440Hz_44100Hz_16bit_05sec.wav")

        # Files from https://www.mediacollege.com/audio/tone/download/
        tests = [
            (test_filepath_100, 100),
            (test_filepath_440, 440),
        ]

        for filename, freq_ref in tests:
            waveform, sample_rate = torchaudio.load(filename)

            freq = torchaudio.functional.detect_pitch_frequency(
                waveform, sample_rate)

            threshold = 1
            s = ((freq - freq_ref).abs() > threshold).sum()
            self.assertFalse(s)

            # Convert to stereo and batch for testing purposes
            freq = freq.repeat(3, 2, 1, 1)
            waveform = waveform.repeat(3, 2, 1, 1)

            freq2 = torchaudio.functional.detect_pitch_frequency(
                waveform, sample_rate)

            assert torch.allclose(freq, freq2, atol=1e-5)

    def _test_batch(self, functional):
        waveform, sample_rate = torchaudio.load(
            self.test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1)

        # Batch then transform
        waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = functional(waveform)
class Test_Kaldi(unittest.TestCase):
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()
    test_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file.wav')
    test_8000_filepath = os.path.join(test_dirpath, 'assets',
                                      'kaldi_file_8000.wav')
    kaldi_output_dir = os.path.join(test_dirpath, 'assets', 'kaldi')
    test_filepaths = {prefix: [] for prefix in compliance.utils.TEST_PREFIX}

    # separating test files by their types (e.g 'spec', 'fbank', etc.)
    for f in os.listdir(kaldi_output_dir):
        dash_idx = f.find('-')
        assert f.endswith('.ark') and dash_idx != -1
        key = f[:dash_idx]
        assert key in test_filepaths
        test_filepaths[key].append(f)

    def _test_get_strided_helper(self, num_samples, window_size, window_shift,
                                 snip_edges):
        waveform = torch.arange(num_samples).float()
        output = kaldi._get_strided(waveform, window_size, window_shift,
                                    snip_edges)

        # from NumFrames in feature-window.cc
        n = window_size
        if snip_edges:
            m = 0 if num_samples < window_size else 1 + (
                num_samples - window_size) // window_shift
        else:
            m = (num_samples + (window_shift // 2)) // window_shift

        self.assertTrue(output.dim() == 2)
        self.assertTrue(output.shape[0] == m and output.shape[1] == n)

        window = torch.empty((m, window_size))

        for r in range(m):
            extract_window(window, waveform, r, window_size, window_shift,
                           snip_edges)
        self.assertTrue(torch.allclose(window, output))

    def test_get_strided(self):
        # generate any combination where 0 < window_size <= num_samples and
        # 0 < window_shift.
        for num_samples in range(1, 20):
            for window_size in range(1, num_samples + 1):
                for window_shift in range(1, 2 * num_samples + 1):
                    for snip_edges in range(0, 2):
                        self._test_get_strided_helper(num_samples, window_size,
                                                      window_shift, snip_edges)

    def _create_data_set(self):
        # used to generate the dataset to test on. this is not used in testing (offline procedure)
        test_dirpath = os.path.dirname(
            os.path.dirname(os.path.realpath(__file__)))
        test_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file.wav')
        sr = 16000
        x = torch.arange(0, 20).float()
        # between [-6,6]
        y = torch.cos(
            2 * math.pi * x) + 3 * torch.sin(math.pi * x) + 2 * torch.cos(x)
        # between [-2^30, 2^30]
        y = (y / 6 * (1 << 30)).long()
        # clear the last 16 bits because they aren't used anyways
        y = ((y >> 16) << 16).float()
        torchaudio.save(test_filepath, y, sr)
        sound, sample_rate = torchaudio.load(test_filepath,
                                             normalization=False)
        print(y >> 16)
        self.assertTrue(sample_rate == sr)
        self.assertTrue(torch.allclose(y, sound))

    def _print_diagnostic(self, output, expect_output):
        # given an output and expected output, it will print the absolute/relative errors (max and mean squared)
        abs_error = output - expect_output
        abs_mse = abs_error.pow(2).sum() / output.numel()
        abs_max_error = torch.max(abs_error.abs())

        relative_error = abs_error / expect_output
        relative_mse = relative_error.pow(2).sum() / output.numel()
        relative_max_error = torch.max(relative_error.abs())

        print('abs_mse:', abs_mse.item(), 'abs_max_error:',
              abs_max_error.item())
        print('relative_mse:', relative_mse.item(), 'relative_max_error:',
              relative_max_error.item())

    def _compliance_test_helper(self,
                                sound_filepath,
                                filepath_key,
                                expected_num_files,
                                expected_num_args,
                                get_output_fn,
                                atol=1e-5,
                                rtol=1e-8):
        """
        Inputs:
            sound_filepath (str): The location of the sound file
            filepath_key (str): A key to `test_filepaths` which matches which files to use
            expected_num_files (int): The expected number of kaldi files to read
            expected_num_args (int): The expected number of arguments used in a kaldi configuration
            get_output_fn (Callable[[Tensor, List], Tensor]): A function that takes in a sound signal
                and a configuration and returns an output
            atol (float): absolute tolerance
            rtol (float): relative tolerance
        """
        sound, sample_rate = torchaudio.load_wav(sound_filepath)
        files = self.test_filepaths[filepath_key]

        assert len(files) == expected_num_files, (
            'number of kaldi %s file changed to %d' %
            (filepath_key, len(files)))

        for f in files:
            print(f)

            # Read kaldi's output from file
            kaldi_output_path = os.path.join(self.kaldi_output_dir, f)
            kaldi_output_dict = {
                k: v
                for k, v in torchaudio.kaldi_io.read_mat_ark(kaldi_output_path)
            }

            assert len(
                kaldi_output_dict
            ) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file'
            kaldi_output = kaldi_output_dict['my_id']

            # Construct the same configuration used by kaldi
            args = f.split('-')
            args[-1] = os.path.splitext(args[-1])[0]
            assert len(
                args) == expected_num_args, 'invalid test kaldi file name'
            args = [compliance.utils.parse(arg) for arg in args]

            output = get_output_fn(sound, args)

            self._print_diagnostic(output, kaldi_output)
            self.assertTrue(output.shape, kaldi_output.shape)
            self.assertTrue(
                torch.allclose(output, kaldi_output, atol=atol, rtol=rtol))

    def test_spectrogram(self):
        def get_output_fn(sound, args):
            output = kaldi.spectrogram(sound,
                                       blackman_coeff=args[1],
                                       dither=args[2],
                                       energy_floor=args[3],
                                       frame_length=args[4],
                                       frame_shift=args[5],
                                       preemphasis_coefficient=args[6],
                                       raw_energy=args[7],
                                       remove_dc_offset=args[8],
                                       round_to_power_of_two=args[9],
                                       snip_edges=args[10],
                                       subtract_mean=args[11],
                                       window_type=args[12])
            return output

        self._compliance_test_helper(self.test_filepath,
                                     'spec',
                                     131,
                                     13,
                                     get_output_fn,
                                     atol=1e-3,
                                     rtol=0)

    def test_fbank(self):
        def get_output_fn(sound, args):
            output = kaldi.fbank(sound,
                                 blackman_coeff=args[1],
                                 dither=0.0,
                                 energy_floor=args[2],
                                 frame_length=args[3],
                                 frame_shift=args[4],
                                 high_freq=args[5],
                                 htk_compat=args[6],
                                 low_freq=args[7],
                                 num_mel_bins=args[8],
                                 preemphasis_coefficient=args[9],
                                 raw_energy=args[10],
                                 remove_dc_offset=args[11],
                                 round_to_power_of_two=args[12],
                                 snip_edges=args[13],
                                 subtract_mean=args[14],
                                 use_energy=args[15],
                                 use_log_fbank=args[16],
                                 use_power=args[17],
                                 vtln_high=args[18],
                                 vtln_low=args[19],
                                 vtln_warp=args[20],
                                 window_type=args[21])
            return output

        self._compliance_test_helper(self.test_filepath,
                                     'fbank',
                                     97,
                                     22,
                                     get_output_fn,
                                     atol=1e-3,
                                     rtol=1e-1)

    def test_mfcc(self):
        def get_output_fn(sound, args):
            output = kaldi.mfcc(sound,
                                blackman_coeff=args[1],
                                dither=0.0,
                                energy_floor=args[2],
                                frame_length=args[3],
                                frame_shift=args[4],
                                high_freq=args[5],
                                htk_compat=args[6],
                                low_freq=args[7],
                                num_mel_bins=args[8],
                                preemphasis_coefficient=args[9],
                                raw_energy=args[10],
                                remove_dc_offset=args[11],
                                round_to_power_of_two=args[12],
                                snip_edges=args[13],
                                subtract_mean=args[14],
                                use_energy=args[15],
                                num_ceps=args[16],
                                cepstral_lifter=args[17],
                                vtln_high=args[18],
                                vtln_low=args[19],
                                vtln_warp=args[20],
                                window_type=args[21])
            return output

        self._compliance_test_helper(self.test_filepath,
                                     'mfcc',
                                     145,
                                     22,
                                     get_output_fn,
                                     atol=1e-3)

    def test_mfcc_empty(self):
        # Passing in an empty tensor should result in an error
        self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))

    def test_resample_waveform(self):
        def get_output_fn(sound, args):
            output = kaldi.resample_waveform(sound, args[1], args[2])
            return output

        self._compliance_test_helper(self.test_8000_filepath,
                                     'resample',
                                     32,
                                     3,
                                     get_output_fn,
                                     atol=1e-2,
                                     rtol=1e-5)

    def test_resample_waveform_upsample_size(self):
        sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
        upsample_sound = kaldi.resample_waveform(sound, sample_rate,
                                                 sample_rate * 2)
        self.assertTrue(upsample_sound.size(-1) == sound.size(-1) * 2)

    def test_resample_waveform_downsample_size(self):
        sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
        downsample_sound = kaldi.resample_waveform(sound, sample_rate,
                                                   sample_rate // 2)
        self.assertTrue(downsample_sound.size(-1) == sound.size(-1) // 2)

    def test_resample_waveform_identity_size(self):
        sound, sample_rate = torchaudio.load_wav(self.test_8000_filepath)
        downsample_sound = kaldi.resample_waveform(sound, sample_rate,
                                                   sample_rate)
        self.assertTrue(downsample_sound.size(-1) == sound.size(-1))

    def _test_resample_waveform_accuracy(self,
                                         up_scale_factor=None,
                                         down_scale_factor=None,
                                         atol=1e-1,
                                         rtol=1e-4):
        # resample the signal and compare it to the ground truth
        n_to_trim = 20
        sample_rate = 1000
        new_sample_rate = sample_rate

        if up_scale_factor is not None:
            new_sample_rate *= up_scale_factor

        if down_scale_factor is not None:
            new_sample_rate //= down_scale_factor

        duration = 5  # seconds
        original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)

        sound = 123 * torch.cos(
            2 * math.pi * 3 * original_timestamps).unsqueeze(0)
        estimate = kaldi.resample_waveform(sound, sample_rate,
                                           new_sample_rate).squeeze()

        new_timestamps = torch.arange(0, duration,
                                      1.0 / new_sample_rate)[:estimate.size(0)]
        ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)

        # trim the first/last n samples as these points have boundary effects
        ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
        estimate = estimate[..., n_to_trim:-n_to_trim]

        self.assertTrue(
            torch.allclose(ground_truth, estimate, atol=atol, rtol=rtol))

    def test_resample_waveform_downsample_accuracy(self):
        for i in range(1, 20):
            self._test_resample_waveform_accuracy(down_scale_factor=i * 2)

    def test_resample_waveform_upsample_accuracy(self):
        for i in range(1, 20):
            self._test_resample_waveform_accuracy(up_scale_factor=1.0 +
                                                  i / 20.0)

    def test_resample_waveform_multi_channel(self):
        num_channels = 3

        sound, sample_rate = torchaudio.load_wav(
            self.test_8000_filepath)  # (1, 8000)
        multi_sound = sound.repeat(num_channels, 1)  # (num_channels, 8000)

        for i in range(num_channels):
            multi_sound[i, :] *= (i + 1) * 1.5

        multi_sound_sampled = kaldi.resample_waveform(multi_sound, sample_rate,
                                                      sample_rate // 2)

        # check that sampling is same whether using separately or in a tensor of size (c, n)
        for i in range(num_channels):
            single_channel = sound * (i + 1) * 1.5
            single_channel_sampled = kaldi.resample_waveform(
                single_channel, sample_rate, sample_rate // 2)
            self.assertTrue(
                torch.allclose(multi_sound_sampled[i, :],
                               single_channel_sampled,
                               rtol=1e-4))
Exemplo n.º 18
0
class TestFunctionalFiltering(unittest.TestCase):
    test_dirpath, test_dir = create_temp_assets_dir()

    def _test_lfilter_basic(self, dtype, device):
        """
        Create a very basic signal,
        Then make a simple 4th order delay
        The output should be same as the input but shifted
        """

        torch.random.manual_seed(42)
        waveform = torch.rand(2, 44100 * 1, dtype=dtype, device=device)
        b_coeffs = torch.tensor([0, 0, 0, 1], dtype=dtype, device=device)
        a_coeffs = torch.tensor([1, 0, 0, 0], dtype=dtype, device=device)
        output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)

        assert torch.allclose(waveform[:, 0:-3], output_waveform[:, 3:], atol=1e-5)

    def test_lfilter_basic(self):
        self._test_lfilter_basic(torch.float32, torch.device("cpu"))

    def test_lfilter_basic_double(self):
        self._test_lfilter_basic(torch.float64, torch.device("cpu"))

    def test_lfilter_basic_gpu(self):
        if torch.cuda.is_available():
            self._test_lfilter_basic(torch.float32, torch.device("cuda:0"))
        else:
            print("skipping GPU test for lfilter_basic because device not available")
            pass

    def _test_lfilter(self, waveform, device):
        """
        Design an IIR lowpass filter using scipy.signal filter design
        https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign

        Example
          >>> from scipy.signal import iirdesign
          >>> b, a = iirdesign(0.2, 0.3, 1, 60)
        """

        b_coeffs = torch.tensor(
            [
                0.00299893,
                -0.0051152,
                0.00841964,
                -0.00747802,
                0.00841964,
                -0.0051152,
                0.00299893,
            ],
            device=device,
        )
        a_coeffs = torch.tensor(
            [
                1.0,
                -4.8155751,
                10.2217618,
                -12.14481273,
                8.49018171,
                -3.3066882,
                0.56088705,
            ],
            device=device,
        )

        output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
        assert len(output_waveform.size()) == 2
        assert output_waveform.size(0) == waveform.size(0)
        assert output_waveform.size(1) == waveform.size(1)
        _test_torchscript_functional(F.lfilter, waveform, a_coeffs, b_coeffs)

    def test_lfilter(self):

        filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        self._test_lfilter(waveform, torch.device("cpu"))

    def test_lfilter_gpu(self):
        if torch.cuda.is_available():
            filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
            waveform, _ = torchaudio.load(filepath, normalization=True)
            cuda0 = torch.device("cuda:0")
            cuda_waveform = waveform.cuda(device=cuda0)
            self._test_lfilter(cuda_waveform, cuda0)
        else:
            print("skipping GPU test for lfilter because device not available")
            pass

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_lowpass(self):

        """
        Test biquad lowpass filter, compare to SoX implementation
        """

        CUTOFF_FREQ = 3000

        noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("lowpass", [CUTOFF_FREQ])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)

        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
        _test_torchscript_functional(F.lowpass_biquad, waveform, sample_rate, CUTOFF_FREQ)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_highpass(self):
        """
        Test biquad highpass filter, compare to SoX implementation
        """

        CUTOFF_FREQ = 2000

        noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("highpass", [CUTOFF_FREQ])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.highpass_biquad(waveform, sample_rate, CUTOFF_FREQ)

        # TBD - this fails at the 1e-4 level, debug why
        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3)
        _test_torchscript_functional(F.highpass_biquad, waveform, sample_rate, CUTOFF_FREQ)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_allpass(self):
        """
        Test biquad allpass filter, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707

        noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("allpass", [CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.allpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)

        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
        _test_torchscript_functional(F.allpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_bandpass_with_csg(self):
        """
        Test biquad bandpass filter, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707
        CONST_SKIRT_GAIN = True

        noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("bandpass", ["-c", CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)

        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
        _test_torchscript_functional(F.bandpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_bandpass_without_csg(self):
        """
        Test biquad bandpass filter, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707
        CONST_SKIRT_GAIN = False

        noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("bandpass", [CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.bandpass_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)

        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
        _test_torchscript_functional(F.bandpass_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, CONST_SKIRT_GAIN)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_bandreject(self):
        """
        Test biquad bandreject filter, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707

        noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("bandreject", [CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.bandreject_biquad(waveform, sample_rate, CENTRAL_FREQ, Q)

        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
        _test_torchscript_functional(F.bandreject_biquad, waveform, sample_rate, CENTRAL_FREQ, Q)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_band_with_noise(self):
        """
        Test biquad band filter with noise mode, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707
        NOISE = True

        noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("band", ["-n", CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)

        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
        _test_torchscript_functional(F.band_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_band_without_noise(self):
        """
        Test biquad band filter without noise mode, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707
        NOISE = False

        noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("band", [CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.band_biquad(waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)

        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
        _test_torchscript_functional(F.band_biquad, waveform, sample_rate, CENTRAL_FREQ, Q, NOISE)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_treble(self):
        """
        Test biquad treble filter, compare to SoX implementation
        """

        CENTRAL_FREQ = 1000
        Q = 0.707
        GAIN = 40

        noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("treble", [GAIN, CENTRAL_FREQ, str(Q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.treble_biquad(waveform, sample_rate, GAIN, CENTRAL_FREQ, Q)

        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
        _test_torchscript_functional(F.treble_biquad, waveform, sample_rate, GAIN, CENTRAL_FREQ, Q)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_deemph(self):
        """
        Test biquad deemph filter, compare to SoX implementation
        """

        noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("deemph")
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.deemph_biquad(waveform, sample_rate)

        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
        _test_torchscript_functional(F.deemph_biquad, waveform, sample_rate)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_riaa(self):
        """
        Test biquad riaa filter, compare to SoX implementation
        """

        noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("riaa")
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.riaa_biquad(waveform, sample_rate)

        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
        _test_torchscript_functional(F.riaa_biquad, waveform, sample_rate)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_equalizer(self):
        """
        Test biquad peaking equalizer filter, compare to SoX implementation
        """

        CENTER_FREQ = 300
        Q = 0.707
        GAIN = 1

        noise_filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("equalizer", [CENTER_FREQ, Q, GAIN])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.equalizer_biquad(waveform, sample_rate, CENTER_FREQ, GAIN, Q)

        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)
        _test_torchscript_functional(F.equalizer_biquad, waveform, sample_rate, CENTER_FREQ, GAIN, Q)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_perf_biquad_filtering(self):

        fn_sine = os.path.join(self.test_dirpath, "assets", "whitenoise.wav")

        b0 = 0.4
        b1 = 0.2
        b2 = 0.9
        a0 = 0.7
        a1 = 0.2
        a2 = 0.6

        # SoX method
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(fn_sine)
        _timing_sox = time.time()
        E.append_effect_to_chain("biquad", [b0, b1, b2, a0, a1, a2])
        waveform_sox_out, sr = E.sox_build_flow_effects()
        _timing_sox_run_time = time.time() - _timing_sox

        _timing_lfilter_filtering = time.time()
        waveform, sample_rate = torchaudio.load(fn_sine, normalization=True)
        waveform_lfilter_out = F.lfilter(
            waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
        )
        _timing_lfilter_run_time = time.time() - _timing_lfilter_filtering

        assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4)
        _test_torchscript_functional(
            F.lfilter, waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
        )
Exemplo n.º 19
0
class TestFunctional(unittest.TestCase):
    data_sizes = [(2, 20), (3, 15), (4, 10)]
    number_of_trials = 100
    specgram = torch.tensor([1., 2., 3., 4.])

    test_dirpath, test_dir = common_utils.create_temp_assets_dir()

    test_filepath = os.path.join(test_dirpath, 'assets',
                                 'steam-train-whistle-daniel_simon.wav')
    waveform_train, sr_train = torchaudio.load(test_filepath)

    def _test_compute_deltas(self,
                             specgram,
                             expected,
                             win_length=3,
                             atol=1e-6,
                             rtol=1e-8):
        computed = F.compute_deltas(specgram, win_length=win_length)
        self.assertTrue(computed.shape == expected.shape,
                        (computed.shape, expected.shape))
        torch.testing.assert_allclose(computed, expected, atol=atol, rtol=rtol)

    def test_compute_deltas_onechannel(self):
        specgram = self.specgram.unsqueeze(0).unsqueeze(0)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
        self._test_compute_deltas(specgram, expected)

    def test_compute_deltas_twochannel(self):
        specgram = self.specgram.repeat(1, 2, 1)
        expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], [0.5, 1.0, 1.0, 0.5]]])
        self._test_compute_deltas(specgram, expected)

    def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
        # trim sound for case when constructed signal is shorter than original
        sound = sound[..., :estimate.size(-1)]

        self.assertTrue(sound.shape == estimate.shape,
                        (sound.shape, estimate.shape))
        self.assertTrue(torch.allclose(sound, estimate, atol=atol, rtol=rtol))

    def _test_istft_is_inverse_of_stft(self, kwargs):
        # generates a random sound signal for each tril and then does the stft/istft
        # operation to check whether we can reconstruct signal
        for data_size in self.data_sizes:
            for i in range(self.number_of_trials):

                sound = common_utils.random_float_tensor(i, data_size)

                stft = torch.stft(sound, **kwargs)
                estimate = torchaudio.functional.istft(stft,
                                                       length=sound.size(1),
                                                       **kwargs)

                self._compare_estimate(sound, estimate)

    def test_istft_is_inverse_of_stft1(self):
        # hann_window, centered, normalized, onesided
        kwargs1 = {
            'n_fft': 12,
            'hop_length': 4,
            'win_length': 12,
            'window': torch.hann_window(12),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }

        self._test_istft_is_inverse_of_stft(kwargs1)

    def test_istft_is_inverse_of_stft2(self):
        # hann_window, centered, not normalized, not onesided
        kwargs2 = {
            'n_fft': 12,
            'hop_length': 2,
            'win_length': 8,
            'window': torch.hann_window(8),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': False,
            'onesided': False,
        }

        self._test_istft_is_inverse_of_stft(kwargs2)

    def test_istft_is_inverse_of_stft3(self):
        # hamming_window, centered, normalized, not onesided
        kwargs3 = {
            'n_fft': 15,
            'hop_length': 3,
            'win_length': 11,
            'window': torch.hamming_window(11),
            'center': True,
            'pad_mode': 'constant',
            'normalized': True,
            'onesided': False,
        }

        self._test_istft_is_inverse_of_stft(kwargs3)

    def test_istft_is_inverse_of_stft4(self):
        # hamming_window, not centered, not normalized, onesided
        # window same size as n_fft
        kwargs4 = {
            'n_fft': 5,
            'hop_length': 2,
            'win_length': 5,
            'window': torch.hamming_window(5),
            'center': False,
            'pad_mode': 'constant',
            'normalized': False,
            'onesided': True,
        }

        self._test_istft_is_inverse_of_stft(kwargs4)

    def test_istft_is_inverse_of_stft5(self):
        # hamming_window, not centered, not normalized, not onesided
        # window same size as n_fft
        kwargs5 = {
            'n_fft': 3,
            'hop_length': 2,
            'win_length': 3,
            'window': torch.hamming_window(3),
            'center': False,
            'pad_mode': 'reflect',
            'normalized': False,
            'onesided': False,
        }

        self._test_istft_is_inverse_of_stft(kwargs5)

    def test_istft_of_ones(self):
        # stft = torch.stft(torch.ones(4), 4)
        stft = torch.tensor([[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4.,
                                                                       0.]],
                             [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0.,
                                                                       0.]],
                             [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0.,
                                                                       0.]]])

        estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
        self._compare_estimate(torch.ones(4), estimate)

    def test_istft_of_zeros(self):
        # stft = torch.stft(torch.zeros(4), 4)
        stft = torch.zeros((3, 5, 2))

        estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
        self._compare_estimate(torch.zeros(4), estimate)

    def test_istft_requires_overlap_windows(self):
        # the window is size 1 but it hops 20 so there is a gap which throw an error
        stft = torch.zeros((3, 5, 2))
        self.assertRaises(AssertionError,
                          torchaudio.functional.istft,
                          stft,
                          n_fft=4,
                          hop_length=20,
                          win_length=1,
                          window=torch.ones(1))

    def test_istft_requires_nola(self):
        stft = torch.zeros((3, 5, 2))
        kwargs_ok = {
            'n_fft': 4,
            'win_length': 4,
            'window': torch.ones(4),
        }

        kwargs_not_ok = {
            'n_fft': 4,
            'win_length': 4,
            'window': torch.zeros(4),
        }

        # A window of ones meets NOLA but a window of zeros does not. This should
        # throw an error.
        torchaudio.functional.istft(stft, **kwargs_ok)
        self.assertRaises(AssertionError, torchaudio.functional.istft, stft,
                          **kwargs_not_ok)

    def test_istft_requires_non_empty(self):
        self.assertRaises(AssertionError, torchaudio.functional.istft,
                          torch.zeros((3, 0, 2)), 2)
        self.assertRaises(AssertionError, torchaudio.functional.istft,
                          torch.zeros((0, 3, 2)), 2)

    def _test_istft_of_sine(self, amplitude, L, n):
        # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
        x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype())
        sound = amplitude * torch.sin(2 * math.pi / L * x * n)
        # stft = torch.stft(sound, L, hop_length=L, win_length=L,
        #                   window=torch.ones(L), center=False, normalized=False)
        stft = torch.zeros((L // 2 + 1, 2, 2))
        stft_largest_val = (amplitude * L) / 2.0
        if n < stft.size(0):
            stft[n, :, 1] = -stft_largest_val

        if 0 <= L - n < stft.size(0):
            # symmetric about L // 2
            stft[L - n, :, 1] = stft_largest_val

        estimate = torchaudio.functional.istft(stft,
                                               L,
                                               hop_length=L,
                                               win_length=L,
                                               window=torch.ones(L),
                                               center=False,
                                               normalized=False)
        # There is a larger error due to the scaling of amplitude
        self._compare_estimate(sound, estimate, atol=1e-3)

    def test_istft_of_sine(self):
        self._test_istft_of_sine(amplitude=123, L=5, n=1)
        self._test_istft_of_sine(amplitude=150, L=5, n=2)
        self._test_istft_of_sine(amplitude=111, L=5, n=3)
        self._test_istft_of_sine(amplitude=160, L=7, n=4)
        self._test_istft_of_sine(amplitude=145, L=8, n=5)
        self._test_istft_of_sine(amplitude=80, L=9, n=6)
        self._test_istft_of_sine(amplitude=99, L=10, n=7)

    def _test_linearity_of_istft(self,
                                 data_size,
                                 kwargs,
                                 atol=1e-6,
                                 rtol=1e-8):
        for i in range(self.number_of_trials):
            tensor1 = common_utils.random_float_tensor(i, data_size)
            tensor2 = common_utils.random_float_tensor(i * 2, data_size)
            a, b = torch.rand(2)
            istft1 = torchaudio.functional.istft(tensor1, **kwargs)
            istft2 = torchaudio.functional.istft(tensor2, **kwargs)
            istft = a * istft1 + b * istft2
            estimate = torchaudio.functional.istft(a * tensor1 + b * tensor2,
                                                   **kwargs)
            self._compare_estimate(istft, estimate, atol, rtol)

    def test_linearity_of_istft1(self):
        # hann_window, centered, normalized, onesided
        kwargs1 = {
            'n_fft': 12,
            'window': torch.hann_window(12),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }
        data_size = (2, 7, 7, 2)
        self._test_linearity_of_istft(data_size, kwargs1)

    def test_linearity_of_istft2(self):
        # hann_window, centered, not normalized, not onesided
        kwargs2 = {
            'n_fft': 12,
            'window': torch.hann_window(12),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': False,
            'onesided': False,
        }
        data_size = (2, 12, 7, 2)
        self._test_linearity_of_istft(data_size, kwargs2)

    def test_linearity_of_istft3(self):
        # hamming_window, centered, normalized, not onesided
        kwargs3 = {
            'n_fft': 12,
            'window': torch.hamming_window(12),
            'center': True,
            'pad_mode': 'constant',
            'normalized': True,
            'onesided': False,
        }
        data_size = (2, 12, 7, 2)
        self._test_linearity_of_istft(data_size, kwargs3)

    def test_linearity_of_istft4(self):
        # hamming_window, not centered, not normalized, onesided
        kwargs4 = {
            'n_fft': 12,
            'window': torch.hamming_window(12),
            'center': False,
            'pad_mode': 'constant',
            'normalized': False,
            'onesided': True,
        }
        data_size = (2, 7, 3, 2)
        self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_gain(self):
        waveform_gain = F.gain(self.waveform_train, 3)
        self.assertTrue(waveform_gain.abs().max().item(), 1.)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("gain", [3])
        sox_gain_waveform = E.sox_build_flow_effects()[0]

        self.assertTrue(
            torch.allclose(waveform_gain, sox_gain_waveform, atol=1e-04))

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_dither(self):
        waveform_dithered = F.dither(self.waveform_train)
        waveform_dithered_noiseshaped = F.dither(self.waveform_train,
                                                 noise_shaping=True)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(self.test_filepath)
        E.append_effect_to_chain("dither", [])
        sox_dither_waveform = E.sox_build_flow_effects()[0]

        self.assertTrue(
            torch.allclose(waveform_dithered, sox_dither_waveform, atol=1e-04))
        E.clear_chain()

        E.append_effect_to_chain("dither", ["-s"])
        sox_dither_waveform_ns = E.sox_build_flow_effects()[0]

        self.assertTrue(
            torch.allclose(waveform_dithered_noiseshaped,
                           sox_dither_waveform_ns,
                           atol=1e-02))

    @unittest.skipIf("sox" not in BACKENDS, "sox not available")
    @AudioBackendScope("sox")
    def test_vctk_transform_pipeline(self):
        test_filepath_vctk = os.path.join(self.test_dirpath,
                                          "assets/VCTK-Corpus/wav48/p224/",
                                          "p224_002.wav")
        wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk)

        # rate
        sample = T.Resample(sr_vctk,
                            16000,
                            resampling_method='sinc_interpolation')
        wf_vctk = sample(wf_vctk)
        # dither
        wf_vctk = F.dither(wf_vctk, noise_shaping=True)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(test_filepath_vctk)
        E.append_effect_to_chain("gain", ["-h"])
        E.append_effect_to_chain("channels", [1])
        E.append_effect_to_chain("rate", [16000])
        E.append_effect_to_chain("gain", ["-rh"])
        E.append_effect_to_chain("dither", ["-s"])
        wf_vctk_sox = E.sox_build_flow_effects()[0]

        self.assertTrue(
            torch.allclose(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03))

    def test_pitch(self):
        test_dirpath, test_dir = common_utils.create_temp_assets_dir()
        test_filepath_100 = os.path.join(test_dirpath, 'assets',
                                         "100Hz_44100Hz_16bit_05sec.wav")
        test_filepath_440 = os.path.join(test_dirpath, 'assets',
                                         "440Hz_44100Hz_16bit_05sec.wav")

        # Files from https://www.mediacollege.com/audio/tone/download/
        tests = [
            (test_filepath_100, 100),
            (test_filepath_440, 440),
        ]

        for filename, freq_ref in tests:
            waveform, sample_rate = torchaudio.load(filename)

            freq = torchaudio.functional.detect_pitch_frequency(
                waveform, sample_rate)

            threshold = 1
            s = ((freq - freq_ref).abs() > threshold).sum()
            self.assertFalse(s)

    def test_DB_to_amplitude(self):
        # Make some noise
        x = torch.rand(1000)
        spectrogram = torchaudio.transforms.Spectrogram()
        spec = spectrogram(x)

        amin = 1e-10
        ref = 1.0
        db_multiplier = math.log10(max(amin, ref))

        # Waveform amplitude -> DB -> amplitude
        multiplier = 20.
        power = 0.5

        db = F.amplitude_to_DB(torch.abs(x),
                               multiplier,
                               amin,
                               db_multiplier,
                               top_db=None)
        x2 = F.DB_to_amplitude(db, ref, power)

        self.assertTrue(torch.allclose(torch.abs(x), x2, atol=5e-5))

        # Spectrogram amplitude -> DB -> amplitude
        db = F.amplitude_to_DB(spec,
                               multiplier,
                               amin,
                               db_multiplier,
                               top_db=None)
        x2 = F.DB_to_amplitude(db, ref, power)

        self.assertTrue(torch.allclose(spec, x2, atol=5e-5))

        # Waveform power -> DB -> power
        multiplier = 10.
        power = 1.

        db = F.amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None)
        x2 = F.DB_to_amplitude(db, ref, power)

        self.assertTrue(torch.allclose(torch.abs(x), x2, atol=5e-5))

        # Spectrogram power -> DB -> power
        db = F.amplitude_to_DB(spec,
                               multiplier,
                               amin,
                               db_multiplier,
                               top_db=None)
        x2 = F.DB_to_amplitude(db, ref, power)

        self.assertTrue(torch.allclose(spec, x2, atol=5e-5))
Exemplo n.º 20
0
 def setUpClass(cls):
     cls.test_dirpath, cls.test_dir = common_utils.create_temp_assets_dir()
Exemplo n.º 21
0
class TestFunctionalFiltering(unittest.TestCase):
    test_dirpath, test_dir = common_utils.create_temp_assets_dir()

    def _test_lfilter_basic(self, dtype, device):
        """
        Create a very basic signal,
        Then make a simple 4th order delay
        The output should be same as the input but shifted
        """

        torch.random.manual_seed(42)
        waveform = torch.rand(2, 44100 * 1, dtype=dtype, device=device)
        b_coeffs = torch.tensor([0, 0, 0, 1], dtype=dtype, device=device)
        a_coeffs = torch.tensor([1, 0, 0, 0], dtype=dtype, device=device)
        output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)

        assert torch.allclose(waveform[:, 0:-3],
                              output_waveform[:, 3:],
                              atol=1e-5)

    def test_lfilter_basic(self):
        self._test_lfilter_basic(torch.float32, torch.device("cpu"))

    def test_lfilter_basic_double(self):
        self._test_lfilter_basic(torch.float64, torch.device("cpu"))

    def test_lfilter_basic_gpu(self):
        if torch.cuda.is_available():
            self._test_lfilter_basic(torch.float32, torch.device("cuda:0"))
        else:
            print(
                "skipping GPU test for lfilter_basic because device not available"
            )
            pass

    def _test_lfilter(self, waveform, device):
        """
        Design an IIR lowpass filter using scipy.signal filter design
        https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign

        Example
          >>> from scipy.signal import iirdesign
          >>> b, a = iirdesign(0.2, 0.3, 1, 60)
        """

        b_coeffs = torch.tensor(
            [
                0.00299893,
                -0.0051152,
                0.00841964,
                -0.00747802,
                0.00841964,
                -0.0051152,
                0.00299893,
            ],
            device=device,
        )
        a_coeffs = torch.tensor(
            [
                1.0,
                -4.8155751,
                10.2217618,
                -12.14481273,
                8.49018171,
                -3.3066882,
                0.56088705,
            ],
            device=device,
        )

        output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
        assert len(output_waveform.size()) == 2
        assert output_waveform.size(0) == waveform.size(0)
        assert output_waveform.size(1) == waveform.size(1)

    def test_lfilter(self):

        filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        self._test_lfilter(waveform, torch.device("cpu"))

    def test_lfilter_gpu(self):
        if torch.cuda.is_available():
            filepath = os.path.join(self.test_dirpath, "assets",
                                    "whitenoise.mp3")
            waveform, _ = torchaudio.load(filepath, normalization=True)
            cuda0 = torch.device("cuda:0")
            cuda_waveform = waveform.cuda(device=cuda0)
            self._test_lfilter(cuda_waveform, cuda0)
        else:
            print("skipping GPU test for lfilter because device not available")
            pass

    def test_lowpass(self):
        """
        Test biquad lowpass filter, compare to SoX implementation
        """

        CUTOFF_FREQ = 3000

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.mp3")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("lowpass", [CUTOFF_FREQ])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ)

        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4)

    def test_highpass(self):
        """
        Test biquad highpass filter, compare to SoX implementation
        """

        CUTOFF_FREQ = 2000

        noise_filepath = os.path.join(self.test_dirpath, "assets",
                                      "whitenoise.mp3")
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("highpass", [CUTOFF_FREQ])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.highpass_biquad(waveform, sample_rate, CUTOFF_FREQ)

        # TBD - this fails at the 1e-4 level, debug why
        assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3)

    def test_perf_biquad_filtering(self):

        fn_sine = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3")

        b0 = 0.4
        b1 = 0.2
        b2 = 0.9
        a0 = 0.7
        a1 = 0.2
        a2 = 0.6

        # SoX method
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(fn_sine)
        _timing_sox = time.time()
        E.append_effect_to_chain("biquad", [b0, b1, b2, a0, a1, a2])
        waveform_sox_out, sr = E.sox_build_flow_effects()
        _timing_sox_run_time = time.time() - _timing_sox

        _timing_lfilter_filtering = time.time()
        waveform, sample_rate = torchaudio.load(fn_sine, normalization=True)
        waveform_lfilter_out = F.lfilter(waveform, torch.tensor([a0, a1, a2]),
                                         torch.tensor([b0, b1, b2]))
        _timing_lfilter_run_time = time.time() - _timing_lfilter_filtering

        assert torch.allclose(waveform_sox_out,
                              waveform_lfilter_out,
                              atol=1e-4)