class DSDOpenUnmixTester():
    def __init__(self, cfg):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.dtype= torch.float32
        self.eps = 1e-4
        self.eval_path = cfg['eval_path']
        
        self.model =  OpenUnmix(cfg['dnn_cfg']).to(self.device)
        self.model.eval()
        self.model.load_state_dict(torch.load(self.eval_path, map_location=self.device))
        
        self.stft_module = STFTModule(cfg['stft_params'], self.device)
        
        self.test_data_num = cfg['test_data_num']
        self.test_batch_size = cfg['test_batch_size']
        self.sample_len = cfg['sample_len']

        self.test_dataset = DSD100Dataset(data_num=self.test_data_num, 
                                          sample_len=self.sample_len, 
                                          folder_type='Test', 
                                          shuffle=False,
                                          device=self.device,
                                          augmentation=False)
        
        self.test_data_loader =  FastDataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False)
        
        self.sdr_list = np.array([])
        self.sir_list = np.array([])
        self.sar_list = np.array([])
        self.si_sdr_list = np.array([])
        self.si_sdr_improve_list = np.array([])
    
    def _preprocess(self, noisy):
        with torch.no_grad():
            noisy_spec = self.stft_module.stft(noisy, pad=None)
            noisy_amp_spec = taF.complex_norm(noisy_spec)
            noisy_mag_spec = self.stft_module.to_normalize_mag(noisy_amp_spec)
            
            return noisy_mag_spec, noisy_spec
    
            
    def test(self, mode='test'):
        with torch.no_grad():
            for i, (noisy, _, _, _, clean) in enumerate(self.test_data_loader):
                start = time.time()
                noisy = noisy.to(self.dtype).to(self.device)
                clean = clean.to(self.dtype).to(self.device)
                siglen = noisy.shape[1]
                noisy_mag_spec, noisy_spec = self._preprocess(noisy)
                est_mask = self.model(noisy_mag_spec)
                est_source = noisy_spec * est_mask[...,None]
                est_wave = self.stft_module.istft(est_source, siglen)
                print(est_wave.shape)
                est_wave = est_wave.squeeze(0)
                clean = clean.squeeze(0)
                noisy = noisy.squeeze(0)
                                
                sdr, sir, sar, si_sdr, si_sdr_improve = mss_evals(est_wave, clean, noisy)
                self.sdr_list = np.append(self.sdr_list, sdr)
                self.sir_list = np.append(self.sir_list, sir)
                self.sar_list = np.append(self.sar_list, sar)
                self.si_sdr_list = np.append(self.si_sdr_list, si_sdr)
                self.si_sdr_improve_list = np.append(self.si_sdr_improve_list, si_sdr_improve)
                print('test time:', time.time() - start)
                
            print('sdr mean:', np.mean(self.sdr_list))
            print('sir mean:', np.mean(self.sir_list))
            print('sar mean:', np.mean(self.sar_list))
            print('si-sdr mean:', np.mean(self.si_sdr_list))
            print('sdr improve mean:', np.mean(self.si_sdr_improve_list))
class DemandCNNOpenUnmix_p2_Tester():
    def __init__(self, cfg):
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.dtype = torch.float32
        self.eps = 1e-4
        self.eval_path = cfg['eval_path']

        self.model = CNNOpenUnmix_p2(cfg['dnn_cfg']).to(self.device)
        self.model.eval()
        self.model.load_state_dict(
            torch.load(self.eval_path, map_location=self.device))

        self.stft_module = STFTModule(cfg['stft_params'], self.device)
        self.stft_module_ex2 = STFTModule(cfg['stft_params_ex2'], self.device)

        self.test_data_num = cfg['test_data_num']
        self.test_batch_size = cfg['test_batch_size']
        self.sample_len = cfg['sample_len']

        self.test_dataset = VoicebankDemandDataset(data_num=self.test_data_num,
                                                   sample_len=self.sample_len,
                                                   folder_type='test',
                                                   shuffle=False)

        self.test_data_loader = FastDataLoader(self.test_dataset,
                                               batch_size=self.test_batch_size,
                                               shuffle=False)

        self.stoi_list = np.array([])
        self.pesq_list = np.array([])
        self.si_sdr_list = np.array([])
        self.si_sdr_improve_list = np.array([])

    def _preprocess(self, noisy):
        with torch.no_grad():
            noisy_spec = self.stft_module.stft(noisy, pad=False)
            noisy_amp_spec = taF.complex_norm(noisy_spec)
            noisy_mag_spec = self.stft_module.to_normalize_mag(noisy_amp_spec)

            #ex2
            ex2_noisy_spec = self.stft_module_ex2.stft(noisy, pad=False)
            ex2_noisy_amp_spec = taF.complex_norm(ex2_noisy_spec)
            ex2_noisy_mag_spec = self.stft_module_ex2.to_normalize_mag(
                ex2_noisy_amp_spec)

            return noisy_mag_spec, ex2_noisy_mag_spec, noisy_spec

    def test(self, mode='test'):
        with torch.no_grad():
            for i, (noisy, clean) in enumerate(self.test_data_loader):
                start = time.time()
                noisy = noisy.to(self.dtype).to(self.device)
                clean = clean.to(self.dtype).to(self.device)
                siglen = noisy.shape[1]
                noisy_mag_spec, ex2_noisy_mag_spec, noisy_spec = self._preprocess(
                    noisy)
                est_mask = self.model(noisy_mag_spec, ex2_noisy_mag_spec)
                est_source = noisy_spec * est_mask[..., None]
                est_wave = self.stft_module.istft(est_source, siglen)
                print(est_wave.shape)
                est_wave = est_wave.squeeze(0)
                clean = clean.squeeze(0)
                noisy = noisy.squeeze(0)

                pesq_val, stoi_val, si_sdr_val, si_sdr_improve = sp_enhance_evals(
                    est_wave, clean, noisy, fs=16000)
                self.pesq_list = np.append(self.pesq_list, pesq_val)
                self.stoi_list = np.append(self.stoi_list, stoi_val)
                self.si_sdr_list = np.append(self.si_sdr_list, si_sdr_val)
                self.si_sdr_improve_list = np.append(self.si_sdr_improve_list,
                                                     si_sdr_improve)
                print('test time:', time.time() - start)

            print('pesq mean:', np.mean(self.pesq_list))
            print('stoi mean:', np.mean(self.stoi_list))
            print('sdr mean:', np.mean(self.si_sdr_list))
            print('sdr improve mena:', np.mean(self.si_sdr_improve_list))
