def _preprocess(self, mixture, true): with torch.no_grad(): mix_spec = self.stft_module.stft(mixture, pad=True) mix_amp_spec = taF.complex_norm(mix_spec) mix_amp_spec = mix_amp_spec[:, 1:, :] mix_mag_spec = torch.log10(mix_amp_spec + self.eps) #ex1 ex1_mix_spec = self.stft_module_ex1.stft(mixture, pad=True) ex1_mix_amp_spec = taF.complex_norm(ex1_mix_spec) ex1_mix_mag_spec = torch.log10(ex1_mix_amp_spec + self.eps) ex1_mix_mag_spec = ex1_mix_mag_spec[:, 1:, 1:513] #ex2 ex2_mix_spec = self.stft_module_ex2.stft(mixture, pad=True) ex2_mix_amp_spec = taF.complex_norm(ex2_mix_spec) ex2_mix_mag_spec = torch.log10(ex2_mix_amp_spec + self.eps) ex2_mix_mag_spec = ex2_mix_mag_spec[:, 1:, :] batch_size, f_size, t_size = ex2_mix_mag_spec.shape pad_ex2_mix_mag_spec = torch.zeros((batch_size, f_size, 128), dtype=self.dtype, device=self.device) pad_ex2_mix_mag_spec[:, :1024, :127] = ex2_mix_mag_spec[:, :, :] return mix_mag_spec, ex1_mix_mag_spec, pad_ex2_mix_mag_spec, mix_spec
def _preprocess(self, noisy, clean): with torch.no_grad(): noisy_spec = self.stft_module.stft(noisy, pad=False) noisy_amp_spec = taF.complex_norm(noisy_spec) noisy_mag_spec = self.stft_module.to_normalize_mag(noisy_amp_spec) clean_spec = self.stft_module.stft(clean, pad=False) clean_amp_spec = taF.complex_norm(clean_spec) return noisy_mag_spec, clean_spec, noisy_spec, noisy_amp_spec, clean_amp_spec
def _preprocess(self, noisy): with torch.no_grad(): noisy_spec = self.stft_module.stft(noisy, pad=False) noisy_amp_spec = taF.complex_norm(noisy_spec) noisy_mag_spec = self.stft_module.to_normalize_mag(noisy_amp_spec) #ex2 ex2_noisy_spec = self.stft_module_ex2.stft(noisy, pad=False) ex2_noisy_amp_spec = taF.complex_norm(ex2_noisy_spec) ex2_noisy_mag_spec = self.stft_module_ex2.to_normalize_mag(ex2_noisy_amp_spec) return noisy_mag_spec, ex2_noisy_mag_spec, noisy_spec
def _preprocess(self, mixture, true): with torch.no_grad(): mix_spec = self.stft_module.stft(mixture, pad=True) mix_amp_spec = taF.complex_norm(mix_spec) mix_amp_spec = mix_amp_spec[:, 1:, :] mix_mag_spec = torch.log10(mix_amp_spec + self.eps) true_spec = self.stft_module.stft(true, pad=True) true_amp_spec = taF.complex_norm(true_spec) true_amp_spec = true_amp_spec[:, 1:, :] return mix_mag_spec, true_amp_spec, mix_amp_spec
def train(self): train_loss = np.array([]) valid_loss = np.array([]) print("start train") for epoch in range(self.epoch_num): # train print('epoch{0}'.format(epoch)) start = time.time() self.model.train() tmp_train_loss, _, _, _, _ = self._run( mode='train', data_loader=self.train_data_loader) train_loss = np.append(train_loss, tmp_train_loss.cpu().clone().numpy()) self.model.eval() with torch.no_grad(): tmp_valid_loss, est_source, est_mask, noisy_amp_spec, clean_amp_spec = self._run( mode='validation', data_loader=self.valid_data_loader) valid_loss = np.append(valid_loss, tmp_valid_loss.cpu().clone().numpy()) if (epoch + 1) % 10 == 0: plot_time = time.time() est_source = taF.complex_norm(est_source) show_TF_domein_result(train_loss, valid_loss, noisy_amp_spec[0, :, :], est_mask[0, :, :], est_source[0, :, :], clean_amp_spec[0, :, :]) print('plot_time:', time.time() - plot_time) torch.save(self.model.state_dict(), self.save_path + 'u_net{0}.ckpt'.format(epoch + 1)) end = time.time() print('----excute time: {0}'.format(end - start))
def _preprocess(self, mixture, true_sources): mix_spec = self.stft_module.stft(mixture, pad=True) mix_phase = mix_spec[:,:,1] mix_amp_spec = taF.complex_norm(mix_spec) mix_amp_spec = mix_amp_spec[:,1:,:] mix_mag_spec = torch.log10(mix_amp_spec + self.eps) mix_mag_spec = mix_mag_spec[:,1:,:] true_sources_spec = self.stft_module.stft_3D(true_sources, pad=True) true_sources_amp_spec = taF.complex_norm(true_sources_spec) true_sources_amp_spec = true_sources_amp_spec[:,:,1:,:] true_res = mixture.unsqueeze(1).repeat(1, self.inst_num, 1) - true_sources true_res_spec = self.stft_module.stft_3D(true_res, pad=True) true_res_amp_spec = taF.complex_norm(true_res_spec) return mix_mag_spec, true_sources_amp_spec, true_res_amp_spec, mix_phase, mix_amp_spec
def test_complex_norm(complex_tensor, power): expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2) norm_tensor = F.complex_norm(complex_tensor, power) torch.testing.assert_allclose(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
def test_complex_norm(self, shape, power): torch.random.manual_seed(42) complex_tensor = torch.randn(*shape) expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2) norm_tensor = F.complex_norm(complex_tensor, power) self.assertEqual(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
def forward(self, complex_tensor: Tensor) -> Tensor: r""" Args: complex_tensor (Tensor): Tensor shape of `(..., complex=2)`. Returns: Tensor: norm of the input tensor, shape of `(..., )`. """ return F.complex_norm(complex_tensor, self.power)
def __call__(self, x): x = x.stft(self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, pad_mode=self.pad_mode) if self.normalized: x /= self.window.pow(2).sum().sqrt() return complex_norm(x, power=self.power)
def __call__(self, x): self.window = self.window.to(dtype=x.dtype, device=x.device) x = x.stft(self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, pad_mode=self.pad_mode) if self.normalized: x /= self.window.pow(2).sum().sqrt() return complex_norm(x).pow(self.power)
def __call__(self, S): self.window = self.window.to(dtype=S.dtype, device=S.device) S = S.pow(1 / self.power) if self.normalized: S *= self.window.pow(2).sum().sqrt() # randomly initialize the phase angles = 2 * math.pi * torch.rand(*S.size()) angles = torch.stack([angles.cos(), angles.sin()], dim=-1).to(dtype=S.dtype, device=S.device) S = S.unsqueeze(-1).expand_as(angles) # And initialize the previous iterate to 0 rebuilt = 0. for i in range(self.n_iter): print(f'Griffin-Lim iteration {i}/{self.n_iter}') # Store the previous iterate tprev = rebuilt # Invert with our current estimate of the phases inverse = istft(S * angles, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, length=self.length).float() # Rebuild the spectrogram rebuilt = inverse.stft(n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, pad_mode=self.pad_mode) # Update our phase estimates angles = rebuilt.sub(self.momentum).mul_(tprev) angles = angles.div_( complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as( angles)) # Return the final phase estimates return istft(S * angles, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=self.window, length=self.length)
def forward(self, x): """ Input: (batch_size, nb_channels, nb_timesteps) Output:() # TODO: find appropriate output """ X = self.transform(x).transpose(-3, -2) A, phi = F.complex_norm(X), F.angle(X) phase_features = self.compute_features(phi) A_hat = self.estimator(A, phase_features) phase = torch.stack((torch.cos(phi), torch.sin(phi)), dim=-1) Y_hat = A_hat.unsqueeze(-1) * phase return Y_hat.transpose(-3, -2)
def func(tensor): power = 2. return F.complex_norm(tensor, power)
def test_complex_norm(complex_tensor, power): expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2) norm_tensor = F.complex_norm(complex_tensor, power) assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5)
# # ``torchaudio`` implements ``TimeStrech``, ``TimeMasking`` and # ``FrequencyMasking``. # ###################################################################### # TimeStrech # ~~~~~~~~~~ # spec = get_spectrogram(power=None) strech = T.TimeStretch() rate = 1.2 spec_ = strech(spec, rate) plot_spectrogram(F.complex_norm(spec_[0]), title=f"Stretched x{rate}", aspect='equal', xmax=304) plot_spectrogram(F.complex_norm(spec[0]), title="Original", aspect='equal', xmax=304) rate = 0.9 spec_ = strech(spec, rate) plot_spectrogram(F.complex_norm(spec_[0]), title=f"Stretched x{rate}", aspect='equal', xmax=304)