Ejemplo n.º 1
0
    def _preprocess(self, mixture, true):
        with torch.no_grad():
            mix_spec = self.stft_module.stft(mixture, pad=True)
            mix_amp_spec = taF.complex_norm(mix_spec)
            mix_amp_spec = mix_amp_spec[:, 1:, :]
            mix_mag_spec = torch.log10(mix_amp_spec + self.eps)

            #ex1
            ex1_mix_spec = self.stft_module_ex1.stft(mixture, pad=True)
            ex1_mix_amp_spec = taF.complex_norm(ex1_mix_spec)
            ex1_mix_mag_spec = torch.log10(ex1_mix_amp_spec + self.eps)
            ex1_mix_mag_spec = ex1_mix_mag_spec[:, 1:, 1:513]

            #ex2
            ex2_mix_spec = self.stft_module_ex2.stft(mixture, pad=True)
            ex2_mix_amp_spec = taF.complex_norm(ex2_mix_spec)
            ex2_mix_mag_spec = torch.log10(ex2_mix_amp_spec + self.eps)
            ex2_mix_mag_spec = ex2_mix_mag_spec[:, 1:, :]
            batch_size, f_size, t_size = ex2_mix_mag_spec.shape
            pad_ex2_mix_mag_spec = torch.zeros((batch_size, f_size, 128),
                                               dtype=self.dtype,
                                               device=self.device)
            pad_ex2_mix_mag_spec[:, :1024, :127] = ex2_mix_mag_spec[:, :, :]

            return mix_mag_spec, ex1_mix_mag_spec, pad_ex2_mix_mag_spec, mix_spec
    def _preprocess(self, noisy, clean):
        with torch.no_grad():
            noisy_spec = self.stft_module.stft(noisy, pad=False)
            noisy_amp_spec = taF.complex_norm(noisy_spec)
            noisy_mag_spec = self.stft_module.to_normalize_mag(noisy_amp_spec)
            clean_spec = self.stft_module.stft(clean, pad=False)
            clean_amp_spec = taF.complex_norm(clean_spec)

        return noisy_mag_spec, clean_spec, noisy_spec, noisy_amp_spec, clean_amp_spec
 def _preprocess(self, noisy):
     with torch.no_grad():
         noisy_spec = self.stft_module.stft(noisy, pad=False)
         noisy_amp_spec = taF.complex_norm(noisy_spec)
         noisy_mag_spec = self.stft_module.to_normalize_mag(noisy_amp_spec)
       
         #ex2
         ex2_noisy_spec = self.stft_module_ex2.stft(noisy, pad=False)
         ex2_noisy_amp_spec = taF.complex_norm(ex2_noisy_spec)
         ex2_noisy_mag_spec = self.stft_module_ex2.to_normalize_mag(ex2_noisy_amp_spec)
         
         return noisy_mag_spec, ex2_noisy_mag_spec,  noisy_spec
Ejemplo n.º 4
0
    def _preprocess(self, mixture, true):
        with torch.no_grad():
            mix_spec = self.stft_module.stft(mixture, pad=True)

            mix_amp_spec = taF.complex_norm(mix_spec)
            mix_amp_spec = mix_amp_spec[:, 1:, :]
            mix_mag_spec = torch.log10(mix_amp_spec + self.eps)

            true_spec = self.stft_module.stft(true, pad=True)
            true_amp_spec = taF.complex_norm(true_spec)
            true_amp_spec = true_amp_spec[:, 1:, :]

        return mix_mag_spec, true_amp_spec, mix_amp_spec
    def train(self):
        train_loss = np.array([])
        valid_loss = np.array([])
        print("start train")
        for epoch in range(self.epoch_num):
            # train
            print('epoch{0}'.format(epoch))
            start = time.time()
            self.model.train()
            tmp_train_loss, _, _, _, _ = self._run(
                mode='train', data_loader=self.train_data_loader)
            train_loss = np.append(train_loss,
                                   tmp_train_loss.cpu().clone().numpy())

            self.model.eval()
            with torch.no_grad():
                tmp_valid_loss, est_source, est_mask, noisy_amp_spec, clean_amp_spec = self._run(
                    mode='validation', data_loader=self.valid_data_loader)
                valid_loss = np.append(valid_loss,
                                       tmp_valid_loss.cpu().clone().numpy())

            if (epoch + 1) % 10 == 0:
                plot_time = time.time()
                est_source = taF.complex_norm(est_source)
                show_TF_domein_result(train_loss, valid_loss,
                                      noisy_amp_spec[0, :, :],
                                      est_mask[0, :, :], est_source[0, :, :],
                                      clean_amp_spec[0, :, :])
                print('plot_time:', time.time() - plot_time)
                torch.save(self.model.state_dict(),
                           self.save_path + 'u_net{0}.ckpt'.format(epoch + 1))

            end = time.time()
            print('----excute time: {0}'.format(end - start))
Ejemplo n.º 6
0
 def _preprocess(self, mixture, true_sources):
     mix_spec = self.stft_module.stft(mixture, pad=True)
     mix_phase = mix_spec[:,:,1]
     mix_amp_spec = taF.complex_norm(mix_spec)
     mix_amp_spec = mix_amp_spec[:,1:,:]
     mix_mag_spec = torch.log10(mix_amp_spec + self.eps)
     mix_mag_spec = mix_mag_spec[:,1:,:]
     
     true_sources_spec = self.stft_module.stft_3D(true_sources, pad=True)
     true_sources_amp_spec = taF.complex_norm(true_sources_spec)
     true_sources_amp_spec = true_sources_amp_spec[:,:,1:,:]
     
     true_res = mixture.unsqueeze(1).repeat(1, self.inst_num, 1) - true_sources
     true_res_spec = self.stft_module.stft_3D(true_res, pad=True)
     true_res_amp_spec = taF.complex_norm(true_res_spec)
     return mix_mag_spec, true_sources_amp_spec, true_res_amp_spec, mix_phase, mix_amp_spec
