Пример #1
0
    def train(self):
        """
        Full training logic, including train and validation.
        """
        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):

            self.data_loader.sampler.set_epoch(epoch)
            result_dict = self._train_epoch(epoch)

            # print logged informations to the screen
            if self.do_validation:
                val_result_dict = result_dict['val_result_dict']
                valid_loss=[]
                for k,v in val_result_dict.items():
                  valid_loss.append([k, v['mEP'], v['mER'], v['mEF'], v['mEA']])
                valid_str=','.join(str(i) for i in valid_loss)
                res_=[]
                res_.append(epoch)
                res_.append(self.epochs)
                res_.append(result_dict['loss'])
                res_.append(result_dict['gl_loss'] * self.gl_loss_lambda)
                res_.append(result_dict['crf_loss'])
                time_vntz_now = datetime.now()
                VN_TZ = pytz.timezone('Asia/Ho_Chi_Minh')
                time_ = time_vntz_now.astimezone(VN_TZ)
                res_str=','.join(str(i) for i in res_)
                res_str=res_str+','+valid_str+','+str(time_)
                with open('/content/drive/MyDrive/SROIE_extraction/PICK_training_log.csv','a') as fp:
                  fp.write(res_str)
                  fp.write('\n')
                  fp.close()
                
                

                val_res = SpanBasedF1MetricTracker.dict2str(val_result_dict)
            else:
                val_res = ''
            # every epoch log information
            self.logger_info('[Epoch Validation] Epoch:[{}/{}] Total Loss: {:.6f} '
                             'GL_Loss: {:.6f} CRF_Loss: {:.6f} \n{}'.
                             format(epoch, self.epochs, result_dict['loss'],
                                    result_dict['gl_loss'] * self.gl_loss_lambda,
                                    result_dict['crf_loss'], val_res))

            # evaluate model performance according to configured metric, check early stop, and
            # save best checkpoint as model_best
            best = False
            if self.monitor_mode != 'off' and self.do_validation:
                best, not_improved_count = self._is_best_monitor_metric(best, not_improved_count, val_result_dict)
                if not_improved_count > self.early_stop:
                    self.logger_info("Validation performance didn\'t improve for {} epochs. "
                                     "Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)
Пример #2
0
    def train(self):
        """
        Full training logic, including train and validation.
        """

        if self.distributed:
            dist.barrier()  # Syncing machines before training

        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):

            # ensure distribute worker sample different data,
            # set different random seed by passing epoch to sampler
            if self.distributed:
                self.data_loader.sampler.set_epoch(epoch)
            result_dict = self._train_epoch(epoch)

            # print logged informations to the screen
            if self.do_validation:
                val_result_dict = result_dict['val_result_dict']
                val_res = SpanBasedF1MetricTracker.dict2str(val_result_dict)
            else:
                val_res = ''
            # every epoch log information
            self.logger_info(
                '[Epoch Validation] Epoch:[{}/{}] Total Loss: {:.6f} '
                'GL_Loss: {:.6f} CRF_Loss: {:.6f} \n{}'.format(
                    epoch, self.epochs, result_dict['loss'],
                    result_dict['gl_loss'] * self.gl_loss_lambda,
                    result_dict['crf_loss'], val_res))

            # evaluate model performance according to configured metric, check early stop, and
            # save best checkpoint as model_best
            best = False
            if self.monitor_mode != 'off' and self.do_validation:
                best, not_improved_count = self._is_best_monitor_metric(
                    best, not_improved_count, val_result_dict)
                if not_improved_count > self.early_stop:
                    self.logger_info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)
Пример #3
0
    def __init__(self,
                 model,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 max_len_step=None):
        '''

        :param model:
        :param optimizer:
        :param config:
        :param data_loader:
        :param valid_data_loader:
        :param lr_scheduler:
        :param max_len_step:  controls number of batches(steps) in each epoch.
        '''
        self.config = config
        self.distributed = config['distributed']
        if self.distributed:
            self.local_master = (config['local_rank'] == 0)
            self.global_master = (dist.get_rank() == 0)
        else:
            self.local_master = True
            self.global_master = True
        self.logger = config.get_logger(
            'trainer',
            config['trainer']['log_verbosity']) if self.local_master else None

        # setup GPU device if available, move model into configured device
        self.device, self.device_ids = self._prepare_device(
            config['local_rank'], config['local_world_size'])
        self.model = model.to(self.device)

        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        monitor_open = cfg_trainer['monitor_open']
        if monitor_open:
            self.monitor = cfg_trainer.get('monitor', 'off')
        else:
            self.monitor = 'off'

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.monitor_mode = 'off'
            self.monitor_best = 0
        else:
            self.monitor_mode, self.monitor_metric = self.monitor.split()
            assert self.monitor_mode in ['min', 'max']

            self.monitor_best = inf if self.monitor_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)
            self.early_stop = inf if self.early_stop == -1 else self.early_stop

        self.start_epoch = 1

        if self.local_master:
            self.checkpoint_dir = config.save_dir
            # setup visualization writer instance
            self.writer = TensorboardWriter(config.log_dir, self.logger,
                                            cfg_trainer['tensorboard'])

        # load checkpoint for resume training
        if config.resume is not None:
            self._resume_checkpoint(config.resume)

        # load checkpoint following load to multi-gpu, avoid 'module.' prefix
        if self.config['trainer']['sync_batch_norm'] and self.distributed:
            self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                self.model)

        if self.distributed:
            self.model = DDP(self.model,
                             device_ids=self.device_ids,
                             output_device=self.device_ids[0],
                             find_unused_parameters=True)

        self.data_loader = data_loader
        if max_len_step is None:  # max length of iteration step of every epoch
            # epoch-based training
            self.len_step = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_step = max_len_step
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler

        log_step = self.config['trainer']['log_step_interval']
        self.log_step = log_step if log_step != -1 and 0 < log_step < self.len_step else int(
            np.sqrt(data_loader.batch_size))

        val_step_interval = self.config['trainer']['val_step_interval']
        # self.val_step_interval = val_step_interval if val_step_interval!= -1 and 0 < val_step_interval < self.len_step\
        #                                             else int(np.sqrt(data_loader.batch_size))
        self.val_step_interval = val_step_interval

        self.gl_loss_lambda = self.config['trainer']['gl_loss_lambda']

        self.train_loss_metrics = MetricTracker(
            'loss',
            'gl_loss',
            'crf_loss',
            writer=self.writer if self.local_master else None)
        self.valid_f1_metrics = SpanBasedF1MetricTracker(iob_labels_vocab_cls)
