def get_output_fn(sound, args): output = kaldi.mfcc(sound, blackman_coeff=args[1], dither=0.0, energy_floor=args[2], frame_length=args[3], frame_shift=args[4], high_freq=args[5], htk_compat=args[6], low_freq=args[7], num_mel_bins=args[8], preemphasis_coefficient=args[9], raw_energy=args[10], remove_dc_offset=args[11], round_to_power_of_two=args[12], snip_edges=args[13], subtract_mean=args[14], use_energy=args[15], num_ceps=args[16], cepstral_lifter=args[17], vtln_high=args[18], vtln_low=args[19], vtln_warp=args[20], window_type=args[21]) return output
def compute_mfcc(data, num_mel_bins=23, frame_length=25, frame_shift=10, dither=0.0, num_ceps=40, high_freq=0.0, low_freq=20.0): """ Extract mfcc Args: data: Iterable[{key, wav, label, sample_rate}] Returns: Iterable[{key, feat, label}] """ for sample in data: assert 'sample_rate' in sample assert 'wav' in sample assert 'key' in sample assert 'label' in sample sample_rate = sample['sample_rate'] waveform = sample['wav'] waveform = waveform * (1 << 15) # Only keep key, feat, label mat = kaldi.mfcc(waveform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, num_ceps=num_ceps, high_freq=high_freq, low_freq=low_freq, sample_frequency=sample_rate) yield dict(key=sample['key'], label=sample['label'], feat=mat)
def get_torchaudio_fbank_or_mfcc( waveform: np.ndarray, sample_rate: float, n_bins: int = 80, feature_type: str = "fbank", ) -> np.ndarray: """Get mel-filter bank or mfcc features via TorchAudio.""" try: import torchaudio.compliance.kaldi as ta_kaldi waveform = torch.from_numpy(waveform) if feature_type == "fbank": features = ta_kaldi.fbank( waveform, num_mel_bins=n_bins, sample_frequency=sample_rate ) else: features = ta_kaldi.mfcc( waveform, num_mel_bins=n_bins, num_ceps=40, low_freq=20, high_freq=-400, sample_frequency=sample_rate, ) return features.numpy() except ImportError: raise ImportError( "Please install torchaudio to enable online feature extraction: pip install torchaudio" )
def set_feats_func(self): # initialize feats_function if self.configs["feats"]["type"] == "mfcc_kaldi": from torchaudio.compliance.kaldi import mfcc self.feats_func = lambda x: mfcc(torch.from_numpy(x.astype("float32").reshape(1, -1)), **self.configs["mfcc_kaldi"]).transpose(0, 1) elif self.configs["feats"]["type"] == "fbank_kaldi": from torchaudio.compliance.kaldi import fbank self.feats_func = lambda x: fbank(torch.from_numpy(x.astype("float32").reshape(1, -1)), **self.configs["fbank_kaldi"]).transpose(0, 1) elif self.configs["feats"]["type"] == "spectrogram_kaldi": from torchaudio.compliance.kaldi import spectrogram self.feats_func = lambda x: spectrogram(torch.from_numpy(x.astype("float32").reshape(1, -1)), **self.configs["spectrogram_kaldi"]).transpose(0, 1) else: raise NotImplementedError
def plain_single_file_predict(model, wav_dir, train_configs, out_dir, window_size=400, lookahead=200, lookbehind=200, regex=""): model = model.eval().cuda() wavs = glob(os.path.join(wav_dir, "**/*{}*.wav".format(regex)), recursive=True) assert len(wavs) > 0, "No file found" for wav in wavs: print("Processing File {}".format(wav)) audio, _ = sf.read(wav) if train_configs["feats"]["type"] == "mfcc_kaldi": feats_func = lambda x: mfcc( torch.from_numpy(x.astype("float32").reshape(1, -1)), ** train_configs["mfcc_kaldi"]).transpose(0, 1) else: raise NotImplementedError tot_feats = compute_feats_windowed(feats_func, audio) tot_feats = tot_feats.detach().cpu().numpy() pred_func = lambda x: model(torch.from_numpy(x).unsqueeze(0).cuda() ).detach().cpu().numpy() preds = overlap_add(tot_feats, pred_func, window_size, window_size // 2, lookahead=lookahead, lookbehind=lookbehind) out_file = os.path.join( out_dir, wav.split("/")[-1].split(".wav")[0] + ".logits") np.save(out_file, preds)
def _feature_fn(self, *args, **kwargs): from torchaudio.compliance.kaldi import mfcc return mfcc(*args, **kwargs)