def set_feats_func(self): # initialize feats_function if self.configs["feats"]["type"] == "mfcc_kaldi": from torchaudio.compliance.kaldi import mfcc self.feats_func = lambda x: mfcc(torch.from_numpy(x.astype("float32").reshape(1, -1)), **self.configs["mfcc_kaldi"]).transpose(0, 1) elif self.configs["feats"]["type"] == "fbank_kaldi": from torchaudio.compliance.kaldi import fbank self.feats_func = lambda x: fbank(torch.from_numpy(x.astype("float32").reshape(1, -1)), **self.configs["fbank_kaldi"]).transpose(0, 1) elif self.configs["feats"]["type"] == "spectrogram_kaldi": from torchaudio.compliance.kaldi import spectrogram self.feats_func = lambda x: spectrogram(torch.from_numpy(x.astype("float32").reshape(1, -1)), **self.configs["spectrogram_kaldi"]).transpose(0, 1) else: raise NotImplementedError
def get_output_fn(sound, args): output = kaldi.spectrogram(sound, blackman_coeff=args[1], dither=args[2], energy_floor=args[3], frame_length=args[4], frame_shift=args[5], preemphasis_coefficient=args[6], raw_energy=args[7], remove_dc_offset=args[8], round_to_power_of_two=args[9], snip_edges=args[10], subtract_mean=args[11], window_type=args[12]) return output
def test_spectrogram(self): sound, sample_rate = torchaudio.load_wav(self.test_filepath) kaldi_output_dir = os.path.join(self.test_dirpath, 'assets', 'kaldi') files = list( filter(lambda x: x.startswith('spec'), os.listdir(kaldi_output_dir))) print('Results:', len(files)) for f in files: print(f) kaldi_output_path = os.path.join(kaldi_output_dir, f) kaldi_output_dict = { k: v for k, v in torchaudio.kaldi_io.read_mat_ark(kaldi_output_path) } assert len( kaldi_output_dict ) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file' kaldi_output = kaldi_output_dict['my_id'] args = f.split('-') args[-1] = os.path.splitext(args[-1])[0] assert len(args) == 13, 'invalid test kaldi file name' spec_output = kaldi.spectrogram( sound, blackman_coeff=float(args[1]), dither=float(args[2]), energy_floor=float(args[3]), frame_length=float(args[4]), frame_shift=float(args[5]), preemphasis_coefficient=float(args[6]), raw_energy=args[7] == 'true', remove_dc_offset=args[8] == 'true', round_to_power_of_two=args[9] == 'true', snip_edges=args[10] == 'true', subtract_mean=args[11] == 'true', window_type=args[12]) error = spec_output - kaldi_output mse = error.pow(2).sum() / spec_output.numel() max_error = torch.max(error.abs()) print('mse:', mse.item(), 'max_error:', max_error.item()) self.assertTrue(spec_output.shape, kaldi_output.shape) self.assertTrue( torch.allclose(spec_output, kaldi_output, atol=1e-3, rtol=0))
def _feature_fn(self, *args, **kwargs): from torchaudio.compliance.kaldi import spectrogram return spectrogram(*args, **kwargs)