Beispiel #1
0
    def train(self):
        """
        Full training logic, including train and validation.
        """
        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):

            self.data_loader.sampler.set_epoch(epoch)
            result_dict = self._train_epoch(epoch)

            # print logged informations to the screen
            if self.do_validation:
                val_result_dict = result_dict['val_result_dict']
                valid_loss=[]
                for k,v in val_result_dict.items():
                  valid_loss.append([k, v['mEP'], v['mER'], v['mEF'], v['mEA']])
                valid_str=','.join(str(i) for i in valid_loss)
                res_=[]
                res_.append(epoch)
                res_.append(self.epochs)
                res_.append(result_dict['loss'])
                res_.append(result_dict['gl_loss'] * self.gl_loss_lambda)
                res_.append(result_dict['crf_loss'])
                time_vntz_now = datetime.now()
                VN_TZ = pytz.timezone('Asia/Ho_Chi_Minh')
                time_ = time_vntz_now.astimezone(VN_TZ)
                res_str=','.join(str(i) for i in res_)
                res_str=res_str+','+valid_str+','+str(time_)
                with open('/content/drive/MyDrive/SROIE_extraction/PICK_training_log.csv','a') as fp:
                  fp.write(res_str)
                  fp.write('\n')
                  fp.close()
                
                

                val_res = SpanBasedF1MetricTracker.dict2str(val_result_dict)
            else:
                val_res = ''
            # every epoch log information
            self.logger_info('[Epoch Validation] Epoch:[{}/{}] Total Loss: {:.6f} '
                             'GL_Loss: {:.6f} CRF_Loss: {:.6f} \n{}'.
                             format(epoch, self.epochs, result_dict['loss'],
                                    result_dict['gl_loss'] * self.gl_loss_lambda,
                                    result_dict['crf_loss'], val_res))

            # evaluate model performance according to configured metric, check early stop, and
            # save best checkpoint as model_best
            best = False
            if self.monitor_mode != 'off' and self.do_validation:
                best, not_improved_count = self._is_best_monitor_metric(best, not_improved_count, val_result_dict)
                if not_improved_count > self.early_stop:
                    self.logger_info("Validation performance didn\'t improve for {} epochs. "
                                     "Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)
Beispiel #2
0
    def train(self):
        """
        Full training logic, including train and validation.
        """

        if self.distributed:
            dist.barrier()  # Syncing machines before training

        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):

            # ensure distribute worker sample different data,
            # set different random seed by passing epoch to sampler
            if self.distributed:
                self.data_loader.sampler.set_epoch(epoch)
            result_dict = self._train_epoch(epoch)

            # print logged informations to the screen
            if self.do_validation:
                val_result_dict = result_dict['val_result_dict']
                val_res = SpanBasedF1MetricTracker.dict2str(val_result_dict)
            else:
                val_res = ''
            # every epoch log information
            self.logger_info(
                '[Epoch Validation] Epoch:[{}/{}] Total Loss: {:.6f} '
                'GL_Loss: {:.6f} CRF_Loss: {:.6f} \n{}'.format(
                    epoch, self.epochs, result_dict['loss'],
                    result_dict['gl_loss'] * self.gl_loss_lambda,
                    result_dict['crf_loss'], val_res))

            # evaluate model performance according to configured metric, check early stop, and
            # save best checkpoint as model_best
            best = False
            if self.monitor_mode != 'off' and self.do_validation:
                best, not_improved_count = self._is_best_monitor_metric(
                    best, not_improved_count, val_result_dict)
                if not_improved_count > self.early_stop:
                    self.logger_info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)
Beispiel #3
0
    def _train_epoch(self, epoch):
        '''
        Training logic for an epoch
        :param epoch: Integer, current training epoch.
        :return: A log dict that contains average loss and metric in this epoch.
        '''
        self.model.train()
        self.train_loss_metrics.reset()
        ## step iteration start ##
        for step_idx, input_data_item in enumerate(self.data_loader):
            step_idx += 1
            for key, input_value in input_data_item.items():
                if input_value is not None and isinstance(
                        input_value, torch.Tensor):
                    input_data_item[key] = input_value.to(self.device,
                                                          non_blocking=True)
            if self.config['trainer']['anomaly_detection']:
                # This mode will increase the runtime and should only be enabled for debugging
                with torch.autograd.detect_anomaly():
                    self.optimizer.zero_grad()
                    # model forward
                    output = self.model(**input_data_item)
                    # calculate loss
                    gl_loss = output['gl_loss']
                    crf_loss = output['crf_loss']
                    total_loss = torch.sum(
                        crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss)
                    # backward
                    total_loss.backward()
                    # self.average_gradients(self.model)
                    self.optimizer.step()
            else:
                self.optimizer.zero_grad()
                # model forward
                output = self.model(**input_data_item)
                # calculate loss
                gl_loss = output['gl_loss']
                crf_loss = output['crf_loss']
                total_loss = torch.sum(
                    crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss)
                # backward
                total_loss.backward()
                # self.average_gradients(self.model)
                self.optimizer.step()

            # Use a barrier() to make sure that all process have finished forward and backward
            if self.distributed:
                dist.barrier()
                #  obtain the sum of all total_loss at all processes
                dist.all_reduce(total_loss, op=dist.reduce_op.SUM)

                size = dist.get_world_size()
            else:
                size = 1
            gl_loss /= size  # averages gl_loss across the whole world
            crf_loss /= size  # averages crf_loss across the whole world

            # calculate average loss across the batch size
            avg_gl_loss = torch.mean(gl_loss)
            avg_crf_loss = torch.mean(crf_loss)
            avg_loss = avg_crf_loss + self.gl_loss_lambda * avg_gl_loss
            # update metrics
            self.writer.set_step((epoch - 1) * self.len_step + step_idx -
                                 1) if self.local_master else None
            self.train_loss_metrics.update('loss', avg_loss.item())
            self.train_loss_metrics.update(
                'gl_loss',
                avg_gl_loss.item() * self.gl_loss_lambda)
            self.train_loss_metrics.update('crf_loss', avg_crf_loss.item())

            # log messages
            if step_idx % self.log_step == 0:
                self.logger_info(
                    'Train Epoch:[{}/{}] Step:[{}/{}] Total Loss: {:.6f} GL_Loss: {:.6f} CRF_Loss: {:.6f}'
                    .format(epoch, self.epochs, step_idx, self.len_step,
                            avg_loss.item(),
                            avg_gl_loss.item() * self.gl_loss_lambda,
                            avg_crf_loss.item()))
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

            # do validation after val_step_interval iteration
            if self.do_validation and step_idx % self.val_step_interval == 0:
                val_result_dict = self._valid_epoch(epoch)
                self.logger_info(
                    '[Step Validation] Epoch:[{}/{}] Step:[{}/{}]  \n{}'.
                    format(epoch, self.epochs, step_idx, self.len_step,
                           SpanBasedF1MetricTracker.dict2str(val_result_dict)))

                # check if best metric, if true, then save as model_best checkpoint.
                best, not_improved_count = self._is_best_monitor_metric(
                    False, 0, val_result_dict)
                if best:
                    self._save_checkpoint(epoch, best)

            # decide whether continue iter
            if step_idx == self.len_step + 1:
                break

        ## step iteration end ##

        # {'loss': avg_loss, 'gl_loss': avg_gl_loss, 'crf_loss': avg_crf_loss}
        log = self.train_loss_metrics.result()

        # do validation after training an epoch
        if self.do_validation:
            val_result_dict = self._valid_epoch(epoch)
            log['val_result_dict'] = val_result_dict

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log