Ejemplo n.º 1
0
    def __call__(self, S):
        self.window = self.window.to(dtype=S.dtype, device=S.device)

        S = S.pow(1 / self.power)
        if self.normalized:
            S *= self.window.pow(2).sum().sqrt()

        # randomly initialize the phase
        angles = 2 * math.pi * torch.rand(*S.size())
        angles = torch.stack([angles.cos(), angles.sin()],
                             dim=-1).to(dtype=S.dtype, device=S.device)
        S = S.unsqueeze(-1).expand_as(angles)

        # And initialize the previous iterate to 0
        rebuilt = 0.

        for i in range(self.n_iter):
            print(f'Griffin-Lim iteration {i}/{self.n_iter}')

            # Store the previous iterate
            tprev = rebuilt

            # Invert with our current estimate of the phases
            inverse = istft(S * angles,
                            n_fft=self.n_fft,
                            hop_length=self.hop_length,
                            win_length=self.win_length,
                            window=self.window,
                            length=self.length).float()

            # Rebuild the spectrogram
            rebuilt = inverse.stft(n_fft=self.n_fft,
                                   hop_length=self.hop_length,
                                   win_length=self.win_length,
                                   window=self.window,
                                   pad_mode=self.pad_mode)

            # Update our phase estimates
            angles = rebuilt.sub(self.momentum).mul_(tprev)
            angles = angles.div_(
                complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(
                    angles))

        # Return the final phase estimates
        return istft(S * angles,
                     n_fft=self.n_fft,
                     hop_length=self.hop_length,
                     win_length=self.win_length,
                     window=self.window,
                     length=self.length)
Ejemplo n.º 2
0
def inverse_stft(stft):
    """Inverses stft to wave form"""

    pad = win_length // 2 + 1 - stft.size(1)
    stft = FUNC.pad(stft, (0, 0, 0, 0, 0, pad))
    wav = istft(stft, win_length, hop_length=hop_length, window=win)
    return wav.detach()
