def __init__( self, n_fft: int, win_length: int, hop_length: int, n_iter: int, window_fn=torch.hann_window, ): super(GriffinLim, self).__init__() self.transform = TTSSpectrogram(n_fft, win_length, hop_length, return_phase=True) basis = get_fourier_basis(n_fft) basis = torch.pinverse(n_fft / hop_length * basis).T[:, None, :] basis *= get_window(window_fn, n_fft, win_length) self.register_buffer("basis", basis) self.n_fft = n_fft self.win_length = win_length self.hop_length = hop_length self.n_iter = n_iter self.tiny = 1.1754944e-38
def get_window_sum_square(cls, n_frames, hop_length, win_length, n_fft, window_fn=torch.hann_window) -> torch.Tensor: w_sq = get_window(window_fn, n_fft, win_length)**2 n = n_fft + hop_length * (n_frames - 1) x = torch.zeros(n, dtype=torch.float32) for i in range(n_frames): ofst = i * hop_length x[ofst:min(n, ofst + n_fft)] += w_sq[:max(0, min(n_fft, n - ofst))] return x