Exemple #1
0
def load_params(*paths):
    params = []
    with open(get_asset_path(*paths), 'r') as file:
        for line in file:
            data = json.loads(line)
            for effect in data['effects']:
                for i, arg in enumerate(effect):
                    if arg.startswith("<ASSET_DIR>"):
                        effect[i] = arg.replace("<ASSET_DIR>",
                                                get_asset_path())
            params.append(param(data))
    return params
 def __init__(self):
     sound_files = ["sinewave.wav", "steam-train-whistle-daniel_simon.mp3"]
     self.data = [common_utils.get_asset_path(fn) for fn in sound_files]
     self.si, self.ei = torchaudio.info(
         common_utils.get_asset_path("sinewave.wav"))
     self.si.precision = 16
     self.E = torchaudio.sox_effects.SoxEffectsChain()
     self.E.append_effect_to_chain("rate",
                                   [self.si.rate])  # resample to 16000hz
     self.E.append_effect_to_chain("channels",
                                   [self.si.channels])  # mono signal
     self.E.append_effect_to_chain(
         "trim", [0, "16000s"])  # first 16000 samples of audio
 def test_vad(self):
     common_utils.set_audio_backend('default')
     filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
     waveform, sample_rate = torchaudio.load(filepath)
     self.assert_batch_consistencies(F.vad,
                                     waveform,
                                     sample_rate=sample_rate)