Ejemplo n.º 3
0
 def inverse(self, magnitude: torch.Tensor, phase: torch.Tensor) -> torch.Tensor:
     # match dimension
     magnitude, phase = magnitude.unsqueeze(3), phase.unsqueeze(3)
     stft = torch.cat([magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=3)
     return istft(
         stft, self.n_fft, self.hop_length, self.win_length, self.window
     )
Ejemplo n.º 4
0
 def istft(self, x):
     return F.istft(x,
                    n_fft=self.n_fft,
                    hop_length=self.n_hop,
                    window=self.window,
                    center=self.center,
                    normalized=False,
                    onesided=True,
                    pad_mode='reflect')
Ejemplo n.º 5
0
    def inv_f(self, input, phase):
        input = torch.stack([input * torch.cos(phase), input * torch.sin(phase)], dim=-1)

        input = istft(
            input,
            n_fft=self.num_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=torch.hann_window(self.win_length, device=input.device),
        )

        return input
Ejemplo n.º 6
0
Archivo: main.py Proyecto: mori97/MVAE
def validate(model, val_dataset, baseline, device, epoch, writer):
    model.eval()

    if_use_cuda = device != torch.device('cpu')
    xp = cp if if_use_cuda else np

    window = torch.hann_window(N_FFT).to(device)

    result = {'SDR': {}, 'SIR': {}, 'SAR': {}}
    for i, (src, mix_spec, speaker) in enumerate(val_dataset):
        if if_use_cuda:
            mix_spec = cp.asarray(mix_spec)
        separated, _ = mvae(mix_spec, model, n_iter=40, device=device)
        separated = separated.transpose(1, 0, 2)
        # Convert to PyTorch-style complex tensor (Shape = (..., 2))
        separated = xp.stack((xp.real(separated), xp.imag(separated)), axis=-1)
        if if_use_cuda:
            separated = to_tensor(separated)
        else:
            separated = torch.from_numpy(separated)
        with torch.no_grad():
            separated = istft(separated, N_FFT, HOP_LEN, window=window)
            separated = separated.cpu().numpy()

        sdr, sir, sar, _ =\
            mir_eval.separation.bss_eval_sources(src, separated)

        if speaker in result['SDR']:
            result['SDR'][speaker].extend(sdr.tolist())
            result['SIR'][speaker].extend(sir.tolist())
            result['SAR'][speaker].extend(sar.tolist())
        else:
            result['SDR'][speaker] = []
            result['SIR'][speaker] = []
            result['SAR'][speaker] = []

        sep_tensor0 = torch.from_numpy(separated[0, :]).unsqueeze(0)
        sep_tensor1 = torch.from_numpy(separated[1, :]).unsqueeze(0)
        writer.add_audio('eval/{}_0'.format(i), sep_tensor0, epoch, 16000)
        writer.add_audio('eval/{}_1'.format(i), sep_tensor1, epoch, 16000)

    for metric in result:
        for speaker in result[metric]:
            result[metric][speaker] = (stat.mean(result[metric][speaker]),
                                       stat.stdev(result[metric][speaker]))

    figures = bar_chart(baseline, result)
    for metric, figure in figures.items():
        writer.add_figure(f'eval/{metric}', figure, epoch)
Ejemplo n.º 7
0
    def forward(self, audio):
        x_stft = torch.stft(audio,
                            self.n_fft,
                            self.hop_len,
                            window=self.window.to(audio.device),
                            normalized=True)  # (B, W, H, 2)
        x_conv = self.conv(x_stft.unsqueeze(1)).unbind(1)[0]  # (B, W, H, 2)
        x_crm = self.cRM(x_conv, x_stft)
        x_istft = istft(x_crm,
                        self.n_fft,
                        self.hop_len,
                        window=self.window.to(audio.device),
                        normalized=True)

        return x_istft
Ejemplo n.º 8
0
    def time_stretch(self, batch, speedup_rate, device="cuda"):
        if speedup_rate == 1:
            return batch

        n_fft = torch.tensor(2048)  # windowsize
        hop_length = torch.floor(n_fft / 4.0).int().item()

        # time stretch
        stft = torch.stft(batch, n_fft.item(), hop_length=hop_length)
        
        phase_advance = torch.linspace(0, math.pi * hop_length, stft.shape[1])[..., None].to(device)
        # time stretch via phase_vocoder (not differentiable):
        vocoded = AF.phase_vocoder(stft, rate=speedup_rate, phase_advance=phase_advance) 
        istft = AF.istft(vocoded, n_fft.item(), hop_length=hop_length).squeeze()

        return istft
Ejemplo n.º 9
0
Archivo: main.py Proyecto: mori97/MVAE
def baseline_ilrma(val_dataset, device):
    """Evaluate with ILRMA.
    """
    if_use_cuda = device != torch.device('cpu')
    xp = cp if if_use_cuda else np

    window = torch.hann_window(N_FFT).to(device)

    ret = {'SDR': {}, 'SIR': {}, 'SAR': {}}
    for src, mix_spec, speaker in val_dataset:
        if if_use_cuda:
            mix_spec = cp.asarray(mix_spec)
        separated, _ = ilrma(mix_spec, n_iter=100)
        separated = separated.transpose(1, 0, 2)
        # Convert to PyTorch-style complex tensor (Shape = (..., 2))
        separated = xp.stack((xp.real(separated), xp.imag(separated)), axis=-1)
        if if_use_cuda:
            separated = to_tensor(separated)
        else:
            separated = torch.from_numpy(separated)
        with torch.no_grad():
            separated = istft(separated, N_FFT, HOP_LEN, window=window)
            separated = separated.cpu().numpy()

        sdr, sir, sar, _ =\
            mir_eval.separation.bss_eval_sources(src, separated)

        if speaker in ret['SDR']:
            ret['SDR'][speaker].extend(sdr.tolist())
            ret['SIR'][speaker].extend(sir.tolist())
            ret['SAR'][speaker].extend(sar.tolist())
        else:
            ret['SDR'][speaker] = []
            ret['SIR'][speaker] = []
            ret['SAR'][speaker] = []

    for metric in ret:
        for speaker in ret[metric]:
            ret[metric][speaker] = (stat.mean(ret[metric][speaker]),
                                    stat.stdev(ret[metric][speaker]))

    return ret
Ejemplo n.º 10
0
 def compute_istft(amplitude: Tensor, phase: Tensor, n_fft: int,
                   hop_length: int) -> Tensor:
     real = amplitude * torch.cos(phase)
     imag = amplitude * torch.sin(phase)
     stft = torch.stack((real, imag), dim=-1)
     return _func.istft(stft, n_fft=n_fft, hop_length=hop_length)
Ejemplo n.º 11
0
noisy_batch, clean_batch = next(iter(dataloader))

#  enable eval mode
model.zero_grad()
model.eval()
model.freeze()

# disable gradients to save memory
torch.set_grad_enabled(False)

n_fft = (model.n_frequency_bins - 1) * 2

x_waveform = noisy_batch

transform = Spectrogram(n_fft=n_fft, power=None)

x_stft = transform(x_waveform)
y_stft = transform(clean_batch)
x_ms = x_stft.pow(2).sum(-1).sqrt()
y_ms = y_stft.pow(2).sum(-1).sqrt()

y_ms_hat = model(x_ms)

y_stft_hat = torch.stack([y_ms_hat * torch.cos(angle(x_stft)),
                          y_ms_hat * torch.sin(angle(x_stft))], dim=-1)

window = torch.hann_window(n_fft)
y_waveform_hat = istft(y_stft_hat, n_fft=n_fft, hop_length=n_fft // 2, win_length=n_fft, window=window, length=x_waveform.shape[-1])
for i, waveform in enumerate(y_waveform_hat.numpy()):
    sf.write('denoised' + str(i) + '.wav', waveform, 16000)