Пример #1
0
    def test_AudioToMelSpectrogramPreprocessor1(self):
        # Test 1 that should test the pure stft implementation as much as possible
        instance1 = modules.AudioToMelSpectrogramPreprocessor(dither=0,
                                                              stft_conv=False,
                                                              mag_power=1.0,
                                                              normalize=False,
                                                              preemph=0.0,
                                                              log=False,
                                                              pad_to=0)
        instance2 = modules.AudioToMelSpectrogramPreprocessor(dither=0,
                                                              stft_conv=True,
                                                              mag_power=1.0,
                                                              normalize=False,
                                                              preemph=0.0,
                                                              log=False,
                                                              pad_to=0)

        # Ensure that the two functions behave similarily
        for _ in range(10):
            input_signal = torch.randn(size=(4, 512))
            length = torch.randint(low=161, high=500, size=[4])
            res1, length1 = instance1(input_signal=input_signal, length=length)
            res2, length2 = instance2(input_signal=input_signal, length=length)
            for len1, len2 in zip(length1, length2):
                assert len1 == len2
            assert res1.shape == res2.shape
            diff = torch.mean(torch.abs(res1 - res2))
            assert diff <= 1e-3
            diff = torch.max(torch.abs(res1 - res2))
            assert diff <= 1e-2
Пример #2
0
    def test_AudioToMelSpectrogramPreprocessor_batch(self):
        # Test 1 that should test the pure stft implementation as much as possible
        instance1 = modules.AudioToMelSpectrogramPreprocessor(
            normalize="per_feature", dither=0, pad_to=0)

        # Ensure that the two functions behave similarily
        for _ in range(10):
            input_signal = torch.randn(size=(4, 512))
            length = torch.randint(low=161, high=500, size=[4])

            with torch.no_grad():
                # batch size 1
                res_instance, length_instance = [], []
                for i in range(input_signal.size(0)):
                    res_ins, length_ins = instance1(
                        input_signal=input_signal[i:i + 1],
                        length=length[i:i + 1])
                    res_instance.append(res_ins)
                    length_instance.append(length_ins)

                res_instance = torch.cat(res_instance, 0)
                length_instance = torch.cat(length_instance, 0)

                # batch size 4
                res_batch, length_batch = instance1(input_signal=input_signal,
                                                    length=length)

            assert res_instance.shape == res_batch.shape
            assert length_instance.shape == length_batch.shape
            diff = torch.mean(torch.abs(res_instance - res_batch))
            assert diff <= 1e-3
            diff = torch.max(torch.abs(res_instance - res_batch))
            assert diff <= 1e-3
Пример #3
0
    def test_SpectrogramAugmentationr_numba_kernel(self, caplog):
        numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__)

        logging._logger.propagate = True
        original_verbosity = logging.get_verbosity()
        logging.set_verbosity(logging.DEBUG)
        caplog.set_level(logging.DEBUG)

        # Make sure constructor works
        instance1 = modules.SpectrogramAugmentation(
            freq_masks=10, time_masks=3, rect_masks=3, use_numba_spec_augment=True
        )
        assert isinstance(instance1, modules.SpectrogramAugmentation)

        # Make sure forward doesn't throw with expected input
        instance0 = modules.AudioToMelSpectrogramPreprocessor(dither=0)
        input_signal = torch.randn(size=(8, 512))
        length = torch.randint(low=161, high=500, size=[8])
        res0 = instance0(input_signal=input_signal, length=length)
        res = instance1(input_spec=res0[0], length=length)

        assert res.shape == res0[0].shape

        # check tha numba kernel debug message indicates that it is available for use
        assert """Numba SpecAugment kernel is available""" in caplog.text

        logging._logger.propagate = False
        logging.set_verbosity(original_verbosity)
Пример #4
0
    def test_AudioToMelSpectrogramPreprocessor2(self):
        # Test 2 that should test the stft implementation as used in ASR models
        instance1 = modules.AudioToMelSpectrogramPreprocessor(dither=0, stft_conv=False)
        instance2 = modules.AudioToMelSpectrogramPreprocessor(dither=0, stft_conv=True)

        # Ensure that the two functions behave similarily
        for _ in range(5):
            input_signal = torch.randn(size=(4, 512))
            length = torch.randint(low=161, high=500, size=[4])
            res1, length1 = instance1(input_signal=input_signal, length=length)
            res2, length2 = instance2(input_signal=input_signal, length=length)
            for len1, len2 in zip(length1, length2):
                assert len1 == len2
            assert res1.shape == res2.shape
            diff = torch.mean(torch.abs(res1 - res2))
            assert diff <= 3e-3
            diff = torch.max(torch.abs(res1 - res2))
            assert diff <= 3
Пример #5
0
    def test_MaskedPatchAugmentation(self):
        # Make sure constructor works
        audio_length = 128
        instance1 = modules.MaskedPatchAugmentation(patch_size=16, mask_patches=0.5, freq_masks=2, freq_width=10)
        assert isinstance(instance1, modules.MaskedPatchAugmentation)

        # Make sure forward doesn't throw with expected input
        instance0 = modules.AudioToMelSpectrogramPreprocessor(dither=0)
        input_signal = torch.randn(size=(4, 512))
        length = torch.randint(low=161, high=500, size=[4])
        res0 = instance0(input_signal=input_signal, length=length)
        res = instance1(input_spec=res0[0], length=length)

        assert res.shape == res0[0].shape
Пример #6
0
    def test_SpectrogramAugmentationr(self):
        # Make sure constructor works
        instance1 = modules.SpectrogramAugmentation(freq_masks=10,
                                                    time_masks=3,
                                                    rect_masks=3)
        assert isinstance(instance1, modules.SpectrogramAugmentation)

        # Make sure forward doesn't throw with expected input
        instance0 = modules.AudioToMelSpectrogramPreprocessor(dither=0)
        input_signal = torch.randn(size=(4, 512))
        length = torch.randint(low=161, high=500, size=[4])
        res0 = instance0(input_signal=input_signal, length=length)
        res = instance1(input_spec=res0[0])

        assert res.shape == res0[0].shape
Пример #7
0
    def test_CropOrPadSpectrogramAugmentation(self):
        # Make sure constructor works
        audio_length = 128
        instance1 = modules.CropOrPadSpectrogramAugmentation(audio_length=audio_length)
        assert isinstance(instance1, modules.CropOrPadSpectrogramAugmentation)

        # Make sure forward doesn't throw with expected input
        instance0 = modules.AudioToMelSpectrogramPreprocessor(dither=0)
        input_signal = torch.randn(size=(4, 512))
        length = torch.randint(low=161, high=500, size=[4])
        res0 = instance0(input_signal=input_signal, length=length)
        res, new_length = instance1(input_signal=res0[0], length=length)

        assert res.shape == torch.Size([4, 64, audio_length])
        assert all(new_length == torch.tensor([128] * 4))