def assert_sphere(self, sample_rate, num_channels, duration): """`sox_io_backend.save` can save sph format. This test takes the same strategy as mp3 to compare the result """ src_path = self.get_temp_path('1.reference.wav') flc_path = self.get_temp_path('2.1.torchaudio.sph') wav_path = self.get_temp_path('2.2.torchaudio.wav') flc_path_sox = self.get_temp_path('3.1.sox.sph') wav_path_sox = self.get_temp_path('3.2.sox.wav') # 1. Generate original wav data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate) save_wav(src_path, data, sample_rate) # 2.1. Convert the original wav to sph with torchaudio sox_io_backend.save(flc_path, load_wav(src_path)[0], sample_rate) # 2.2. Convert the sph to wav with Sox # converting to 32 bit because sph file has 24 bit depth which scipy cannot handle. sox_utils.convert_audio_file(flc_path, wav_path, bit_depth=32) # 2.3. Load found = load_wav(wav_path)[0] # 3.1. Convert the original wav to sph with SoX sox_utils.convert_audio_file(src_path, flc_path_sox) # 3.2. Convert the sph to wav with Sox # converting to 32 bit because sph file has 24 bit depth which scipy cannot handle. sox_utils.convert_audio_file(flc_path_sox, wav_path_sox, bit_depth=32) # 3.3. Load expected = load_wav(wav_path_sox)[0] self.assertEqual(found, expected)
def assert_amb(self, dtype, sample_rate, num_channels, duration): """`sox_io_backend.save` can save amb format. This test takes the same strategy as mp3 to compare the result """ src_path = self.get_temp_path('1.reference.wav') amb_path = self.get_temp_path('2.1.torchaudio.amb') wav_path = self.get_temp_path('2.2.torchaudio.wav') amb_path_sox = self.get_temp_path('3.1.sox.amb') wav_path_sox = self.get_temp_path('3.2.sox.wav') # 1. Generate original wav data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) save_wav(src_path, data, sample_rate) # 2.1. Convert the original wav to amb with torchaudio sox_io_backend.save(amb_path, load_wav(src_path, normalize=False)[0], sample_rate) # 2.2. Convert the amb to wav with Sox sox_utils.convert_audio_file(amb_path, wav_path) # 2.3. Load found = load_wav(wav_path)[0] # 3.1. Convert the original wav to amb with SoX sox_utils.convert_audio_file(src_path, amb_path_sox) # 3.2. Convert the amb to wav with Sox sox_utils.convert_audio_file(amb_path_sox, wav_path_sox) # 3.3. Load expected = load_wav(wav_path_sox)[0] self.assertEqual(found, expected)
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration): """`sox_io_backend.save` can save mp3 format. mp3 encoding introduces delay and boundary effects so we convert the resulting mp3 to wav and compare the results there | | 1. Generate original wav file with SciPy | v -------------- wav ---------------- | | | 2.1. load with scipy | 3.1. Convert to mp3 with Sox | then save with torchaudio | v v mp3 mp3 | | | 2.2. Convert to wav with Sox | 3.2. Convert to wav with Sox | | v v wav wav | | | 2.3. load with scipy | 3.3. load with scipy | | v v tensor -------> compare <--------- tensor """ src_path = self.get_temp_path('1.reference.wav') mp3_path = self.get_temp_path('2.1.torchaudio.mp3') wav_path = self.get_temp_path('2.2.torchaudio.wav') mp3_path_sox = self.get_temp_path('3.1.sox.mp3') wav_path_sox = self.get_temp_path('3.2.sox.wav') # 1. Generate original wav data = get_wav_data('float32', num_channels, normalize=True, num_frames=duration * sample_rate) save_wav(src_path, data, sample_rate) # 2.1. Convert the original wav to mp3 with torchaudio sox_io_backend.save(mp3_path, load_wav(src_path)[0], sample_rate, compression=bit_rate) # 2.2. Convert the mp3 to wav with Sox sox_utils.convert_audio_file(mp3_path, wav_path) # 2.3. Load found = load_wav(wav_path)[0] # 3.1. Convert the original wav to mp3 with SoX sox_utils.convert_audio_file(src_path, mp3_path_sox, compression=bit_rate) # 3.2. Convert the mp3 to wav with Sox sox_utils.convert_audio_file(mp3_path_sox, wav_path_sox) # 3.3. Load expected = load_wav(wav_path_sox)[0] self.assertEqual(found, expected)
def test_save_flac(self, sample_rate, num_channels, compression_level): ts_save_func = torch_script(py_save_func) expected = get_wav_data('float32', num_channels) py_path = self.get_temp_path( f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac' ) ts_path = self.get_temp_path( f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac' ) py_save_func(py_path, expected, sample_rate, True, compression_level, None, None) ts_save_func(ts_path, expected, sample_rate, True, compression_level, None, None) # converting to 32 bit because flac file has 24 bit depth which scipy cannot handle. py_path_wav = f'{py_path}.wav' ts_path_wav = f'{ts_path}.wav' sox_utils.convert_audio_file(py_path, py_path_wav, bit_depth=32) sox_utils.convert_audio_file(ts_path, ts_path_wav, bit_depth=32) py_data, py_sr = load_wav(py_path_wav, normalize=True) ts_data, ts_sr = load_wav(ts_path_wav, normalize=True) self.assertEqual(sample_rate, py_sr) self.assertEqual(sample_rate, ts_sr) self.assertEqual(expected, py_data) self.assertEqual(expected, ts_data)
def _assert_vorbis(self, sample_rate, num_channels, quality_level, duration): """`sox_io_backend.save` can save vorbis format. This test takes the same strategy as mp3 to compare the result """ src_path = self.get_temp_path('1.reference.wav') vbs_path = self.get_temp_path('2.1.torchaudio.vorbis') wav_path = self.get_temp_path('2.2.torchaudio.wav') vbs_path_sox = self.get_temp_path('3.1.sox.vorbis') wav_path_sox = self.get_temp_path('3.2.sox.wav') # 1. Generate original wav data = get_wav_data('int16', num_channels, normalize=False, num_frames=duration * sample_rate) save_wav(src_path, data, sample_rate) # 2.1. Convert the original wav to vorbis with torchaudio sox_io_backend.save(vbs_path, load_wav(src_path)[0], sample_rate, compression=quality_level, dtype=None) # 2.2. Convert the vorbis to wav with Sox sox_utils.convert_audio_file(vbs_path, wav_path) # 2.3. Load found = load_wav(wav_path)[0] # 3.1. Convert the original wav to vorbis with SoX sox_utils.convert_audio_file(src_path, vbs_path_sox, compression=quality_level) # 3.2. Convert the vorbis to wav with Sox sox_utils.convert_audio_file(vbs_path_sox, wav_path_sox) # 3.3. Load expected = load_wav(wav_path_sox)[0] # sox's vorbis encoding has some random boundary effect, which cause small number of # samples yields higher descrepency than the others. # so we allow small portions of data to be outside of absolute torelance. # make sure to pass somewhat long duration atol = 1.0e-4 max_failure_allowed = 0.01 # this percent of samples are allowed to outside of atol. failure_ratio = ( (found - expected).abs() > atol).sum().item() / found.numel() if failure_ratio > max_failure_allowed: # it's failed and this will give a better error message. self.assertEqual(found, expected, atol=atol, rtol=1.3e-6)
def assert_wav( self, dtype, sample_rate, num_channels, normalize, channels_first=True, duration=1, ): """`soundfile_backend.load` can load wav format correctly. Wav data loaded with soundfile backend should match those with scipy """ path = self.get_temp_path("reference.wav") num_frames = duration * sample_rate data = get_wav_data( dtype, num_channels, normalize=normalize, num_frames=num_frames, channels_first=channels_first, ) save_wav(path, data, sample_rate, channels_first=channels_first) expected = load_wav(path, normalize=normalize, channels_first=channels_first)[0] data, sr = soundfile_backend.load(path, normalize=normalize, channels_first=channels_first) assert sr == sample_rate self.assertEqual(data, expected)
def test_vad_from_file(self): filepath = common_utils.get_asset_path("vad-go-stereo-44100.wav") waveform, sample_rate = common_utils.load_wav(filepath) # Each channel is slightly offset - we can use this to create a batch # with different items. batch = waveform.view(2, 1, -1) self.assert_batch_consistency(F.vad, batch, sample_rate=sample_rate)
def assert_amr_nb(self, duration): """`sox_io_backend.load` can load amr-nb format. This test takes the same strategy as mp3 to compare the result """ sample_rate = 8000 num_channels = 1 path = self.get_temp_path('1.original.amr-nb') ref_path = self.get_temp_path('2.reference.wav') # 1. Generate amr-nb with sox sox_utils.gen_audio_file(path, sample_rate, num_channels, bit_depth=32, duration=duration) # 2. Convert to wav with sox sox_utils.convert_audio_file(path, ref_path) # 3. Load amr-nb with torchaudio data, sr = sox_io_backend.load(path) # 4. Load wav with scipy data_ref = load_wav(ref_path)[0] # 5. Compare assert sr == sample_rate self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
def assert_vorbis(self, sample_rate, num_channels, quality_level, duration): """`sox_io_backend.load` can load vorbis format. This test takes the same strategy as mp3 to compare the result """ path = self.get_temp_path('1.original.vorbis') ref_path = self.get_temp_path('2.reference.wav') # 1. Generate vorbis with sox sox_utils.gen_audio_file(path, sample_rate, num_channels, compression=quality_level, bit_depth=16, duration=duration) # 2. Convert to wav with sox sox_utils.convert_audio_file(path, ref_path) # 3. Load vorbis with torchaudio data, sr = sox_io_backend.load(path) # 4. Load wav with scipy data_ref = load_wav(ref_path)[0] # 5. Compare assert sr == sample_rate self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
def test_tarfile(self, ext, compression): """Applying effects to compressed audio via file-like file works""" sample_rate = 16000 channels_first = True effects = [['band', '300', '10']] format_ = ext if ext in ['mp3'] else None audio_file = f'input.{ext}' input_path = self.get_temp_path(audio_file) reference_path = self.get_temp_path('reference.wav') archive_path = self.get_temp_path('archive.tar.gz') sox_utils.gen_audio_file( input_path, sample_rate, num_channels=2, compression=compression) sox_utils.run_sox_effect( input_path, reference_path, effects, output_bitdepth=32) expected, expected_sr = load_wav(reference_path) with tarfile.TarFile(archive_path, 'w') as tarobj: tarobj.add(input_path, arcname=audio_file) with tarfile.TarFile(archive_path, 'r') as tarobj: fileobj = tarobj.extractfile(audio_file) found, sr = sox_effects.apply_effects_file( fileobj, effects, channels_first=channels_first, format=format_) save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first) assert sr == expected_sr self.assertEqual(found, expected)
def test_apply_effects(self, args): """`apply_effects_tensor` should return identical data as sox command""" effects = args['effects'] num_channels = args.get("num_channels", 2) input_sr = args.get("input_sample_rate", 8000) output_sr = args.get("output_sample_rate") input_path = self.get_temp_path('input.wav') reference_path = self.get_temp_path('reference.wav') original = get_sinusoid(frequency=800, sample_rate=input_sr, n_channels=num_channels, dtype='float32') save_wav(input_path, original, input_sr) sox_utils.run_sox_effect(input_path, reference_path, effects, output_sample_rate=output_sr) expected, expected_sr = load_wav(reference_path) found, sr = sox_effects.apply_effects_tensor(original, input_sr, effects) assert sr == expected_sr self.assertEqual(expected, found)
def test_resample_size(self): input_path = common_utils.get_asset_path('sinewave.wav') waveform, sample_rate = common_utils.load_wav(input_path) upsample_rate = sample_rate * 2 downsample_rate = sample_rate // 2 invalid_resampling_method = 'foo' with self.assertRaises(ValueError): torchaudio.transforms.Resample( sample_rate, upsample_rate, resampling_method=invalid_resampling_method) upsample_resample = torchaudio.transforms.Resample( sample_rate, upsample_rate, resampling_method='sinc_interpolation') up_sampled = upsample_resample(waveform) # we expect the upsampled signal to have twice as many samples self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2) downsample_resample = torchaudio.transforms.Resample( sample_rate, downsample_rate, resampling_method='sinc_interpolation') down_sampled = downsample_resample(waveform) # we expect the downsampled signal to have half as many samples self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
def test_requests(self, ext, compression): sample_rate = 16000 channels_first = True effects = [['band', '300', '10']] format_ = ext if ext in ['mp3'] else None audio_file = f'input.{ext}' input_path = self.get_temp_path(audio_file) reference_path = self.get_temp_path('reference.wav') sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression) sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32) expected, expected_sr = load_wav(reference_path) url = self.get_url(audio_file) with requests.get(url, stream=True) as resp: found, sr = sox_effects.apply_effects_file( resp.raw, effects, channels_first=channels_first, format=format_) save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first) assert sr == expected_sr self.assertEqual(found, expected)
def test_apply_effects_path(self): """`apply_effects_file` should return identical data as sox command when file path is given as a Path Object""" dtype = 'int32' channels_first = True effects = [["hilbert"]] num_channels = 2 input_sr = 8000 output_sr = 8000 input_path = self.get_temp_path('input.wav') reference_path = self.get_temp_path('reference.wav') data = get_wav_data(dtype, num_channels, channels_first=channels_first) save_wav(input_path, data, input_sr, channels_first=channels_first) sox_utils.run_sox_effect(input_path, reference_path, effects, output_sample_rate=output_sr) expected, expected_sr = load_wav(reference_path) found, sr = sox_effects.apply_effects_file( Path(input_path), effects, normalize=False, channels_first=channels_first) assert sr == expected_sr self.assertEqual(found, expected)
def test_bytesio(self, ext, compression): """Applying effects via BytesIO object works""" sample_rate = 16000 channels_first = True effects = [['band', '300', '10']] format_ = ext if ext in ['mp3'] else None input_path = self.get_temp_path(f'input.{ext}') reference_path = self.get_temp_path('reference.wav') sox_utils.gen_audio_file(input_path, sample_rate, num_channels=2, compression=compression) sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32) expected, expected_sr = load_wav(reference_path) with open(input_path, 'rb') as file_: fileobj = io.BytesIO(file_.read()) found, sr = sox_effects.apply_effects_file( fileobj, effects, channels_first=channels_first, format=format_) save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first) assert sr == expected_sr self.assertEqual(found, expected)
def test_apply_effects_str(self, args): """`apply_effects_file` should return identical data as sox command""" dtype = 'int32' channels_first = True effects = args['effects'] num_channels = args.get("num_channels", 2) input_sr = args.get("input_sample_rate", 8000) output_sr = args.get("output_sample_rate") input_path = self.get_temp_path('input.wav') reference_path = self.get_temp_path('reference.wav') data = get_wav_data(dtype, num_channels, channels_first=channels_first) save_wav(input_path, data, input_sr, channels_first=channels_first) sox_utils.run_sox_effect(input_path, reference_path, effects, output_sample_rate=output_sr) expected, expected_sr = load_wav(reference_path) found, sr = sox_effects.apply_effects_file( input_path, effects, normalize=False, channels_first=channels_first) assert sr == expected_sr self.assertEqual(found, expected)
def assert_wav(self, dtype, sample_rate, num_channels, num_frames): """`sox_io_backend.save` can save wav format.""" path = self.get_temp_path('data.wav') expected = get_wav_data(dtype, num_channels, num_frames=num_frames) sox_io_backend.save(path, expected, sample_rate) found, sr = load_wav(path) assert sample_rate == sr self.assertEqual(found, expected)
def test_channels_first(self, channels_first): """channels_first swaps axes""" path = self.get_temp_path('data.wav') data = get_wav_data('int32', 2, channels_first=channels_first) sox_io_backend.save(path, data, 8000, channels_first=channels_first) found = load_wav(path)[0] expected = data if channels_first else data.transpose(1, 0) self.assertEqual(found, expected)
def test_noncontiguous(self, dtype): """Noncontiguous tensors are saved correctly""" path = self.get_temp_path('data.wav') expected = get_wav_data(dtype, 4)[::2, ::2] assert not expected.is_contiguous() sox_io_backend.save(path, expected, 8000) found = load_wav(path)[0] self.assertEqual(found, expected)
def test_channels_first(self, channels_first): """channels_first swaps axes""" path = self.get_temp_path("data.wav") data = get_wav_data("int32", 2, channels_first=channels_first) soundfile_backend.save(path, data, 8000, channels_first=channels_first) found = load_wav(path)[0] expected = data if channels_first else data.transpose(1, 0) self.assertEqual(found, expected, atol=1e-4, rtol=1e-8)
def test_dtype_conversion(self, dtype, expected): """`save` performs dtype conversion on float32 src tensors only.""" path = self.get_temp_path("data.wav") data = torch.tensor([-1.0, -0.5, 0, 0.5, 1.0]).to(torch.float32).view(-1, 1) sox_io_backend.save(path, data, 8000, dtype=dtype) found = load_wav(path, normalize=False)[0] self.assertEqual(found, expected.view(-1, 1))
def test_save_noncontiguous(self, dtype): """Noncontiguous tensors are saved correctly""" path = self.get_temp_path('data.wav') enc, bps = get_enc_params(dtype) expected = get_wav_data(dtype, 4, normalize=False)[::2, ::2] assert not expected.is_contiguous() sox_io_backend.save( path, expected, 8000, encoding=enc, bits_per_sample=bps) found = load_wav(path, normalize=False)[0] self.assertEqual(found, expected)
def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_files, expected_num_args, get_output_fn, atol=1e-5, rtol=1e-7): """ Inputs: sound_filepath (str): The location of the sound file filepath_key (str): A key to `test_filepaths` which matches which files to use expected_num_files (int): The expected number of kaldi files to read expected_num_args (int): The expected number of arguments used in a kaldi configuration get_output_fn (Callable[[Tensor, List], Tensor]): A function that takes in a sound signal and a configuration and returns an output atol (float): absolute tolerance rtol (float): relative tolerance """ sound, sr = common_utils.load_wav(sound_filepath, normalize=False) files = self.test_filepaths[filepath_key] assert len(files) == expected_num_files, \ ('number of kaldi {} file changed to {}'.format( filepath_key, len(files))) for f in files: print(f) # Read kaldi's output from file kaldi_output_path = os.path.join(self.kaldi_output_dir, f) kaldi_output_dict = { k: v for k, v in torchaudio.kaldi_io.read_mat_ark(kaldi_output_path) } assert len( kaldi_output_dict ) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file' kaldi_output = kaldi_output_dict['my_id'] # Construct the same configuration used by kaldi args = f.split('-') args[-1] = os.path.splitext(args[-1])[0] assert len( args) == expected_num_args, 'invalid test kaldi file name' args = [compliance_utils.parse(arg) for arg in args] output = get_output_fn(sound, args) self._print_diagnostic(output, kaldi_output) torch.testing.assert_allclose(output, kaldi_output, atol=atol, rtol=rtol)
def test_save_wav(self, dtype, sample_rate, num_channels): ts_save_func = torch_script(py_save_func) expected = get_wav_data(dtype, num_channels, normalize=False) py_path = self.get_temp_path( f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav') ts_path = self.get_temp_path( f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav') enc, bps = get_enc_params(dtype) py_save_func(py_path, expected, sample_rate, True, None, enc, bps) ts_save_func(ts_path, expected, sample_rate, True, None, enc, bps) py_data, py_sr = load_wav(py_path, normalize=False) ts_data, ts_sr = load_wav(ts_path, normalize=False) self.assertEqual(sample_rate, py_sr) self.assertEqual(sample_rate, ts_sr) self.assertEqual(expected, py_data) self.assertEqual(expected, ts_data)
def test_AmplitudeToDB(self): filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav') waveform = common_utils.load_wav(filepath)[0] mag_to_db_transform = transforms.AmplitudeToDB('magnitude', 80.) power_to_db_transform = transforms.AmplitudeToDB('power', 80.) mag_to_db_torch = mag_to_db_transform(torch.abs(waveform)) power_to_db_torch = power_to_db_transform(torch.pow(waveform, 2)) self.assertEqual(mag_to_db_torch, power_to_db_torch)
def test_opus(self, bitrate, num_channels, compression_level): """`sox_io_backend.load` can load opus file correctly.""" ops_path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus') wav_path = self.get_temp_path(f'{bitrate}_{compression_level}_{num_channels}ch.opus.wav') sox_utils.convert_audio_file(ops_path, wav_path) expected, sample_rate = load_wav(wav_path) found, sr = sox_io_backend.load(ops_path) assert sample_rate == sr self.assertEqual(expected, found)
def test_mfcc(self, kwargs): """mfcc should be numerically compatible with compute-mfcc-feats""" wave_file = get_asset_path('kaldi_file.wav') waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device) result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs) command = ['compute-mfcc-feats' ] + convert_args(**kwargs) + ['scp:-', 'ark:-'] kaldi_result = run_kaldi(command, 'scp', wave_file) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
def test_mel2(self): top_db = 80. s2db = transforms.AmplitudeToDB('power', top_db) waveform = self.waveform.clone() # (1, 16000) waveform_scaled = self.scale(waveform) # (1, 16000) mel_transform = transforms.MelSpectrogram() # check defaults spectrogram_torch = s2db( mel_transform(waveform_scaled)) # (1, 128, 321) self.assertTrue(spectrogram_torch.dim() == 3) self.assertTrue( spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels) # check correctness of filterbank conversion matrix self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all()) # check options kwargs = { 'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500, 'hop_length': 125, 'n_fft': 800, 'n_mels': 50 } mel_transform2 = transforms.MelSpectrogram(**kwargs) spectrogram2_torch = s2db( mel_transform2(waveform_scaled)) # (1, 50, 513) self.assertTrue(spectrogram2_torch.dim() == 3) self.assertTrue( spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels) self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all()) # check on multi-channel audio filepath = common_utils.get_asset_path( 'steam-train-whistle-daniel_simon.wav') x_stereo = common_utils.load_wav(filepath)[0] # (2, 278756), 44100 spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394) self.assertTrue(spectrogram_stereo.dim() == 3) self.assertTrue(spectrogram_stereo.size(0) == 2) self.assertTrue( spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels) # check filterbank matrix creation fb_matrix_transform = transforms.MelScale(n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400) self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): """`sox_io_backend.load` can load wav format correctly. Wav data loaded with sox_io backend should match those with scipy """ path = self.get_temp_path('reference.wav') data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate) save_wav(path, data, sample_rate) expected = load_wav(path, normalize=normalize)[0] data, sr = sox_io_backend.load(path, normalize=normalize) assert sr == sample_rate self.assertEqual(data, expected)
def test_save_wav(self, dtype, sample_rate, num_channels): script_path = self.get_temp_path('save_func.zip') torch.jit.script(py_save_func).save(script_path) ts_save_func = torch.jit.load(script_path) expected = get_wav_data(dtype, num_channels) py_path = self.get_temp_path( f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav') ts_path = self.get_temp_path( f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav') py_save_func(py_path, expected, sample_rate, True, None) ts_save_func(ts_path, expected, sample_rate, True, None) py_data, py_sr = load_wav(py_path) ts_data, ts_sr = load_wav(ts_path) self.assertEqual(sample_rate, py_sr) self.assertEqual(sample_rate, ts_sr) self.assertEqual(expected, py_data) self.assertEqual(expected, ts_data)