def __init__( self, n_fft: int, win_length: int, hop_length: int, n_iter: int, window_fn=torch.hann_window, ): super(GriffinLim, self).__init__() self.transform = TTSSpectrogram(n_fft, win_length, hop_length, return_phase=True) basis = get_fourier_basis(n_fft) basis = torch.pinverse(n_fft / hop_length * basis).T[:, None, :] basis *= get_window(window_fn, n_fft, win_length) self.register_buffer("basis", basis) self.n_fft = n_fft self.win_length = win_length self.hop_length = hop_length self.n_iter = n_iter self.tiny = 1.1754944e-38
def extract_logmel_spectrogram( waveform: torch.Tensor, sample_rate: int, output_path: Optional[Path] = None, win_length: int = 1024, hop_length: int = 256, n_fft: int = 1024, win_fn: callable = torch.hann_window, n_mels: int = 80, f_min: float = 0., f_max: float = 8000, eps: float = 1e-5, overwrite: bool = False, target_length: Optional[int] = None ): if output_path is not None and output_path.is_file() and not overwrite: return spectrogram_transform = TTSSpectrogram( n_fft=n_fft, win_length=win_length, hop_length=hop_length, window_fn=win_fn ) mel_scale_transform = TTSMelScale( n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, n_stft=n_fft // 2 + 1 ) spectrogram = spectrogram_transform(waveform) mel_spec = mel_scale_transform(spectrogram) logmel_spec = torch.clamp(mel_spec, min=eps).log() assert len(logmel_spec.shape) == 3 and logmel_spec.shape[0] == 1 logmel_spec = logmel_spec.squeeze().t() # D x T -> T x D if target_length is not None: logmel_spec = trim_or_pad_to_target_length(logmel_spec, target_length) if output_path is not None: np.save(output_path.as_posix(), logmel_spec) else: return logmel_spec