def validation(self, epoch): self.logger.info( 'Start Validation from epoch: {:d}, iter: {:d}'.format(epoch, 1)) self.dpcl.eval() num_batchs = len(self.val_dataloader) num_index = 1 total_loss = 0.0 start_time = time.time() with torch.no_grad(): for mix_wave, target_waves, non_slient in self.val_dataloader: mix_wave = mix_wave.to(self.device) target_waves = target_waves.to(self.device) non_slient = non_slient.to(self.device) mix_embs = self.dpcl(mix_wave) l = Loss(mix_embs, target_waves, non_slient, self.num_spks) epoch_loss = l.loss() total_loss += epoch_loss.item() if num_index % self.print_freq == 0: message = '<epoch:{:d}, iter:{:d}, lr:{:.3e}, loss:{:.3f}>'.format( epoch, num_index, self.optimizer.param_groups[0]['lr'], total_loss / num_index) self.logger.info(message) num_index += 1 end_time = time.time() total_loss = total_loss / num_batchs message = '<epoch:{:d}, iter:{:d}, lr:{:.3e}, loss:{:.3f}, Total time:{:.3f} min> '.format( epoch, num_batchs, self.optimizer.param_groups[0]['lr'], total_loss, (end_time - start_time) / 60) self.logger.info(message) return total_loss
def train(self, epoch): self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(epoch, 1)) self.dpcl.train() num_batchs = len(self.train_dataloader) total_loss = 0.0 num_index = 1 start_time = time.time() for mix_wave, target_waves, non_slient in self.train_dataloader: mix_wave = mix_wave.to(self.device) target_waves = target_waves.to(self.device) non_slient = non_slient.to(self.device) model = torch.nn.DataParallel(self.dpcl) mix_embs = model(mix_wave) l = Loss(mix_embs, target_waves, non_slient, self.num_spks) epoch_loss = l.loss() total_loss += epoch_loss.item() self.optimizer.zero_grad() epoch_loss.backward() if self.clip_norm: torch.nn.utils.clip_grad_norm_(self.dpcl.parameters(),self.clip_norm) self.optimizer.step() if num_index % self.print_freq == 0: message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}>, loss:{:.3f}'.format( epoch, num_index, self.optimizer.param_groups[0]['lr'], total_loss/num_index) self.logger.info(message) end_time = time.time() total_loss = total_loss/num_batchs message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, loss:{:.3f}, Total time:{:.3f} min> '.format( epoch, num_batchs, self.optimizer.param_groups[0]['lr'], total_loss, (end_time-start_time)/60) self.logger.info(message) return total_loss
def train(self, epoch): self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format( epoch, 1)) self.danet.train() num_batchs = len(self.train_dataloader) total_loss = 0.0 num_index = 1 start_time = time.time() for mix_samp, wf, ibm, non_silent in self.train_dataloader: mix_samp = Variable(mix_samp).contiguous().to(self.device) wf = Variable(wf).contiguous().to(self.device) ibm = Variable(ibm).contiguous().to(self.device) non_silent = Variable(non_silent).contiguous().to(self.device) hidden = self.danet.init_hidden(mix_samp.size(0)) input_list = [mix_samp, ibm, non_silent, hidden] self.optimizer.zero_grad() if self.gpuid: #mask=torch.nn.parallel.data_parallel(self.danet,input_list,device_ids=self.gpuid) mask, hidden = self.danet(input_list) else: mask, hidden = self.danet(mix_samp, ibm, non_silent) l = Loss(mix_samp, wf, mask) epoch_loss = l.loss() total_loss += epoch_loss.item() epoch_loss.backward() #if self.clip_norm: # torch.nn.utils.clip_grad_norm_( # self.danet.parameters(), self.clip_norm) self.optimizer.step() if num_index % self.print_freq == 0: message = '<epoch:{:d}, iter:{:d}, lr:{:.3e}, loss:{:.3f}>'.format( epoch, num_index, self.optimizer.param_groups[0]['lr'], total_loss / num_index) self.logger.info(message) num_index += 1 end_time = time.time() total_loss = total_loss / num_index message = '<epoch:{:d}, iter:{:d}, lr:{:.3e}, loss:{:.3f}, Total time:{:.3f} min> '.format( epoch, num_index, self.optimizer.param_groups[0]['lr'], total_loss, (end_time - start_time) / 60) self.logger.info(message) return total_loss
def validation(self, epoch): self.logger.info( 'Start Validation from epoch: {:d}, iter: {:d}'.format(epoch, 1)) self.danet.eval() num_batchs = len(self.val_dataloader) num_index = 1 total_loss = 0.0 start_time = time.time() with torch.no_grad(): for mix_samp, wf, ibm, non_silent in self.val_dataloader: mix_samp = Variable(mix_samp).contiguous().to(self.device) wf = Variable(wf).contiguous().to(self.device) ibm = Variable(ibm).contiguous().to(self.device) non_silent = Variable(non_silent).contiguous().to(self.device) hidden = self.danet.init_hidden(mix_samp.size(0)) input_list = [mix_samp, ibm, non_silent, hidden] if self.gpuid: #mask=torch.nn.parallel.data_parallel(self.danet,input_list,device_ids=self.gpuid) mask, hidden = self.danet(input_list) else: mask, hidden = self.danet(mix_samp, ibm, non_silent) l = Loss(mix_samp, wf, mask) epoch_loss = l.loss() total_loss += epoch_loss.item() if num_index % self.print_freq == 0: message = '<epoch:{:d}, iter:{:d}, lr:{:.3e}, loss:{:.3f}>'.format( epoch, num_index, self.optimizer.param_groups[0]['lr'], total_loss / num_index) self.logger.info(message) num_index += 1 end_time = time.time() total_loss = total_loss / num_index message = '<epoch:{:d}, iter:{:d}, lr:{:.3e}, loss:{:.3f}, Total time:{:.3f} min> '.format( epoch, num_index, self.optimizer.param_groups[0]['lr'], total_loss, (end_time - start_time) / 60) self.logger.info(message) return total_loss