Example #1
0
    def set_feats_func(self):

        # initialize feats_function
        if self.configs["feats"]["type"] == "mfcc_kaldi":
            from torchaudio.compliance.kaldi import mfcc
            self.feats_func = lambda x: mfcc(torch.from_numpy(x.astype("float32").reshape(1, -1)), **self.configs["mfcc_kaldi"]).transpose(0, 1)
        elif self.configs["feats"]["type"] == "fbank_kaldi":
            from torchaudio.compliance.kaldi import fbank
            self.feats_func = lambda x: fbank(torch.from_numpy(x.astype("float32").reshape(1, -1)), **self.configs["fbank_kaldi"]).transpose(0, 1)
        elif self.configs["feats"]["type"] == "spectrogram_kaldi":
            from torchaudio.compliance.kaldi import spectrogram
            self.feats_func = lambda x: spectrogram(torch.from_numpy(x.astype("float32").reshape(1, -1)),
                                             **self.configs["spectrogram_kaldi"]).transpose(0, 1)
        else:
            raise NotImplementedError
 def get_output_fn(sound, args):
     output = kaldi.spectrogram(sound,
                                blackman_coeff=args[1],
                                dither=args[2],
                                energy_floor=args[3],
                                frame_length=args[4],
                                frame_shift=args[5],
                                preemphasis_coefficient=args[6],
                                raw_energy=args[7],
                                remove_dc_offset=args[8],
                                round_to_power_of_two=args[9],
                                snip_edges=args[10],
                                subtract_mean=args[11],
                                window_type=args[12])
     return output
Example #3
0
    def test_spectrogram(self):
        sound, sample_rate = torchaudio.load_wav(self.test_filepath)
        kaldi_output_dir = os.path.join(self.test_dirpath, 'assets', 'kaldi')
        files = list(
            filter(lambda x: x.startswith('spec'),
                   os.listdir(kaldi_output_dir)))
        print('Results:', len(files))

        for f in files:
            print(f)
            kaldi_output_path = os.path.join(kaldi_output_dir, f)
            kaldi_output_dict = {
                k: v
                for k, v in torchaudio.kaldi_io.read_mat_ark(kaldi_output_path)
            }

            assert len(
                kaldi_output_dict
            ) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file'
            kaldi_output = kaldi_output_dict['my_id']

            args = f.split('-')
            args[-1] = os.path.splitext(args[-1])[0]
            assert len(args) == 13, 'invalid test kaldi file name'

            spec_output = kaldi.spectrogram(
                sound,
                blackman_coeff=float(args[1]),
                dither=float(args[2]),
                energy_floor=float(args[3]),
                frame_length=float(args[4]),
                frame_shift=float(args[5]),
                preemphasis_coefficient=float(args[6]),
                raw_energy=args[7] == 'true',
                remove_dc_offset=args[8] == 'true',
                round_to_power_of_two=args[9] == 'true',
                snip_edges=args[10] == 'true',
                subtract_mean=args[11] == 'true',
                window_type=args[12])

            error = spec_output - kaldi_output
            mse = error.pow(2).sum() / spec_output.numel()
            max_error = torch.max(error.abs())

            print('mse:', mse.item(), 'max_error:', max_error.item())
            self.assertTrue(spec_output.shape, kaldi_output.shape)
            self.assertTrue(
                torch.allclose(spec_output, kaldi_output, atol=1e-3, rtol=0))
Example #4
0
    def _feature_fn(self, *args, **kwargs):
        from torchaudio.compliance.kaldi import spectrogram

        return spectrogram(*args, **kwargs)