Ejemplo n.º 1
0
    def test_gain(self):
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)

        waveform_gain = F.gain(waveform, 3)
        self.assertTrue(waveform_gain.abs().max().item(), 1.)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(test_filepath)
        E.append_effect_to_chain("gain", [3])
        sox_gain_waveform = E.sox_build_flow_effects()[0]

        self.assertEqual(waveform_gain, sox_gain_waveform, atol=1e-04, rtol=1e-5)
    def test_equalizer(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            center_freq = 300.
            gain = 1.
            q = 0.707
            return F.equalizer_biquad(tensor, sample_rate, center_freq, gain,
                                      q)

        self._assert_consistency(func, waveform)
Ejemplo n.º 3
0
    def test_batch_Vol(self):
        test_filepath = common_utils.get_asset_path(
            'steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        expected = torchaudio.transforms.Vol(gain=1.1)(waveform).repeat(
            3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(
            3, 1, 1))
        self.assertEqual(computed, expected)
Ejemplo n.º 4
0
    def test_batch_melspectrogram(self):
        test_filepath = common_utils.get_asset_path(
            'steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        expected = torchaudio.transforms.MelSpectrogram()(waveform).repeat(
            3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.MelSpectrogram()(waveform.repeat(
            3, 1, 1))
        torch.testing.assert_allclose(computed, expected)
Ejemplo n.º 5
0
    def test_highpass(self):
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

        filepath = common_utils.get_asset_path('whitenoise.wav')
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            cutoff_freq = 2000.
            return F.highpass_biquad(tensor, sample_rate, cutoff_freq)

        self._assert_consistency(func, waveform)
    def test_bandpass_withou_csg(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            const_skirt_gain = True
            return F.bandpass_biquad(tensor, sample_rate, central_freq, q,
                                     const_skirt_gain)

        self._assert_consistency(func, waveform)
Ejemplo n.º 7
0
    def test_perf_biquad_filtering(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            a = torch.tensor([0.7, 0.2, 0.6],
                             device=tensor.device,
                             dtype=tensor.dtype)
            b = torch.tensor([0.4, 0.2, 0.9],
                             device=tensor.device,
                             dtype=tensor.dtype)
            return F.lfilter(tensor, a, b)

        self._assert_consistency(func, waveform)
Ejemplo n.º 8
0
    def test_bandreject(self):
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            sample_rate = 44100
            central_freq = 1000.
            q = 0.707
            return F.bandreject_biquad(tensor, sample_rate, central_freq, q)

        self._assert_consistency(func, waveform)
Ejemplo n.º 9
0
    def test_riaa(self):
        """
        Test biquad riaa filter, compare to SoX implementation
        """

        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("riaa")
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.riaa_biquad(waveform, sample_rate)

        self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
Ejemplo n.º 10
0
    def test_contrast(self):
        """
        Test contrast effect, compare to SoX implementation
        """
        enhancement_amount = 80.
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("contrast", [enhancement_amount])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.contrast(waveform, enhancement_amount)

        self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
Ejemplo n.º 11
0
    def test_dcshift_without_limiter(self):
        """
        Test dcshift effect, compare to SoX implementation
        """
        shift = 0.6
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("dcshift", [shift])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, _ = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.dcshift(waveform, shift)

        self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
Ejemplo n.º 12
0
 def _create_data_set(self):
     # used to generate the dataset to test on. this is not used in testing (offline procedure)
     test_filepath = common_utils.get_asset_path('kaldi_file.wav')
     sr = 16000
     x = torch.arange(0, 20).float()
     # between [-6,6]
     y = torch.cos(2 * math.pi * x) + 3 * torch.sin(math.pi * x) + 2 * torch.cos(x)
     # between [-2^30, 2^30]
     y = (y / 6 * (1 << 30)).long()
     # clear the last 16 bits because they aren't used anyways
     y = ((y >> 16) << 16).float()
     torchaudio.save(test_filepath, y, sr)
     sound, sample_rate = torchaudio.load(test_filepath, normalization=False)
     print(y >> 16)
     self.assertTrue(sample_rate == sr)
     torch.testing.assert_allclose(y, sound)
Ejemplo n.º 13
0
    def _test_helper(self, file_name, expected_data, fn, expected_dtype):
        """ Takes a file_name to the input data and a function fn to extract the
        data. It compares the extracted data to the expected_data. The expected_dtype
        will be used to check that the extracted data is of the right type.
        """
        test_filepath = common_utils.get_asset_path(file_name)
        expected_output = {
            'key' + str(idx + 1): torch.tensor(val, dtype=expected_dtype)
            for idx, val in enumerate(expected_data)
        }

        for key, vec in fn(test_filepath):
            self.assertTrue(key in expected_output)
            self.assertTrue(isinstance(vec, torch.Tensor))
            self.assertEqual(vec.dtype, expected_dtype)
            self.assertTrue(torch.all(torch.eq(vec, expected_output[key])))
Ejemplo n.º 14
0
    def test_overdrive(self):
        """
        Test overdrive effect, compare to SoX implementation
        """
        gain = 30
        colour = 40
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("overdrive", [gain, colour])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, _ = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.overdrive(waveform, gain, colour)

        self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
Ejemplo n.º 15
0
    def test_lowpass(self):
        """
        Test biquad lowpass filter, compare to SoX implementation
        """

        cutoff_freq = 3000

        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("lowpass", [cutoff_freq])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.lowpass_biquad(waveform, sample_rate, cutoff_freq)

        self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
Ejemplo n.º 16
0
    def test_bandreject(self):
        """
        Test biquad bandreject filter, compare to SoX implementation
        """

        central_freq = 1000
        q = 0.707

        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("bandreject", [central_freq, str(q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.bandreject_biquad(waveform, sample_rate, central_freq, q)

        self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
Ejemplo n.º 17
0
    def test_equalizer(self):
        """
        Test biquad peaking equalizer filter, compare to SoX implementation
        """

        center_freq = 300
        q = 0.707
        gain = 1

        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("equalizer", [center_freq, q, gain])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.equalizer_biquad(waveform, sample_rate, center_freq, gain, q)

        self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
Ejemplo n.º 18
0
    def test_lfilter(self):
        if self.dtype == torch.float64:
            pytest.xfail("This test is known to fail for float64")

        filepath = common_utils.get_asset_path('whitenoise.wav')
        waveform, _ = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            # Design an IIR lowpass filter using scipy.signal filter design
            # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
            #
            # Example
            #     >>> from scipy.signal import iirdesign
            #     >>> b, a = iirdesign(0.2, 0.3, 1, 60)
            b_coeffs = torch.tensor(
                [
                    0.00299893,
                    -0.0051152,
                    0.00841964,
                    -0.00747802,
                    0.00841964,
                    -0.0051152,
                    0.00299893,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            a_coeffs = torch.tensor(
                [
                    1.0,
                    -4.8155751,
                    10.2217618,
                    -12.14481273,
                    8.49018171,
                    -3.3066882,
                    0.56088705,
                ],
                device=tensor.device,
                dtype=tensor.dtype,
            )
            return F.lfilter(tensor, a_coeffs, b_coeffs)

        self._assert_consistency(func, waveform)
Ejemplo n.º 19
0
    def test_phaser_triangle(self):
        """
        Test phaser effect with triangle modulation, compare to SoX implementation
        """
        gain_in = 0.5
        gain_out = 0.8
        delay_ms = 2.0
        decay = 0.4
        speed = 0.5
        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("phaser", [gain_in, gain_out, delay_ms, decay, speed, "-t"])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath, normalization=True)
        output_waveform = F.phaser(waveform, sample_rate, gain_in, gain_out, delay_ms, decay, speed, sinusoidal=False)

        self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
Ejemplo n.º 20
0
    def test_dither(self):
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)

        waveform_dithered = F.dither(waveform)
        waveform_dithered_noiseshaped = F.dither(waveform, noise_shaping=True)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(test_filepath)
        E.append_effect_to_chain("dither", [])
        sox_dither_waveform = E.sox_build_flow_effects()[0]

        self.assertEqual(waveform_dithered, sox_dither_waveform, atol=1e-04, rtol=1e-5)
        E.clear_chain()

        E.append_effect_to_chain("dither", ["-s"])
        sox_dither_waveform_ns = E.sox_build_flow_effects()[0]

        self.assertEqual(waveform_dithered_noiseshaped, sox_dither_waveform_ns, atol=1e-02, rtol=1e-5)
Ejemplo n.º 21
0
class TestDatasets(unittest.TestCase):
    path = common_utils.get_asset_path()

    def test_yesno(self):
        data = YESNO(self.path)
        data[0]

    def test_vctk(self):
        data = VCTK(self.path)
        data[0]

    def test_librispeech(self):
        data = LIBRISPEECH(self.path, "dev-clean")
        data[0]

    @unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
    def test_commonvoice(self):
        data = COMMONVOICE(self.path, url="tatar")
        data[0]

    @unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
    def test_commonvoice_diskcache(self):
        data = COMMONVOICE(self.path, url="tatar")
        data = diskcache_iterator(data)
        # Save
        data[0]
        # Load
        data[0]

    @unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
    def test_commonvoice_bg(self):
        data = COMMONVOICE(self.path, url="tatar")
        data = bg_iterator(data, 5)
        for _ in data:
            pass

    def test_ljspeech(self):
        data = LJSPEECH(self.path)
        data[0]

    def test_speechcommands(self):
        data = SPEECHCOMMANDS(self.path)
        data[0]
Ejemplo n.º 22
0
    def test_vctk_transform_pipeline(self):
        test_filepath_vctk = common_utils.get_asset_path('VCTK-Corpus', 'wav48', 'p224', 'p224_002.wav')
        wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk)

        # rate
        sample = T.Resample(sr_vctk, 16000, resampling_method='sinc_interpolation')
        wf_vctk = sample(wf_vctk)
        # dither
        wf_vctk = F.dither(wf_vctk, noise_shaping=True)

        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(test_filepath_vctk)
        E.append_effect_to_chain("gain", ["-h"])
        E.append_effect_to_chain("channels", [1])
        E.append_effect_to_chain("rate", [16000])
        E.append_effect_to_chain("gain", ["-rh"])
        E.append_effect_to_chain("dither", ["-s"])
        wf_vctk_sox = E.sox_build_flow_effects()[0]

        self.assertEqual(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03)
Ejemplo n.º 23
0
    def test_phaser(self):
        filepath = common_utils.get_asset_path("whitenoise.wav")
        waveform, sample_rate = torchaudio.load(filepath, normalization=True)

        def func(tensor):
            gain_in = 0.5
            gain_out = 0.8
            delay_ms = 2.0
            decay = 0.4
            speed = 0.5
            sample_rate = 44100
            return F.phaser(tensor,
                            sample_rate,
                            gain_in,
                            gain_out,
                            delay_ms,
                            decay,
                            speed,
                            sinusoidal=True)

        self._assert_consistency(func, waveform)
Ejemplo n.º 24
0
    def test_highpass(self):
        """
        Test biquad highpass filter, compare to SoX implementation
        """

        cutoff_freq = 2000

        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("highpass", [cutoff_freq])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.highpass_biquad(waveform, sample_rate, cutoff_freq)

        # TBD - this fails at the 1e-4 level, debug why
        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-3,
                                      rtol=1e-5)
Ejemplo n.º 25
0
    def test_mel2(self):
        top_db = 80.
        s2db = transforms.AmplitudeToDB('power', top_db)

        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
        mel_transform = transforms.MelSpectrogram()
        # check defaults
        spectrogram_torch = s2db(mel_transform(waveform_scaled))  # (1, 128, 321)
        self.assertTrue(spectrogram_torch.dim() == 3)
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
        # check correctness of filterbank conversion matrix
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
        # check options
        kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
                  'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
        spectrogram2_torch = s2db(mel_transform2(waveform_scaled))  # (1, 50, 513)
        self.assertTrue(spectrogram2_torch.dim() == 3)
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
        # check on multi-channel audio
        filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        x_stereo, sr_stereo = torchaudio.load(filepath)  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
        self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
        # check filterbank matrix creation
        fb_matrix_transform = transforms.MelScale(
            n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
Ejemplo n.º 26
0
    def test_resample_size(self):
        input_path = common_utils.get_asset_path('sinewave.wav')
        waveform, sample_rate = torchaudio.load(input_path)

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')

        self.assertRaises(ValueError, invalid_resample, waveform)

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
        up_sampled = upsample_resample(waveform)

        # we expect the upsampled signal to have twice as many samples
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate, downsample_rate, resampling_method='sinc_interpolation')
        down_sampled = downsample_resample(waveform)

        # we expect the downsampled signal to have half as many samples
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
Ejemplo n.º 27
0
    def test_perf_biquad_filtering(self):

        fn_sine = common_utils.get_asset_path('whitenoise.wav')

        b0 = 0.4
        b1 = 0.2
        b2 = 0.9
        a0 = 0.7
        a1 = 0.2
        a2 = 0.6

        # SoX method
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(fn_sine)
        E.append_effect_to_chain("biquad", [b0, b1, b2, a0, a1, a2])
        waveform_sox_out, _ = E.sox_build_flow_effects()

        waveform, _ = torchaudio.load(fn_sine, normalization=True)
        waveform_lfilter_out = F.lfilter(
            waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2])
        )

        self.assertEqual(waveform_lfilter_out, waveform_sox_out, atol=1e-4, rtol=1e-5)
Ejemplo n.º 28
0
    def test_treble(self):
        """
        Test biquad treble filter, compare to SoX implementation
        """

        central_freq = 1000
        q = 0.707
        gain = 40

        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("treble", [gain, central_freq, str(q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.treble_biquad(waveform, sample_rate, gain,
                                          central_freq, q)

        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-4,
                                      rtol=1e-5)
Ejemplo n.º 29
0
    def test_bandpass_without_csg(self):
        """
        Test biquad bandpass filter, compare to SoX implementation
        """

        central_freq = 1000
        q = 0.707
        const_skirt_gain = False

        noise_filepath = common_utils.get_asset_path('whitenoise.wav')
        E = torchaudio.sox_effects.SoxEffectsChain()
        E.set_input_file(noise_filepath)
        E.append_effect_to_chain("bandpass", [central_freq, str(q) + 'q'])
        sox_output_waveform, sr = E.sox_build_flow_effects()

        waveform, sample_rate = torchaudio.load(noise_filepath,
                                                normalization=True)
        output_waveform = F.bandpass_biquad(waveform, sample_rate,
                                            central_freq, q, const_skirt_gain)

        torch.testing.assert_allclose(output_waveform,
                                      sox_output_waveform,
                                      atol=1e-4,
                                      rtol=1e-5)
Ejemplo n.º 30
0
    def test_batch_mulaw(self):
        test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        waveform_encoded = torchaudio.transforms.MuLawEncoding()(waveform)
        expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = torchaudio.transforms.MuLawEncoding()(waveform_batched)

        # shape = (3, 2, 201, 1394)
        torch.testing.assert_allclose(computed, expected)

        # Single then transform then batch
        waveform_decoded = torchaudio.transforms.MuLawDecoding()(waveform_encoded)
        expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.MuLawDecoding()(computed)

        # shape = (3, 2, 201, 1394)
        torch.testing.assert_allclose(computed, expected)