예제 #1
0
 def validation(self, epoch):
     self.logger.info(
         'Start Validation from epoch: {:d}, iter: {:d}'.format(epoch, 1))
     self.dpcl.eval()
     num_batchs = len(self.val_dataloader)
     num_index = 1
     total_loss = 0.0
     start_time = time.time()
     with torch.no_grad():
         for mix_wave, target_waves, non_slient in self.val_dataloader:
             mix_wave = mix_wave.to(self.device)
             target_waves = target_waves.to(self.device)
             non_slient = non_slient.to(self.device)
             mix_embs = self.dpcl(mix_wave)
             l = Loss(mix_embs, target_waves, non_slient, self.num_spks)
             epoch_loss = l.loss()
             total_loss += epoch_loss.item()
             if num_index % self.print_freq == 0:
                 message = '<epoch:{:d}, iter:{:d}, lr:{:.3e}, loss:{:.3f}>'.format(
                     epoch, num_index, self.optimizer.param_groups[0]['lr'],
                     total_loss / num_index)
                 self.logger.info(message)
             num_index += 1
     end_time = time.time()
     total_loss = total_loss / num_batchs
     message = '<epoch:{:d}, iter:{:d}, lr:{:.3e}, loss:{:.3f}, Total time:{:.3f} min> '.format(
         epoch, num_batchs, self.optimizer.param_groups[0]['lr'],
         total_loss, (end_time - start_time) / 60)
     self.logger.info(message)
     return total_loss
예제 #2
0
    def train(self, epoch):
        self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(epoch, 1))
        self.dpcl.train()
        num_batchs = len(self.train_dataloader)
        total_loss = 0.0
        num_index = 1
        start_time = time.time()
        for mix_wave, target_waves, non_slient in self.train_dataloader:
            mix_wave = mix_wave.to(self.device)
            target_waves = target_waves.to(self.device)
            non_slient = non_slient.to(self.device)
            model = torch.nn.DataParallel(self.dpcl)
            mix_embs = model(mix_wave)
            l = Loss(mix_embs, target_waves, non_slient, self.num_spks)
            epoch_loss = l.loss()
            total_loss += epoch_loss.item()
            self.optimizer.zero_grad()
            epoch_loss.backward()
            
            if self.clip_norm:
                torch.nn.utils.clip_grad_norm_(self.dpcl.parameters(),self.clip_norm)

            self.optimizer.step()
            if num_index % self.print_freq == 0:
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}>, loss:{:.3f}'.format(
                    epoch, num_index, self.optimizer.param_groups[0]['lr'], total_loss/num_index)
                self.logger.info(message)
        end_time = time.time()
        total_loss = total_loss/num_batchs
        message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, loss:{:.3f}, Total time:{:.3f} min> '.format(
            epoch, num_batchs, self.optimizer.param_groups[0]['lr'], total_loss, (end_time-start_time)/60)
        self.logger.info(message)
        return total_loss
예제 #3
0
    def train(self, epoch):
        self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
            epoch, 1))
        self.danet.train()
        num_batchs = len(self.train_dataloader)
        total_loss = 0.0
        num_index = 1
        start_time = time.time()
        for mix_samp, wf, ibm, non_silent in self.train_dataloader:
            mix_samp = Variable(mix_samp).contiguous().to(self.device)
            wf = Variable(wf).contiguous().to(self.device)
            ibm = Variable(ibm).contiguous().to(self.device)
            non_silent = Variable(non_silent).contiguous().to(self.device)

            hidden = self.danet.init_hidden(mix_samp.size(0))

            input_list = [mix_samp, ibm, non_silent, hidden]
            self.optimizer.zero_grad()

            if self.gpuid:
                #mask=torch.nn.parallel.data_parallel(self.danet,input_list,device_ids=self.gpuid)
                mask, hidden = self.danet(input_list)
            else:
                mask, hidden = self.danet(mix_samp, ibm, non_silent)

            l = Loss(mix_samp, wf, mask)
            epoch_loss = l.loss()
            total_loss += epoch_loss.item()
            epoch_loss.backward()

            #if self.clip_norm:
            #    torch.nn.utils.clip_grad_norm_(
            #        self.danet.parameters(), self.clip_norm)

            self.optimizer.step()
            if num_index % self.print_freq == 0:
                message = '<epoch:{:d}, iter:{:d}, lr:{:.3e}, loss:{:.3f}>'.format(
                    epoch, num_index, self.optimizer.param_groups[0]['lr'],
                    total_loss / num_index)
                self.logger.info(message)
            num_index += 1
        end_time = time.time()
        total_loss = total_loss / num_index
        message = '<epoch:{:d}, iter:{:d}, lr:{:.3e}, loss:{:.3f}, Total time:{:.3f} min> '.format(
            epoch, num_index, self.optimizer.param_groups[0]['lr'], total_loss,
            (end_time - start_time) / 60)
        self.logger.info(message)
        return total_loss
예제 #4
0
    def validation(self, epoch):
        self.logger.info(
            'Start Validation from epoch: {:d}, iter: {:d}'.format(epoch, 1))
        self.danet.eval()
        num_batchs = len(self.val_dataloader)
        num_index = 1
        total_loss = 0.0
        start_time = time.time()
        with torch.no_grad():
            for mix_samp, wf, ibm, non_silent in self.val_dataloader:
                mix_samp = Variable(mix_samp).contiguous().to(self.device)
                wf = Variable(wf).contiguous().to(self.device)
                ibm = Variable(ibm).contiguous().to(self.device)
                non_silent = Variable(non_silent).contiguous().to(self.device)

                hidden = self.danet.init_hidden(mix_samp.size(0))
                input_list = [mix_samp, ibm, non_silent, hidden]

                if self.gpuid:
                    #mask=torch.nn.parallel.data_parallel(self.danet,input_list,device_ids=self.gpuid)
                    mask, hidden = self.danet(input_list)
                else:
                    mask, hidden = self.danet(mix_samp, ibm, non_silent)

                l = Loss(mix_samp, wf, mask)
                epoch_loss = l.loss()
                total_loss += epoch_loss.item()
                if num_index % self.print_freq == 0:
                    message = '<epoch:{:d}, iter:{:d}, lr:{:.3e}, loss:{:.3f}>'.format(
                        epoch, num_index, self.optimizer.param_groups[0]['lr'],
                        total_loss / num_index)
                    self.logger.info(message)
                num_index += 1
        end_time = time.time()
        total_loss = total_loss / num_index
        message = '<epoch:{:d}, iter:{:d}, lr:{:.3e}, loss:{:.3f}, Total time:{:.3f} min> '.format(
            epoch, num_index, self.optimizer.param_groups[0]['lr'], total_loss,
            (end_time - start_time) / 60)
        self.logger.info(message)
        return total_loss