Ejemplo n.º 7
0
def test_complex_norm(complex_tensor, power):
    expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
    norm_tensor = F.complex_norm(complex_tensor, power)

    torch.testing.assert_allclose(norm_tensor,
                                  expected_norm_tensor,
                                  atol=1e-5,
                                  rtol=1e-5)
Ejemplo n.º 8
0
 def test_complex_norm(self, shape, power):
     torch.random.manual_seed(42)
     complex_tensor = torch.randn(*shape)
     expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
     norm_tensor = F.complex_norm(complex_tensor, power)
     self.assertEqual(norm_tensor,
                      expected_norm_tensor,
                      atol=1e-5,
                      rtol=1e-5)
Ejemplo n.º 9
0
    def forward(self, complex_tensor: Tensor) -> Tensor:
        r"""
        Args:
            complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.

        Returns:
            Tensor: norm of the input tensor, shape of `(..., )`.
        """
        return F.complex_norm(complex_tensor, self.power)
Ejemplo n.º 10
0
    def __call__(self, x):
        x = x.stft(self.n_fft,
                   hop_length=self.hop_length,
                   win_length=self.win_length,
                   window=self.window,
                   pad_mode=self.pad_mode)

        if self.normalized:
            x /= self.window.pow(2).sum().sqrt()
        return complex_norm(x, power=self.power)
Ejemplo n.º 11
0
    def __call__(self, x):
        self.window = self.window.to(dtype=x.dtype, device=x.device)
        x = x.stft(self.n_fft,
                   hop_length=self.hop_length,
                   win_length=self.win_length,
                   window=self.window,
                   pad_mode=self.pad_mode)

        if self.normalized:
            x /= self.window.pow(2).sum().sqrt()
        return complex_norm(x).pow(self.power)
Ejemplo n.º 12
0
    def __call__(self, S):
        self.window = self.window.to(dtype=S.dtype, device=S.device)

        S = S.pow(1 / self.power)
        if self.normalized:
            S *= self.window.pow(2).sum().sqrt()

        # randomly initialize the phase
        angles = 2 * math.pi * torch.rand(*S.size())
        angles = torch.stack([angles.cos(), angles.sin()],
                             dim=-1).to(dtype=S.dtype, device=S.device)
        S = S.unsqueeze(-1).expand_as(angles)

        # And initialize the previous iterate to 0
        rebuilt = 0.

        for i in range(self.n_iter):
            print(f'Griffin-Lim iteration {i}/{self.n_iter}')

            # Store the previous iterate
            tprev = rebuilt

            # Invert with our current estimate of the phases
            inverse = istft(S * angles,
                            n_fft=self.n_fft,
                            hop_length=self.hop_length,
                            win_length=self.win_length,
                            window=self.window,
                            length=self.length).float()

            # Rebuild the spectrogram
            rebuilt = inverse.stft(n_fft=self.n_fft,
                                   hop_length=self.hop_length,
                                   win_length=self.win_length,
                                   window=self.window,
                                   pad_mode=self.pad_mode)

            # Update our phase estimates
            angles = rebuilt.sub(self.momentum).mul_(tprev)
            angles = angles.div_(
                complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(
                    angles))

        # Return the final phase estimates
        return istft(S * angles,
                     n_fft=self.n_fft,
                     hop_length=self.hop_length,
                     win_length=self.win_length,
                     window=self.window,
                     length=self.length)
Ejemplo n.º 13
0
    def forward(self, x):
        """
        Input: (batch_size, nb_channels, nb_timesteps)
        Output:() # TODO: find appropriate output
        """
        X = self.transform(x).transpose(-3, -2)

        A, phi = F.complex_norm(X), F.angle(X)

        phase_features = self.compute_features(phi)

        A_hat = self.estimator(A, phase_features)

        phase = torch.stack((torch.cos(phi), torch.sin(phi)), dim=-1)

        Y_hat = A_hat.unsqueeze(-1) * phase

        return Y_hat.transpose(-3, -2)
 def func(tensor):
     power = 2.
     return F.complex_norm(tensor, power)
Ejemplo n.º 15
0
def test_complex_norm(complex_tensor, power):
    expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
    norm_tensor = F.complex_norm(complex_tensor, power)

    assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5)
Ejemplo n.º 16
0
#
# ``torchaudio`` implements ``TimeStrech``, ``TimeMasking`` and
# ``FrequencyMasking``.
#

######################################################################
# TimeStrech
# ~~~~~~~~~~
#

spec = get_spectrogram(power=None)
strech = T.TimeStretch()

rate = 1.2
spec_ = strech(spec, rate)
plot_spectrogram(F.complex_norm(spec_[0]),
                 title=f"Stretched x{rate}",
                 aspect='equal',
                 xmax=304)

plot_spectrogram(F.complex_norm(spec[0]),
                 title="Original",
                 aspect='equal',
                 xmax=304)

rate = 0.9
spec_ = strech(spec, rate)
plot_spectrogram(F.complex_norm(spec_[0]),
                 title=f"Stretched x{rate}",
                 aspect='equal',
                 xmax=304)