Ejemplo n.º 3
0
class UNet_pp_Tester():
    def __init__(self, cfg):
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.dtype = torch.float32
        self.eps = 1e-4
        self.eval_path = cfg['eval_path']

        self.model = UNet_pp().to(self.device)
        self.model.eval()
        self.model.load_state_dict(
            torch.load(self.eval_path, map_location=self.device))

        self.stft_module = STFTModule(cfg['stft_params'], self.device)
        self.stft_module_ex1 = STFTModule(cfg['stft_params_ex1'], self.device)
        self.stft_module_ex2 = STFTModule(cfg['stft_params_ex2'], self.device)

        self.test_data_num = cfg['test_data_num']
        self.test_batch_size = cfg['test_batch_size']
        self.sample_len = cfg['sample_len']
        self.test_dataset = DSD100Dataset(data_num=self.test_data_num,
                                          sample_len=self.sample_len,
                                          folder_type='test',
                                          device=self.device,
                                          shuffle=False)
        self.test_data_loader = FastDataLoader(self.test_dataset,
                                               batch_size=self.test_batch_size,
                                               shuffle=False)

        self.sdr_list = np.array([])
        self.sar_list = np.array([])
        self.sir_list = np.array([])

    def _preprocess(self, mixture, true):
        with torch.no_grad():
            mix_spec = self.stft_module.stft(mixture, pad=True)
            mix_amp_spec = taF.complex_norm(mix_spec)
            mix_amp_spec = mix_amp_spec[:, 1:, :]
            mix_mag_spec = torch.log10(mix_amp_spec + self.eps)

            #ex1
            ex1_mix_spec = self.stft_module_ex1.stft(mixture, pad=True)
            ex1_mix_amp_spec = taF.complex_norm(ex1_mix_spec)
            ex1_mix_mag_spec = torch.log10(ex1_mix_amp_spec + self.eps)
            ex1_mix_mag_spec = ex1_mix_mag_spec[:, 1:, 1:513]

            #ex2
            ex2_mix_spec = self.stft_module_ex2.stft(mixture, pad=True)
            ex2_mix_amp_spec = taF.complex_norm(ex2_mix_spec)
            ex2_mix_mag_spec = torch.log10(ex2_mix_amp_spec + self.eps)
            ex2_mix_mag_spec = ex2_mix_mag_spec[:, 1:, :]
            batch_size, f_size, t_size = ex2_mix_mag_spec.shape
            pad_ex2_mix_mag_spec = torch.zeros((batch_size, f_size, 128),
                                               dtype=self.dtype,
                                               device=self.device)
            pad_ex2_mix_mag_spec[:, :1024, :127] = ex2_mix_mag_spec[:, :, :]

            return mix_mag_spec, ex1_mix_mag_spec, pad_ex2_mix_mag_spec, mix_spec

    def _postprocess(self, x):
        x = x.squeeze(1)
        batch_size, f_size, t_size = x.shape
        pad_x = torch.zeros((batch_size, f_size + 2, t_size),
                            dtype=self.dtype,
                            device=self.device)
        pad_x[:, 1:-1, :] = x[:, :, :]
        return pad_x

    def test(self, mode='test'):
        with torch.no_grad():
            for i, (mixture, _, _, _,
                    vocals) in enumerate(self.test_data_loader):
                start = time.time()
                mixture = mixture.squeeze(0).to(self.dtype).to(self.device)
                true = vocals.squeeze(0).to(self.dtype).to(self.device)

                mix_mag_spec, ex1_mix_mag_spec, ex2_mix_mag_spec, mix_spec = self._preprocess(
                    mixture, true)
                est_mask = self.model(mix_mag_spec.unsqueeze(1),
                                      ex1_mix_mag_spec.unsqueeze(1),
                                      ex2_mix_mag_spec.unsqueeze(1))
                est_mask = self._postprocess(est_mask)
                est_source = mix_spec * est_mask[..., None]
                est_wave = self.stft_module.istft(est_source)

                est_wave = est_wave.flatten()
                mixture = mixture.flatten()
                true = true.flatten()
                true_accompany = mixture - true
                est_accompany = mixture - est_wave
                sdr, sir, sar = mss_evals(est_wave, est_accompany, true,
                                          true_accompany)
                self.sdr_list = np.append(self.sdr_list, sdr)
                self.sar_list = np.append(self.sar_list, sar)
                self.sir_list = np.append(self.sir_list, sir)
                print('test time:', time.time() - start)

            print('sdr mean:', np.mean(self.sdr_list))
            print('sir mean:', np.mean(self.sir_list))
            print('sar mean:', np.mean(self.sar_list))