Пример #4
0
    def _train_epoch(self, epoch):
        '''
        Training logic for an epoch
        :param epoch: Integer, current training epoch.
        :return: A log dict that contains average loss and metric in this epoch.
        '''
        self.model.train()
        self.train_loss_metrics.reset()
        ## step iteration start ##
        for step_idx, input_data_item in enumerate(self.data_loader):
            step_idx += 1
            for key, input_value in input_data_item.items():
                if input_value is not None and isinstance(
                        input_value, torch.Tensor):
                    input_data_item[key] = input_value.to(self.device,
                                                          non_blocking=True)
            if self.config['trainer']['anomaly_detection']:
                # This mode will increase the runtime and should only be enabled for debugging
                with torch.autograd.detect_anomaly():
                    self.optimizer.zero_grad()
                    # model forward
                    output = self.model(**input_data_item)
                    # calculate loss
                    gl_loss = output['gl_loss']
                    crf_loss = output['crf_loss']
                    total_loss = torch.sum(
                        crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss)
                    # backward
                    total_loss.backward()
                    # self.average_gradients(self.model)
                    self.optimizer.step()
            else:
                self.optimizer.zero_grad()
                # model forward
                output = self.model(**input_data_item)
                # calculate loss
                gl_loss = output['gl_loss']
                crf_loss = output['crf_loss']
                total_loss = torch.sum(
                    crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss)
                # backward
                total_loss.backward()
                # self.average_gradients(self.model)
                self.optimizer.step()

            # Use a barrier() to make sure that all process have finished forward and backward
            if self.distributed:
                dist.barrier()
                #  obtain the sum of all total_loss at all processes
                dist.all_reduce(total_loss, op=dist.reduce_op.SUM)

                size = dist.get_world_size()
            else:
                size = 1
            gl_loss /= size  # averages gl_loss across the whole world
            crf_loss /= size  # averages crf_loss across the whole world

            # calculate average loss across the batch size
            avg_gl_loss = torch.mean(gl_loss)
            avg_crf_loss = torch.mean(crf_loss)
            avg_loss = avg_crf_loss + self.gl_loss_lambda * avg_gl_loss
            # update metrics
            self.writer.set_step((epoch - 1) * self.len_step + step_idx -
                                 1) if self.local_master else None
            self.train_loss_metrics.update('loss', avg_loss.item())
            self.train_loss_metrics.update(
                'gl_loss',
                avg_gl_loss.item() * self.gl_loss_lambda)
            self.train_loss_metrics.update('crf_loss', avg_crf_loss.item())

            # log messages
            if step_idx % self.log_step == 0:
                self.logger_info(
                    'Train Epoch:[{}/{}] Step:[{}/{}] Total Loss: {:.6f} GL_Loss: {:.6f} CRF_Loss: {:.6f}'
                    .format(epoch, self.epochs, step_idx, self.len_step,
                            avg_loss.item(),
                            avg_gl_loss.item() * self.gl_loss_lambda,
                            avg_crf_loss.item()))
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

            # do validation after val_step_interval iteration
            if self.do_validation and step_idx % self.val_step_interval == 0:
                val_result_dict = self._valid_epoch(epoch)
                self.logger_info(
                    '[Step Validation] Epoch:[{}/{}] Step:[{}/{}]  \n{}'.
                    format(epoch, self.epochs, step_idx, self.len_step,
                           SpanBasedF1MetricTracker.dict2str(val_result_dict)))

                # check if best metric, if true, then save as model_best checkpoint.
                best, not_improved_count = self._is_best_monitor_metric(
                    False, 0, val_result_dict)
                if best:
                    self._save_checkpoint(epoch, best)

            # decide whether continue iter
            if step_idx == self.len_step + 1:
                break

        ## step iteration end ##

        # {'loss': avg_loss, 'gl_loss': avg_gl_loss, 'crf_loss': avg_crf_loss}
        log = self.train_loss_metrics.result()

        # do validation after training an epoch
        if self.do_validation:
            val_result_dict = self._valid_epoch(epoch)
            log['val_result_dict'] = val_result_dict

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log