def __init__(self, cfg): self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.dtype = torch.float32 self.eps = 1e-8 self.stft_module = STFTModule(cfg['stft_params'], self.device) self.stft_module_ex1 = STFTModule(cfg['stft_params_ex1'], self.device) self.stft_module_ex2 = STFTModule(cfg['stft_params_ex2'], self.device) self.train_data_num = cfg['train_data_num'] self.valid_data_num = cfg['valid_data_num'] self.sample_len = cfg['sample_len'] self.epoch_num = cfg['epoch_num'] self.train_batch_size = cfg['train_batch_size'] self.valid_batch_size = cfg['valid_batch_size'] self.train_full_data_num = cfg['train_full_data_num'] self.valid_full_data_num = cfg['valid_full_data_num'] self.save_path = cfg['save_path'] self.train_dataset = VoicebankDemandDataset( data_num=self.train_data_num, full_data_num=self.train_full_data_num, sample_len=self.sample_len, folder_type='train') self.valid_dataset = VoicebankDemandDataset( data_num=self.valid_data_num, full_data_num=self.valid_full_data_num, sample_len=self.sample_len, folder_type='validation') self.train_data_loader = FastDataLoader( self.train_dataset, batch_size=self.train_batch_size, shuffle=True) self.valid_data_loader = FastDataLoader( self.valid_dataset, batch_size=self.valid_batch_size, shuffle=True) self.model = DemandUNet_pp().to(self.device) self.criterion = PSA() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3) self.early_stopping = EarlyStopping(patience=10)
def __init__(self, cfg): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.dtype= torch.float32 self.eps = 1e-8 self.stft_module = STFTModule(cfg['stft_params'], self.device) self.train_data_num = cfg['train_data_num'] self.valid_data_num = cfg['valid_data_num'] self.sample_len = cfg['sample_len'] self.epoch_num = cfg['epoch_num'] self.train_batch_size = cfg['train_batch_size'] self.valid_batch_size = cfg['valid_batch_size'] self.train_full_data_num = cfg['train_full_data_num'] self.valid_full_data_num = cfg['valid_full_data_num'] self.train_dataset = DSD100Dataset(data_num=self.train_data_num, full_data_num=self.train_full_data_num, sample_len=self.sample_len, folder_type='Dev', shuffle=True, device=self.device, augmentation=True) self.valid_dataset =DSD100Dataset(data_num=self.valid_data_num, full_data_num=self.valid_full_data_num, sample_len=self.sample_len, folder_type='Validation', shuffle=True, device=self.device, augmentation=False) self.train_data_loader = FastDataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=True) self.valid_data_loader = FastDataLoader(self.valid_dataset, batch_size=self.valid_batch_size, shuffle=True) self.model = OpenUnmix(cfg['dnn_cfg']).to(self.device) self.criterion = MSE() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3) self.save_path = cfg['save_path'] self.early_stopping = EarlyStopping(patience=10)
class DemandUNet_pp_Runner(): def __init__(self, cfg): self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.dtype = torch.float32 self.eps = 1e-8 self.stft_module = STFTModule(cfg['stft_params'], self.device) self.stft_module_ex1 = STFTModule(cfg['stft_params_ex1'], self.device) self.stft_module_ex2 = STFTModule(cfg['stft_params_ex2'], self.device) self.train_data_num = cfg['train_data_num'] self.valid_data_num = cfg['valid_data_num'] self.sample_len = cfg['sample_len'] self.epoch_num = cfg['epoch_num'] self.train_batch_size = cfg['train_batch_size'] self.valid_batch_size = cfg['valid_batch_size'] self.train_full_data_num = cfg['train_full_data_num'] self.valid_full_data_num = cfg['valid_full_data_num'] self.save_path = cfg['save_path'] self.train_dataset = VoicebankDemandDataset( data_num=self.train_data_num, full_data_num=self.train_full_data_num, sample_len=self.sample_len, folder_type='train') self.valid_dataset = VoicebankDemandDataset( data_num=self.valid_data_num, full_data_num=self.valid_full_data_num, sample_len=self.sample_len, folder_type='validation') self.train_data_loader = FastDataLoader( self.train_dataset, batch_size=self.train_batch_size, shuffle=True) self.valid_data_loader = FastDataLoader( self.valid_dataset, batch_size=self.valid_batch_size, shuffle=True) self.model = DemandUNet_pp().to(self.device) self.criterion = PSA() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3) self.early_stopping = EarlyStopping(patience=10) def _preprocess(self, noisy, clean): with torch.no_grad(): noisy_spec = self.stft_module.stft(noisy, pad=False) noisy_amp_spec = taF.complex_norm(noisy_spec) noisy_mag_spec = self.stft_module.to_normalize_mag(noisy_amp_spec) clean_spec = self.stft_module.stft(clean, pad=False) clean_amp_spec = taF.complex_norm(clean_spec) #ex1 ex1_noisy_spec = self.stft_module_ex1.stft(noisy, pad=False) ex1_noisy_amp_spec = taF.complex_norm(ex1_noisy_spec) ex1_noisy_mag_spec = self.stft_module_ex1.to_normalize_mag( ex1_noisy_amp_spec) #ex2 ex2_noisy_spec = self.stft_module_ex2.stft(noisy, pad=False) ex2_noisy_amp_spec = taF.complex_norm(ex2_noisy_spec) ex2_noisy_mag_spec = self.stft_module_ex2.to_normalize_mag( ex2_noisy_amp_spec) return noisy_mag_spec, ex1_noisy_mag_spec, ex2_noisy_mag_spec, clean_amp_spec, noisy_amp_spec, noisy_spec, clean_spec def _run(self, mode=None, data_loader=None): running_loss = 0 for i, (noisy, clean) in enumerate(data_loader): noisy = noisy.to(self.dtype).to(self.device) clean = clean.to(self.dtype).to(self.device) noisy_mag_spec, ex1_noisy_mag_spec, ex2_noisy_mag_spec, clean_amp_spec, noisy_amp_spec, noisy_spec, clean_spec = self._preprocess( noisy, clean) self.model.zero_grad() est_mask = self.model(noisy_mag_spec, ex1_noisy_mag_spec, ex2_noisy_mag_spec) est_source = noisy_spec * est_mask[..., None] if mode == 'train' or mode == 'validation': loss = 10 * self.criterion(est_source, clean_spec) running_loss += loss.data if mode == 'train': loss.backward() self.optimizer.step() return (running_loss / (i + 1)), est_source, est_mask, noisy_amp_spec, clean_amp_spec def train(self): train_loss = np.array([]) valid_loss = np.array([]) print("start train") for epoch in range(self.epoch_num): # train print('epoch{0}'.format(epoch)) start = time.time() self.model.train() tmp_train_loss, _, _, _, _ = self._run( mode='train', data_loader=self.train_data_loader) train_loss = np.append(train_loss, tmp_train_loss.cpu().clone().numpy()) self.model.eval() with torch.no_grad(): tmp_valid_loss, est_source, est_mask, noisy_amp_spec, clean_amp_spec = self._run( mode='validation', data_loader=self.valid_data_loader) if self.early_stopping.validation(tmp_valid_loss): torch.save( self.model.state_dict(), self.save_path + 'u_net{0}.ckpt'.format(epoch + 1)) break valid_loss = np.append(valid_loss, tmp_valid_loss.cpu().clone().numpy()) est_source = taF.complex_norm(est_source) if (epoch + 1) % 10 == 0: show_TF_domein_result(train_loss, valid_loss, noisy_amp_spec[0, :, :], est_mask[0, :, :], est_source[0, :, :], clean_amp_spec[0, :, :]) torch.save(self.model.state_dict(), self.save_path + 'u_net{0}.ckpt'.format(epoch + 1)) end = time.time() print('----excute time: {0}'.format(end - start))