def train(self): """ Full training logic, including train and validation. """ not_improved_count = 0 for epoch in range(self.start_epoch, self.epochs + 1): self.data_loader.sampler.set_epoch(epoch) result_dict = self._train_epoch(epoch) # print logged informations to the screen if self.do_validation: val_result_dict = result_dict['val_result_dict'] valid_loss=[] for k,v in val_result_dict.items(): valid_loss.append([k, v['mEP'], v['mER'], v['mEF'], v['mEA']]) valid_str=','.join(str(i) for i in valid_loss) res_=[] res_.append(epoch) res_.append(self.epochs) res_.append(result_dict['loss']) res_.append(result_dict['gl_loss'] * self.gl_loss_lambda) res_.append(result_dict['crf_loss']) time_vntz_now = datetime.now() VN_TZ = pytz.timezone('Asia/Ho_Chi_Minh') time_ = time_vntz_now.astimezone(VN_TZ) res_str=','.join(str(i) for i in res_) res_str=res_str+','+valid_str+','+str(time_) with open('/content/drive/MyDrive/SROIE_extraction/PICK_training_log.csv','a') as fp: fp.write(res_str) fp.write('\n') fp.close() val_res = SpanBasedF1MetricTracker.dict2str(val_result_dict) else: val_res = '' # every epoch log information self.logger_info('[Epoch Validation] Epoch:[{}/{}] Total Loss: {:.6f} ' 'GL_Loss: {:.6f} CRF_Loss: {:.6f} \n{}'. format(epoch, self.epochs, result_dict['loss'], result_dict['gl_loss'] * self.gl_loss_lambda, result_dict['crf_loss'], val_res)) # evaluate model performance according to configured metric, check early stop, and # save best checkpoint as model_best best = False if self.monitor_mode != 'off' and self.do_validation: best, not_improved_count = self._is_best_monitor_metric(best, not_improved_count, val_result_dict) if not_improved_count > self.early_stop: self.logger_info("Validation performance didn\'t improve for {} epochs. " "Training stops.".format(self.early_stop)) break if epoch % self.save_period == 0: self._save_checkpoint(epoch, save_best=best)
def train(self): """ Full training logic, including train and validation. """ if self.distributed: dist.barrier() # Syncing machines before training not_improved_count = 0 for epoch in range(self.start_epoch, self.epochs + 1): # ensure distribute worker sample different data, # set different random seed by passing epoch to sampler if self.distributed: self.data_loader.sampler.set_epoch(epoch) result_dict = self._train_epoch(epoch) # print logged informations to the screen if self.do_validation: val_result_dict = result_dict['val_result_dict'] val_res = SpanBasedF1MetricTracker.dict2str(val_result_dict) else: val_res = '' # every epoch log information self.logger_info( '[Epoch Validation] Epoch:[{}/{}] Total Loss: {:.6f} ' 'GL_Loss: {:.6f} CRF_Loss: {:.6f} \n{}'.format( epoch, self.epochs, result_dict['loss'], result_dict['gl_loss'] * self.gl_loss_lambda, result_dict['crf_loss'], val_res)) # evaluate model performance according to configured metric, check early stop, and # save best checkpoint as model_best best = False if self.monitor_mode != 'off' and self.do_validation: best, not_improved_count = self._is_best_monitor_metric( best, not_improved_count, val_result_dict) if not_improved_count > self.early_stop: self.logger_info( "Validation performance didn\'t improve for {} epochs. " "Training stops.".format(self.early_stop)) break if epoch % self.save_period == 0: self._save_checkpoint(epoch, save_best=best)
def __init__(self, model, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, max_len_step=None): ''' :param model: :param optimizer: :param config: :param data_loader: :param valid_data_loader: :param lr_scheduler: :param max_len_step: controls number of batches(steps) in each epoch. ''' self.config = config self.distributed = config['distributed'] if self.distributed: self.local_master = (config['local_rank'] == 0) self.global_master = (dist.get_rank() == 0) else: self.local_master = True self.global_master = True self.logger = config.get_logger( 'trainer', config['trainer']['log_verbosity']) if self.local_master else None # setup GPU device if available, move model into configured device self.device, self.device_ids = self._prepare_device( config['local_rank'], config['local_world_size']) self.model = model.to(self.device) self.optimizer = optimizer cfg_trainer = config['trainer'] self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer['save_period'] monitor_open = cfg_trainer['monitor_open'] if monitor_open: self.monitor = cfg_trainer.get('monitor', 'off') else: self.monitor = 'off' # configuration to monitor model performance and save best if self.monitor == 'off': self.monitor_mode = 'off' self.monitor_best = 0 else: self.monitor_mode, self.monitor_metric = self.monitor.split() assert self.monitor_mode in ['min', 'max'] self.monitor_best = inf if self.monitor_mode == 'min' else -inf self.early_stop = cfg_trainer.get('early_stop', inf) self.early_stop = inf if self.early_stop == -1 else self.early_stop self.start_epoch = 1 if self.local_master: self.checkpoint_dir = config.save_dir # setup visualization writer instance self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) # load checkpoint for resume training if config.resume is not None: self._resume_checkpoint(config.resume) # load checkpoint following load to multi-gpu, avoid 'module.' prefix if self.config['trainer']['sync_batch_norm'] and self.distributed: self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( self.model) if self.distributed: self.model = DDP(self.model, device_ids=self.device_ids, output_device=self.device_ids[0], find_unused_parameters=True) self.data_loader = data_loader if max_len_step is None: # max length of iteration step of every epoch # epoch-based training self.len_step = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_step = max_len_step self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler log_step = self.config['trainer']['log_step_interval'] self.log_step = log_step if log_step != -1 and 0 < log_step < self.len_step else int( np.sqrt(data_loader.batch_size)) val_step_interval = self.config['trainer']['val_step_interval'] # self.val_step_interval = val_step_interval if val_step_interval!= -1 and 0 < val_step_interval < self.len_step\ # else int(np.sqrt(data_loader.batch_size)) self.val_step_interval = val_step_interval self.gl_loss_lambda = self.config['trainer']['gl_loss_lambda'] self.train_loss_metrics = MetricTracker( 'loss', 'gl_loss', 'crf_loss', writer=self.writer if self.local_master else None) self.valid_f1_metrics = SpanBasedF1MetricTracker(iob_labels_vocab_cls)
def _train_epoch(self, epoch): ''' Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log dict that contains average loss and metric in this epoch. ''' self.model.train() self.train_loss_metrics.reset() ## step iteration start ## for step_idx, input_data_item in enumerate(self.data_loader): step_idx += 1 for key, input_value in input_data_item.items(): if input_value is not None and isinstance( input_value, torch.Tensor): input_data_item[key] = input_value.to(self.device, non_blocking=True) if self.config['trainer']['anomaly_detection']: # This mode will increase the runtime and should only be enabled for debugging with torch.autograd.detect_anomaly(): self.optimizer.zero_grad() # model forward output = self.model(**input_data_item) # calculate loss gl_loss = output['gl_loss'] crf_loss = output['crf_loss'] total_loss = torch.sum( crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss) # backward total_loss.backward() # self.average_gradients(self.model) self.optimizer.step() else: self.optimizer.zero_grad() # model forward output = self.model(**input_data_item) # calculate loss gl_loss = output['gl_loss'] crf_loss = output['crf_loss'] total_loss = torch.sum( crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss) # backward total_loss.backward() # self.average_gradients(self.model) self.optimizer.step() # Use a barrier() to make sure that all process have finished forward and backward if self.distributed: dist.barrier() # obtain the sum of all total_loss at all processes dist.all_reduce(total_loss, op=dist.reduce_op.SUM) size = dist.get_world_size() else: size = 1 gl_loss /= size # averages gl_loss across the whole world crf_loss /= size # averages crf_loss across the whole world # calculate average loss across the batch size avg_gl_loss = torch.mean(gl_loss) avg_crf_loss = torch.mean(crf_loss) avg_loss = avg_crf_loss + self.gl_loss_lambda * avg_gl_loss # update metrics self.writer.set_step((epoch - 1) * self.len_step + step_idx - 1) if self.local_master else None self.train_loss_metrics.update('loss', avg_loss.item()) self.train_loss_metrics.update( 'gl_loss', avg_gl_loss.item() * self.gl_loss_lambda) self.train_loss_metrics.update('crf_loss', avg_crf_loss.item()) # log messages if step_idx % self.log_step == 0: self.logger_info( 'Train Epoch:[{}/{}] Step:[{}/{}] Total Loss: {:.6f} GL_Loss: {:.6f} CRF_Loss: {:.6f}' .format(epoch, self.epochs, step_idx, self.len_step, avg_loss.item(), avg_gl_loss.item() * self.gl_loss_lambda, avg_crf_loss.item())) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) # do validation after val_step_interval iteration if self.do_validation and step_idx % self.val_step_interval == 0: val_result_dict = self._valid_epoch(epoch) self.logger_info( '[Step Validation] Epoch:[{}/{}] Step:[{}/{}] \n{}'. format(epoch, self.epochs, step_idx, self.len_step, SpanBasedF1MetricTracker.dict2str(val_result_dict))) # check if best metric, if true, then save as model_best checkpoint. best, not_improved_count = self._is_best_monitor_metric( False, 0, val_result_dict) if best: self._save_checkpoint(epoch, best) # decide whether continue iter if step_idx == self.len_step + 1: break ## step iteration end ## # {'loss': avg_loss, 'gl_loss': avg_gl_loss, 'crf_loss': avg_crf_loss} log = self.train_loss_metrics.result() # do validation after training an epoch if self.do_validation: val_result_dict = self._valid_epoch(epoch) log['val_result_dict'] = val_result_dict if self.lr_scheduler is not None: self.lr_scheduler.step() return log