class NetFeeder(object): def __init__(self, device, win_size=320, hop_size=160): self.eps = torch.finfo(torch.float32).eps self.stft = STFT(win_size, hop_size).to(device) def __call__(self, mix, sph): real_mix, imag_mix = self.stft.stft(mix) feat = torch.stack([real_mix, imag_mix], dim=1) real_sph, imag_sph = self.stft.stft(sph) lbl = torch.stack([real_sph, imag_sph], dim=1) return feat, lbl
class NetFeeder(object): def __init__(self, device, win_size=320, hop_size=160): self.eps = torch.finfo(torch.float32).eps self.stft = STFT(win_size, hop_size).to(device) def __call__(self, mix, sph): real_mix, imag_mix = self.stft.stft(mix) mag_mix = torch.sqrt(real_mix**2 + imag_mix**2) feat = mag_mix real_sph, imag_sph = self.stft.stft(sph) mag_sph = torch.sqrt(real_sph**2 + imag_sph**2) lbl = mag_sph return feat, lbl
class Resynthesizer(object): def __init__(self, device, win_size=320, hop_size=160): self.stft = STFT(win_size, hop_size).to(device) def __call__(self, est, mix): real_mix, imag_mix = self.stft.stft(mix) pha_mix = torch.atan2(imag_mix.data, real_mix.data) real_est = est * torch.cos(pha_mix) imag_est = est * torch.sin(pha_mix) sph_est = self.stft.istft(torch.stack([real_est, imag_est], dim=1)) sph_est = F.pad(sph_est, [0, mix.shape[1]-sph_est.shape[1]]) return sph_est