Exemple #4
0
    def test_resample_size(self):
        input_path = common_utils.get_asset_path('sinewave.wav')
        waveform, sample_rate = common_utils.load_wav(input_path)

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2
        invalid_resampling_method = 'foo'

        with self.assertRaises(ValueError):
            torchaudio.transforms.Resample(
                sample_rate,
                upsample_rate,
                resampling_method=invalid_resampling_method)

        upsample_resample = torchaudio.transforms.Resample(
            sample_rate, upsample_rate, resampling_method='sinc_interpolation')
        up_sampled = upsample_resample(waveform)

        # we expect the upsampled signal to have twice as many samples
        self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)

        downsample_resample = torchaudio.transforms.Resample(
            sample_rate,
            downsample_rate,
            resampling_method='sinc_interpolation')
        down_sampled = downsample_resample(waveform)

        # we expect the downsampled signal to have half as many samples
        self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
    def test_batch_mulaw(self):
        test_filepath = common_utils.get_asset_path(
            'steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        waveform_encoded = torchaudio.transforms.MuLawEncoding()(waveform)
        expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1)
        computed = torchaudio.transforms.MuLawEncoding()(waveform_batched)

        # shape = (3, 2, 201, 1394)
        self.assertEqual(computed, expected)

        # Single then transform then batch
        waveform_decoded = torchaudio.transforms.MuLawDecoding()(
            waveform_encoded)
        expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.MuLawDecoding()(computed)

        # shape = (3, 2, 201, 1394)
        self.assertEqual(computed, expected)
 def test_vad_from_file(self):
     filepath = common_utils.get_asset_path("vad-go-stereo-44100.wav")
     waveform, sample_rate = common_utils.load_wav(filepath)
     # Each channel is slightly offset - we can use this to create a batch
     # with different items.
     batch = waveform.view(2, 1, -1)
     self.assert_batch_consistency(F.vad, batch, sample_rate=sample_rate)
    def test_batch_TimeStretch(self):
        test_filepath = common_utils.get_asset_path(
            'steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        kwargs = {
            'n_fft': 2048,
            'hop_length': 512,
            'win_length': 2048,
            'window': torch.hann_window(2048),
            'center': True,
            'pad_mode': 'reflect',
            'normalized': True,
            'onesided': True,
        }
        rate = 2

        complex_specgrams = torch.stft(waveform, **kwargs)

        # Single then transform then batch
        expected = torchaudio.transforms.TimeStretch(
            fixed_rate=rate,
            n_freq=1025,
            hop_length=512,
        )(complex_specgrams).repeat(3, 1, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.TimeStretch(
            fixed_rate=rate,
            n_freq=1025,
            hop_length=512,
        )(complex_specgrams.repeat(3, 1, 1, 1, 1))

        self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
Exemple #8
0
class TestDatasets(TorchaudioTestCase):
    backend = 'default'
    path = get_asset_path()

    def test_vctk(self):
        data = VCTK(self.path)
        data[0]
Exemple #9
0
 def test_opus(self, bitrate, num_channels, compression_level):
     """`sox_io_backend.info` can check opus file correcty"""
     path = get_asset_path(
         'io', f'{bitrate}_{compression_level}_{num_channels}ch.opus')
     info = sox_io_backend.info(path)
     assert info.sample_rate == 48000
     assert info.num_frames == 32768
     assert info.num_channels == num_channels
Exemple #10
0
 def test_opus(self, bitrate, num_channels, compression_level):
     """`sox_io_backend.info` can check opus file correcty"""
     path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus')
     info = sox_io_backend.info(path)
     assert info.sample_rate == 48000
     assert info.num_frames == 32768
     assert info.num_channels == num_channels
     assert info.bits_per_sample == 0  # bit_per_sample is irrelevant for compressed formats
     assert info.encoding == "OPUS"
Exemple #11
0
def create_temp_assets_dir():
    """
    Creates a temporary directory and moves all files from test/assets there.
    Returns a Tuple[string, TemporaryDirectory] which is the folder path
    and object.
    """
    tmp_dir = tempfile.TemporaryDirectory()
    shutil.copytree(get_asset_path(), os.path.join(tmp_dir.name, "assets"))
    return tmp_dir.name, tmp_dir
Exemple #12
0
    def test_AmplitudeToDB(self):
        filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
        waveform = common_utils.load_wav(filepath)[0]

        mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.)
        power_to_db_transform = transforms.AmplitudeToDB('power', 80.)

        mag_to_db_torch = mag_to_db_transform(torch.abs(waveform))
        power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2))

        self.assertEqual(mag_to_db_torch, power_to_db_torch)
    def test_batch_mfcc(self):
        test_filepath = common_utils.get_asset_path(
            'steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)

        # Single then transform then batch
        expected = torchaudio.transforms.MFCC()(waveform).repeat(3, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.MFCC()(waveform.repeat(3, 1, 1))
        self.assertEqual(computed, expected, atol=1e-4, rtol=1e-5)
Exemple #14
0
    def test_opus(self, bitrate, num_channels, compression_level):
        """`sox_io_backend.load` can load opus file correctly."""
        ops_path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus')
        wav_path = self.get_temp_path(f'{bitrate}_{compression_level}_{num_channels}ch.opus.wav')
        sox_utils.convert_audio_file(ops_path, wav_path)

        expected, sample_rate = load_wav(wav_path)
        found, sr = sox_io_backend.load(ops_path)

        assert sample_rate == sr
        self.assertEqual(expected, found)
Exemple #15
0
    def test_mel2(self):
        top_db = 80.
        s2db = transforms.AmplitudeToDB('power', top_db)

        waveform = self.waveform.clone()  # (1, 16000)
        waveform_scaled = self.scale(waveform)  # (1, 16000)
        mel_transform = transforms.MelSpectrogram()
        # check defaults
        spectrogram_torch = s2db(
            mel_transform(waveform_scaled))  # (1, 128, 321)
        self.assertTrue(spectrogram_torch.dim() == 3)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
        # check correctness of filterbank conversion matrix
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
        # check options
        kwargs = {
            'window_fn': torch.hamming_window,
            'pad': 10,
            'win_length': 500,
            'hop_length': 125,
            'n_fft': 800,
            'n_mels': 50
        }
        mel_transform2 = transforms.MelSpectrogram(**kwargs)
        spectrogram2_torch = s2db(
            mel_transform2(waveform_scaled))  # (1, 50, 513)
        self.assertTrue(spectrogram2_torch.dim() == 3)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
        self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
        # check on multi-channel audio
        filepath = common_utils.get_asset_path(
            'steam-train-whistle-daniel_simon.wav')
        x_stereo = common_utils.load_wav(filepath)[0]  # (2, 278756), 44100
        spectrogram_stereo = s2db(mel_transform(x_stereo))  # (2, 128, 1394)
        self.assertTrue(spectrogram_stereo.dim() == 3)
        self.assertTrue(spectrogram_stereo.size(0) == 2)
        self.assertTrue(
            spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
        self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
        # check filterbank matrix creation
        fb_matrix_transform = transforms.MelScale(n_mels=100,
                                                  sample_rate=16000,
                                                  f_min=0.,
                                                  f_max=None,
                                                  n_stft=400)
        self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
        self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
        self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
 def test_mfcc(self, kwargs):
     """mfcc should be numerically compatible with compute-mfcc-feats"""
     wave_file = get_asset_path('kaldi_file.wav')
     waveform = load_wav(wave_file,
                         normalize=False)[0].to(dtype=self.dtype,
                                                device=self.device)
     result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs)
     command = ['compute-mfcc-feats'
                ] + convert_args(**kwargs) + ['scp:-', 'ark:-']
     kaldi_result = run_kaldi(command, 'scp', wave_file)
     self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
Exemple #17
0
    def setUp(self):
        super().setUp()

        # 1. test signal for testing resampling
        self.test1_signal_sr = 16000
        self.test1_signal = common_utils.get_whitenoise(
            sample_rate=self.test1_signal_sr, duration=0.5,
        )

        # 2. test audio file corresponding to saved kaldi ark files
        self.test2_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')
Exemple #18
0
    def test_mp3(self):
        """Providing format allows to read mp3 without extension

        libsox does not check header for mp3

        https://github.com/pytorch/audio/issues/1040

        The file was generated with the following command
            ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
        """
        path = get_asset_path("mp3_without_ext")
        _, sr = sox_io_backend.load(path, format="mp3")
        assert sr == 16000
    def test_batch_Vol(self):
        test_filepath = common_utils.get_asset_path(
            'steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        # Single then transform then batch
        expected = torchaudio.transforms.Vol(gain=1.1)(waveform).repeat(
            3, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(
            3, 1, 1))
        self.assertEqual(computed, expected)
Exemple #20
0
    def _test_helper(self, file_name, expected_data, fn, expected_dtype):
        """ Takes a file_name to the input data and a function fn to extract the
        data. It compares the extracted data to the expected_data. The expected_dtype
        will be used to check that the extracted data is of the right type.
        """
        test_filepath = common_utils.get_asset_path(file_name)
        expected_output = {
            'key' + str(idx + 1): torch.tensor(val, dtype=expected_dtype)
            for idx, val in enumerate(expected_data)
        }

        for key, vec in fn(test_filepath):
            self.assertTrue(key in expected_output)
            self.assertTrue(isinstance(vec, torch.Tensor))
            self.assertEqual(vec.dtype, expected_dtype)
            self.assertTrue(torch.all(torch.eq(vec, expected_output[key])))
Exemple #21
0
class TestIterator(TorchaudioTestCase):
    backend = 'default'
    path = get_asset_path('CommonVoice', 'cv-corpus-4-2019-12-10', 'tt')

    def test_disckcache_iterator(self):
        data = COMMONVOICE(self.path, url="tatar")
        data = dataset_utils.diskcache_iterator(data)
        # Save
        data[0]
        # Load
        data[0]

    def test_bg_iterator(self):
        data = COMMONVOICE(self.path, url="tatar")
        data = dataset_utils.bg_iterator(data, 5)
        for _ in data:
            pass
Exemple #22
0
    def test_mp3(self):
        """Providing `format` allows to read mp3 without extension

        libsox does not check header for mp3

        https://github.com/pytorch/audio/issues/1040

        The file was generated with the following command
            ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
        """
        path = get_asset_path("mp3_without_ext")
        sinfo = sox_io_backend.info(path, format="mp3")
        assert sinfo.sample_rate == 16000
        assert sinfo.num_frames == 81216
        assert sinfo.num_channels == 1
        assert sinfo.bits_per_sample == 0  # bit_per_sample is irrelevant for compressed formats
        assert sinfo.encoding == "MP3"
Exemple #23
0
class TestIterator(TorchaudioTestCase):
    backend = 'default'
    path = get_asset_path()

    def test_disckcache_iterator(self):
        data = COMMONVOICE(self.path, url="tatar")
        data = dataset_utils.diskcache_iterator(data)
        # Save
        data[0]
        # Load
        data[0]

    def test_bg_iterator(self):
        data = COMMONVOICE(self.path, url="tatar")
        data = dataset_utils.bg_iterator(data, 5)
        for _ in data:
            pass
    def test_resample(self):
        input_path = common_utils.get_asset_path('sinewave.wav')
        waveform, sample_rate = common_utils.load_wav(input_path)

        upsample_rate = sample_rate * 2
        downsample_rate = sample_rate // 2

        ta_upsampled = F.resample(waveform, sample_rate, upsample_rate)
        lr_upsampled = librosa.resample(
            waveform.squeeze(0).numpy(), sample_rate, upsample_rate)
        lr_upsampled = torch.from_numpy(lr_upsampled).unsqueeze(0)

        self.assertEqual(ta_upsampled, lr_upsampled, atol=1e-2, rtol=1e-5)

        ta_downsampled = F.resample(waveform, sample_rate, downsample_rate)
        lr_downsampled = librosa.resample(
            waveform.squeeze(0).numpy(), sample_rate, downsample_rate)
        lr_downsampled = torch.from_numpy(lr_downsampled).unsqueeze(0)

        self.assertEqual(ta_downsampled, lr_downsampled, atol=1e-2, rtol=1e-5)
Exemple #25
0
class TestIterator(TorchaudioTestCase):
    backend = 'default'
    path = get_asset_path()

    def test_disckcache_iterator(self):
        data = COMMONVOICE(self.path,
                           version="cv-corpus-4-2019-12-10",
                           language="tatar")
        data = dataset_utils.diskcache_iterator(data)
        # Save
        data[0]
        # Load
        data[0]

    def test_bg_iterator(self):
        data = COMMONVOICE(self.path,
                           version="cv-corpus-4-2019-12-10",
                           language="tatar")
        data = dataset_utils.bg_iterator(data, 5)
        for _ in data:
            pass
Exemple #26
0
    def test_batch_TimeStretch(self):
        test_filepath = common_utils.get_asset_path(
            'steam-train-whistle-daniel_simon.wav')
        waveform, _ = torchaudio.load(test_filepath)  # (2, 278756), 44100

        rate = 2

        complex_specgrams = torch.view_as_real(
            torch.stft(
                input=waveform,
                n_fft=2048,
                hop_length=512,
                win_length=2048,
                window=torch.hann_window(2048),
                center=True,
                pad_mode='reflect',
                normalized=True,
                onesided=True,
                return_complex=True,
            ))

        # Single then transform then batch
        expected = torchaudio.transforms.TimeStretch(
            fixed_rate=rate,
            n_freq=1025,
            hop_length=512,
        )(complex_specgrams).repeat(3, 1, 1, 1, 1)

        # Batch then transform
        computed = torchaudio.transforms.TimeStretch(
            fixed_rate=rate,
            n_freq=1025,
            hop_length=512,
        )(complex_specgrams.repeat(3, 1, 1, 1, 1))

        self.assertEqual(computed, expected, atol=1e-5, rtol=1e-5)
Exemple #27
0
 def test_dither(self):
     path = get_asset_path('steam-train-whistle-daniel_simon.wav')
     data, _ = load_wav(path)
     result = F.dither(data)
     self.assert_sox_effect(result, path, ['dither'])
Exemple #28
0
 def test_dither_noise(self):
     path = get_asset_path('steam-train-whistle-daniel_simon.wav')
     data, _ = load_wav(path)
     result = F.dither(data, noise_shaping=True)
     self.assert_sox_effect(result, path, ['dither', '-s'], atol=1.5e-4)
 def test_Vad(self):
     filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
     waveform, sample_rate = common_utils.load_wav(filepath)
     self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)
Exemple #30
0
class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
    backend = 'sox'

    kaldi_output_dir = common_utils.get_asset_path('kaldi')
    test_filepath = common_utils.get_asset_path('kaldi_file.wav')
    test_filepaths = {prefix: [] for prefix in compliance_utils.TEST_PREFIX}

    def setUp(self):
        super().setUp()

        # 1. test signal for testing resampling
        self.test1_signal_sr = 16000
        self.test1_signal = common_utils.get_whitenoise(
            sample_rate=self.test1_signal_sr,
            duration=0.5,
        )

        # 2. test audio file corresponding to saved kaldi ark files
        self.test2_filepath = common_utils.get_asset_path(
            'kaldi_file_8000.wav')

    # separating test files by their types (e.g 'spec', 'fbank', etc.)
    for f in os.listdir(kaldi_output_dir):
        dash_idx = f.find('-')
        assert f.endswith('.ark') and dash_idx != -1
        key = f[:dash_idx]
        assert key in test_filepaths
        test_filepaths[key].append(f)

    def _test_get_strided_helper(self, num_samples, window_size, window_shift,
                                 snip_edges):
        waveform = torch.arange(num_samples).float()
        output = kaldi._get_strided(waveform, window_size, window_shift,
                                    snip_edges)

        # from NumFrames in feature-window.cc
        n = window_size
        if snip_edges:
            m = 0 if num_samples < window_size else 1 + (
                num_samples - window_size) // window_shift
        else:
            m = (num_samples + (window_shift // 2)) // window_shift

        self.assertTrue(output.dim() == 2)
        self.assertTrue(output.shape[0] == m and output.shape[1] == n)

        window = torch.empty((m, window_size))

        for r in range(m):
            extract_window(window, waveform, r, window_size, window_shift,
                           snip_edges)
        torch.testing.assert_allclose(window, output)

    def test_get_strided(self):
        # generate any combination where 0 < window_size <= num_samples and
        # 0 < window_shift.
        for num_samples in range(1, 20):
            for window_size in range(1, num_samples + 1):
                for window_shift in range(1, 2 * num_samples + 1):
                    for snip_edges in range(0, 2):
                        self._test_get_strided_helper(num_samples, window_size,
                                                      window_shift, snip_edges)

    def _create_data_set(self):
        # used to generate the dataset to test on. this is not used in testing (offline procedure)
        sr = 16000
        x = torch.arange(0, 20).float()
        # between [-6,6]
        y = torch.cos(
            2 * math.pi * x) + 3 * torch.sin(math.pi * x) + 2 * torch.cos(x)
        # between [-2^30, 2^30]
        y = (y / 6 * (1 << 30)).long()
        # clear the last 16 bits because they aren't used anyways
        y = ((y >> 16) << 16).float()
        torchaudio.save(self.test_filepath, y, sr)
        sound, sample_rate = torchaudio.load(self.test_filepath,
                                             normalization=False)
        print(y >> 16)
        self.assertTrue(sample_rate == sr)
        torch.testing.assert_allclose(y, sound)

    def _print_diagnostic(self, output, expect_output):
        # given an output and expected output, it will print the absolute/relative errors (max and mean squared)
        abs_error = output - expect_output
        abs_mse = abs_error.pow(2).sum() / output.numel()
        abs_max_error = torch.max(abs_error.abs())

        relative_error = abs_error / expect_output
        relative_mse = relative_error.pow(2).sum() / output.numel()
        relative_max_error = torch.max(relative_error.abs())

        print('abs_mse:', abs_mse.item(), 'abs_max_error:',
              abs_max_error.item())
        print('relative_mse:', relative_mse.item(), 'relative_max_error:',
              relative_max_error.item())

    def _compliance_test_helper(self,
                                sound_filepath,
                                filepath_key,
                                expected_num_files,
                                expected_num_args,
                                get_output_fn,
                                atol=1e-5,
                                rtol=1e-7):
        """
        Inputs:
            sound_filepath (str): The location of the sound file
            filepath_key (str): A key to `test_filepaths` which matches which files to use
            expected_num_files (int): The expected number of kaldi files to read
            expected_num_args (int): The expected number of arguments used in a kaldi configuration
            get_output_fn (Callable[[Tensor, List], Tensor]): A function that takes in a sound signal
                and a configuration and returns an output
            atol (float): absolute tolerance
            rtol (float): relative tolerance
        """
        sound, sr = torchaudio.load_wav(sound_filepath)
        files = self.test_filepaths[filepath_key]

        assert len(files) == expected_num_files, \
            ('number of kaldi {} file changed to {}'.format(
                filepath_key, len(files)))

        for f in files:
            print(f)

            # Read kaldi's output from file
            kaldi_output_path = os.path.join(self.kaldi_output_dir, f)
            kaldi_output_dict = {
                k: v
                for k, v in torchaudio.kaldi_io.read_mat_ark(kaldi_output_path)
            }

            assert len(
                kaldi_output_dict
            ) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file'
            kaldi_output = kaldi_output_dict['my_id']

            # Construct the same configuration used by kaldi
            args = f.split('-')
            args[-1] = os.path.splitext(args[-1])[0]
            assert len(
                args) == expected_num_args, 'invalid test kaldi file name'
            args = [compliance_utils.parse(arg) for arg in args]

            output = get_output_fn(sound, args)

            self._print_diagnostic(output, kaldi_output)
            torch.testing.assert_allclose(output,
                                          kaldi_output,
                                          atol=atol,
                                          rtol=rtol)

    def test_mfcc_empty(self):
        # Passing in an empty tensor should result in an error
        self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))

    def test_resample_waveform(self):
        def get_output_fn(sound, args):
            output = kaldi.resample_waveform(sound, args[1], args[2])
            return output

        self._compliance_test_helper(self.test2_filepath,
                                     'resample',
                                     32,
                                     3,
                                     get_output_fn,
                                     atol=1e-2,
                                     rtol=1e-5)

    def test_resample_waveform_upsample_size(self):
        upsample_sound = kaldi.resample_waveform(self.test1_signal,
                                                 self.test1_signal_sr,
                                                 self.test1_signal_sr * 2)
        self.assertTrue(
            upsample_sound.size(-1) == self.test1_signal.size(-1) * 2)

    def test_resample_waveform_downsample_size(self):
        downsample_sound = kaldi.resample_waveform(self.test1_signal,
                                                   self.test1_signal_sr,
                                                   self.test1_signal_sr // 2)
        self.assertTrue(
            downsample_sound.size(-1) == self.test1_signal.size(-1) // 2)

    def test_resample_waveform_identity_size(self):
        downsample_sound = kaldi.resample_waveform(self.test1_signal,
                                                   self.test1_signal_sr,
                                                   self.test1_signal_sr)
        self.assertTrue(
            downsample_sound.size(-1) == self.test1_signal.size(-1))

    def _test_resample_waveform_accuracy(self,
                                         up_scale_factor=None,
                                         down_scale_factor=None,
                                         atol=1e-1,
                                         rtol=1e-4):
        # resample the signal and compare it to the ground truth
        n_to_trim = 20
        sample_rate = 1000
        new_sample_rate = sample_rate

        if up_scale_factor is not None:
            new_sample_rate *= up_scale_factor

        if down_scale_factor is not None:
            new_sample_rate //= down_scale_factor

        duration = 5  # seconds
        original_timestamps = torch.arange(0, duration, 1.0 / sample_rate)

        sound = 123 * torch.cos(
            2 * math.pi * 3 * original_timestamps).unsqueeze(0)
        estimate = kaldi.resample_waveform(sound, sample_rate,
                                           new_sample_rate).squeeze()

        new_timestamps = torch.arange(0, duration,
                                      1.0 / new_sample_rate)[:estimate.size(0)]
        ground_truth = 123 * torch.cos(2 * math.pi * 3 * new_timestamps)

        # trim the first/last n samples as these points have boundary effects
        ground_truth = ground_truth[..., n_to_trim:-n_to_trim]
        estimate = estimate[..., n_to_trim:-n_to_trim]

        torch.testing.assert_allclose(estimate,
                                      ground_truth,
                                      atol=atol,
                                      rtol=rtol)

    def test_resample_waveform_downsample_accuracy(self):
        for i in range(1, 20):
            self._test_resample_waveform_accuracy(down_scale_factor=i * 2)

    def test_resample_waveform_upsample_accuracy(self):
        for i in range(1, 20):
            self._test_resample_waveform_accuracy(up_scale_factor=1.0 +
                                                  i / 20.0)

    def test_resample_waveform_multi_channel(self):
        num_channels = 3

        multi_sound = self.test1_signal.repeat(num_channels,
                                               1)  # (num_channels, 8000 smp)

        for i in range(num_channels):
            multi_sound[i, :] *= (i + 1) * 1.5

        multi_sound_sampled = kaldi.resample_waveform(
            multi_sound, self.test1_signal_sr, self.test1_signal_sr // 2)

        # check that sampling is same whether using separately or in a tensor of size (c, n)
        for i in range(num_channels):
            single_channel = self.test1_signal * (i + 1) * 1.5
            single_channel_sampled = kaldi.resample_waveform(
                single_channel, self.test1_signal_sr,
                self.test1_signal_sr // 2)
            torch.testing.assert_allclose(multi_sound_sampled[i, :],
                                          single_channel_sampled[0],
                                          rtol=1e-4,
                                          atol=1e-7)