Ejemplo n.º 1
0
    def test_griffinlim(self):

        # NOTE: This test is flaky without a fixed random seed
        # See https://github.com/pytorch/audio/issues/382
        torch.random.manual_seed(42)
        tensor = torch.rand((1, 1000))

        n_fft = 400
        ws = 400
        hop = 100
        window = torch.hann_window(ws)
        normalize = False
        momentum = 0.99
        n_iter = 8
        length = 1000
        rand_init = False
        init = 'random' if rand_init else None

        specgram = F.spectrogram(tensor, 0, window, n_fft, hop, ws, 2,
                                 normalize).sqrt()
        ta_out = F.griffinlim(specgram, window, n_fft, hop, ws, 1, normalize,
                              n_iter, momentum, length, rand_init)
        lr_out = librosa.griffinlim(specgram.squeeze(0).numpy(),
                                    n_iter=n_iter,
                                    hop_length=hop,
                                    momentum=momentum,
                                    init=init,
                                    length=length)
        lr_out = torch.from_numpy(lr_out).unsqueeze(0)

        self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5))
    def test_griffinlim(self, momentum):
        # FFT params
        n_fft = 400
        win_length = n_fft
        hop_length = n_fft // 4
        window = torch.hann_window(win_length, device=self.device)
        power = 1
        # GriffinLim params
        n_iter = 8

        waveform = get_whitenoise(device=self.device, dtype=self.dtype)
        specgram = get_spectrogram(
            waveform, n_fft=n_fft, hop_length=hop_length, power=power,
            win_length=win_length, window=window)

        result = F.griffinlim(
            specgram,
            window=window,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            power=power,
            n_iter=n_iter,
            momentum=momentum,
            length=waveform.size(1),
            rand_init=False)
        expected = librosa.griffinlim(
            specgram[0].cpu().numpy(),
            n_iter=n_iter,
            hop_length=hop_length,
            momentum=momentum,
            init=None,
            length=waveform.size(1))[None, ...]
        self.assertEqual(result, torch.from_numpy(expected), atol=5e-5, rtol=1e-07)
def spec2wav(x, cos, sin, wav_len, syn_phase=0, device="cuda"):

#'''
#	args : channels * frames * n_fft
#'''

	x = F.pad(x, (0, 1), "constant", 0)
	fft_window = FFT_WINDOW.cuda() if device == "cuda" else FFT_WINDOW

	if syn_phase == 1:
		print("here")
		wav_len = int((x.shape[-2] - 1)/ FRAMES_PER_SECOND * SAMPLE_RATE)
		wav = AF.griffinlim(x.transpose(1, 2), 
												window=fft_window, 
												n_fft=N_FFT, 
												hop_length=HOP_LEN, 
												win_length=WINDOW_SIZE, 
												power=1,
												normalized=False, 
												length=wav_len, 
												n_iter=N_ITER, 
												momentum=0, 
												rand_init=False)
	elif syn_phase == 2:
		itersNum = 100
		for i in range(itersNum):
			spec = torch.stack([x * cos, x * sin], -1).transpose(1, 2)	
			wav = torch.istft(spec,
											n_fft=N_FFT,
											hop_length=HOP_LEN,
											win_length=WINDOW_SIZE,
											window=fft_window,
											center=True,
											normalized=False,
											onesided=None,
											length=wav_len,
											return_complex=False)
			if i < itersNum - 1:
				_, cos, sin = wav2spec(wav)


	
	elif syn_phase == 0:
		spec = torch.stack([x * cos, x * sin], -1).transpose(1, 2)
		wav = torch.istft(spec, 
											n_fft=N_FFT, 
											hop_length=HOP_LEN, 
											win_length=WINDOW_SIZE,
											window=fft_window, 
											center=True, 
											normalized=False, 
											onesided=None, 
											length=wav_len, 
											return_complex=False)
	return wav
Ejemplo n.º 4
0
    def forward(self, specgram: Tensor) -> Tensor:
        r"""
        Args:
            specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
            where freq is ``n_fft // 2 + 1``.

        Returns:
            Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
        """
        return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
                            self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)
Ejemplo n.º 5
0
 def func(tensor):
     n_fft = 400
     ws = 400
     hop = 200
     window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
     power = 2.
     momentum = 0.99
     n_iter = 32
     length = 1000
     rand_int = False
     return F.griffinlim(tensor, window, n_fft, hop, ws, power, n_iter, momentum, length, rand_int)
Ejemplo n.º 6
0
 def forward(self, S):
     return F.griffinlim(S, self.window, self.n_fft, self.hop_length,
                         self.win_length, self.power, self.normalized,
                         self.n_iter, self.momentum, self.length,
                         self.rand_init)