class Trainer: """ Trainer class """ def __init__(self, model, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, max_len_step=None): ''' :param model: :param optimizer: :param config: :param data_loader: :param valid_data_loader: :param lr_scheduler: :param max_len_step: controls number of batches(steps) in each epoch. ''' self.config = config self.distributed = config['distributed'] if self.distributed: self.local_master = (config['local_rank'] == 0) self.global_master = (dist.get_rank() == 0) else: self.local_master = True self.global_master = True self.logger = config.get_logger( 'trainer', config['trainer']['log_verbosity']) if self.local_master else None # setup GPU device if available, move model into configured device self.device, self.device_ids = self._prepare_device( config['local_rank'], config['local_world_size']) self.model = model.to(self.device) self.optimizer = optimizer cfg_trainer = config['trainer'] self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer['save_period'] monitor_open = cfg_trainer['monitor_open'] if monitor_open: self.monitor = cfg_trainer.get('monitor', 'off') else: self.monitor = 'off' # configuration to monitor model performance and save best if self.monitor == 'off': self.monitor_mode = 'off' self.monitor_best = 0 else: self.monitor_mode, self.monitor_metric = self.monitor.split() assert self.monitor_mode in ['min', 'max'] self.monitor_best = inf if self.monitor_mode == 'min' else -inf self.early_stop = cfg_trainer.get('early_stop', inf) self.early_stop = inf if self.early_stop == -1 else self.early_stop self.start_epoch = 1 if self.local_master: self.checkpoint_dir = config.save_dir # setup visualization writer instance self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) # load checkpoint for resume training if config.resume is not None: self._resume_checkpoint(config.resume) # load checkpoint following load to multi-gpu, avoid 'module.' prefix if self.config['trainer']['sync_batch_norm'] and self.distributed: self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( self.model) if self.distributed: self.model = DDP(self.model, device_ids=self.device_ids, output_device=self.device_ids[0], find_unused_parameters=True) self.data_loader = data_loader if max_len_step is None: # max length of iteration step of every epoch # epoch-based training self.len_step = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_step = max_len_step self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler log_step = self.config['trainer']['log_step_interval'] self.log_step = log_step if log_step != -1 and 0 < log_step < self.len_step else int( np.sqrt(data_loader.batch_size)) val_step_interval = self.config['trainer']['val_step_interval'] # self.val_step_interval = val_step_interval if val_step_interval!= -1 and 0 < val_step_interval < self.len_step\ # else int(np.sqrt(data_loader.batch_size)) self.val_step_interval = val_step_interval self.gl_loss_lambda = self.config['trainer']['gl_loss_lambda'] self.train_loss_metrics = MetricTracker( 'loss', 'gl_loss', 'crf_loss', writer=self.writer if self.local_master else None) self.valid_f1_metrics = SpanBasedF1MetricTracker(iob_labels_vocab_cls) def train(self): """ Full training logic, including train and validation. """ if self.distributed: dist.barrier() # Syncing machines before training not_improved_count = 0 for epoch in range(self.start_epoch, self.epochs + 1): # ensure distribute worker sample different data, # set different random seed by passing epoch to sampler if self.distributed: self.data_loader.sampler.set_epoch(epoch) result_dict = self._train_epoch(epoch) # print logged informations to the screen if self.do_validation: val_result_dict = result_dict['val_result_dict'] val_res = SpanBasedF1MetricTracker.dict2str(val_result_dict) else: val_res = '' # every epoch log information self.logger_info( '[Epoch Validation] Epoch:[{}/{}] Total Loss: {:.6f} ' 'GL_Loss: {:.6f} CRF_Loss: {:.6f} \n{}'.format( epoch, self.epochs, result_dict['loss'], result_dict['gl_loss'] * self.gl_loss_lambda, result_dict['crf_loss'], val_res)) # evaluate model performance according to configured metric, check early stop, and # save best checkpoint as model_best best = False if self.monitor_mode != 'off' and self.do_validation: best, not_improved_count = self._is_best_monitor_metric( best, not_improved_count, val_result_dict) if not_improved_count > self.early_stop: self.logger_info( "Validation performance didn\'t improve for {} epochs. " "Training stops.".format(self.early_stop)) break if epoch % self.save_period == 0: self._save_checkpoint(epoch, save_best=best) def _is_best_monitor_metric(self, best, not_improved_count, val_result_dict): ''' monitor metric :param best: :param not_improved_count: :param val_result_dict: :return: ''' entity_name, metric = self.monitor_metric.split('-') val_monitor_metric_res = val_result_dict[entity_name][metric] try: # check whether model performance improved or not, according to specified metric(monitor_metric) improved = (self.monitor_mode == 'min' and val_monitor_metric_res <= self.monitor_best) or \ (self.monitor_mode == 'max' and val_monitor_metric_res >= self.monitor_best) except KeyError: self.logger_warning( "Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( self.monitor_metric)) self.monitor_mode = 'off' improved = False if improved: self.monitor_best = val_monitor_metric_res not_improved_count = 0 best = True else: not_improved_count += 1 return best, not_improved_count def _train_epoch(self, epoch): ''' Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log dict that contains average loss and metric in this epoch. ''' self.model.train() self.train_loss_metrics.reset() ## step iteration start ## for step_idx, input_data_item in enumerate(self.data_loader): step_idx += 1 for key, input_value in input_data_item.items(): if input_value is not None and isinstance( input_value, torch.Tensor): input_data_item[key] = input_value.to(self.device, non_blocking=True) if self.config['trainer']['anomaly_detection']: # This mode will increase the runtime and should only be enabled for debugging with torch.autograd.detect_anomaly(): self.optimizer.zero_grad() # model forward output = self.model(**input_data_item) # calculate loss gl_loss = output['gl_loss'] crf_loss = output['crf_loss'] total_loss = torch.sum( crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss) # backward total_loss.backward() # self.average_gradients(self.model) self.optimizer.step() else: self.optimizer.zero_grad() # model forward output = self.model(**input_data_item) # calculate loss gl_loss = output['gl_loss'] crf_loss = output['crf_loss'] total_loss = torch.sum( crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss) # backward total_loss.backward() # self.average_gradients(self.model) self.optimizer.step() # Use a barrier() to make sure that all process have finished forward and backward if self.distributed: dist.barrier() # obtain the sum of all total_loss at all processes dist.all_reduce(total_loss, op=dist.reduce_op.SUM) size = dist.get_world_size() else: size = 1 gl_loss /= size # averages gl_loss across the whole world crf_loss /= size # averages crf_loss across the whole world # calculate average loss across the batch size avg_gl_loss = torch.mean(gl_loss) avg_crf_loss = torch.mean(crf_loss) avg_loss = avg_crf_loss + self.gl_loss_lambda * avg_gl_loss # update metrics self.writer.set_step((epoch - 1) * self.len_step + step_idx - 1) if self.local_master else None self.train_loss_metrics.update('loss', avg_loss.item()) self.train_loss_metrics.update( 'gl_loss', avg_gl_loss.item() * self.gl_loss_lambda) self.train_loss_metrics.update('crf_loss', avg_crf_loss.item()) # log messages if step_idx % self.log_step == 0: self.logger_info( 'Train Epoch:[{}/{}] Step:[{}/{}] Total Loss: {:.6f} GL_Loss: {:.6f} CRF_Loss: {:.6f}' .format(epoch, self.epochs, step_idx, self.len_step, avg_loss.item(), avg_gl_loss.item() * self.gl_loss_lambda, avg_crf_loss.item())) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) # do validation after val_step_interval iteration if self.do_validation and step_idx % self.val_step_interval == 0: val_result_dict = self._valid_epoch(epoch) self.logger_info( '[Step Validation] Epoch:[{}/{}] Step:[{}/{}] \n{}'. format(epoch, self.epochs, step_idx, self.len_step, SpanBasedF1MetricTracker.dict2str(val_result_dict))) # check if best metric, if true, then save as model_best checkpoint. best, not_improved_count = self._is_best_monitor_metric( False, 0, val_result_dict) if best: self._save_checkpoint(epoch, best) # decide whether continue iter if step_idx == self.len_step + 1: break ## step iteration end ## # {'loss': avg_loss, 'gl_loss': avg_gl_loss, 'crf_loss': avg_crf_loss} log = self.train_loss_metrics.result() # do validation after training an epoch if self.do_validation: val_result_dict = self._valid_epoch(epoch) log['val_result_dict'] = val_result_dict if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): ''' Validate after training an epoch or regular step, this is a time-consuming procedure if validation data is big. :param epoch: Integer, current training epoch. :return: A dict that contains information about validation ''' self.model.eval() self.valid_f1_metrics.reset() with torch.no_grad(): for step_idx, input_data_item in enumerate(self.valid_data_loader): for key, input_value in input_data_item.items(): if input_value is not None and isinstance( input_value, torch.Tensor): input_data_item[key] = input_value.to( self.device, non_blocking=True) output = self.model(**input_data_item) # print("awesome 307") logits = output['logits'] new_mask = output['new_mask'] if hasattr(self.model, 'module'): # List[(List[int], torch.Tensor)] contain the tag indices of the maximum likelihood tag sequence. # and the score of the viterbi path. best_paths = self.model.module.decoder.crf_layer.viterbi_tags( logits, mask=new_mask, logits_batch_first=True) else: best_paths = self.model.decoder.crf_layer.viterbi_tags( logits, mask=new_mask, logits_batch_first=True) predicted_tags = [] for path, score in best_paths: predicted_tags.append(path) self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + step_idx, 'valid') \ if self.local_master else None # calculate and update f1 metrics # (B, N*T, out_dim) predicted_tags_hard_prob = logits * 0 for i, instance_tags in enumerate(predicted_tags): for j, tag_id in enumerate(instance_tags): predicted_tags_hard_prob[i, j, tag_id] = 1 golden_tags = input_data_item['iob_tags_label'] mask = input_data_item['mask'] union_iob_tags = iob_tags_to_union_iob_tags(golden_tags, mask) if self.distributed: dist.barrier() # self.valid_f1_metrics.update(predicted_tags_hard_prob.long(), union_iob_tags, new_mask) # add histogram of model parameters to the tensorboard # for name, p in self.model.named_parameters(): # self.writer.add_histogram(name, p, bins='auto') f1_result_dict = self.valid_f1_metrics.result() # rollback to train mode self.model.train() return f1_result_dict def average_gradients(self, model): ''' Gradient averaging :param model: :return: ''' size = float(dist.get_world_size()) for param in model.parameters(): dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM) param.grad.data /= size def logger_info(self, msg): self.logger.info(msg) if self.local_master else None def logger_warning(self, msg): self.logger.warning(msg) if self.local_master else None def _prepare_device(self, local_rank, local_world_size): ''' setup GPU device if available, move model into configured device :param local_rank: :param local_world_size: :return: ''' if self.distributed: ngpu_per_process = torch.cuda.device_count() // local_world_size device_ids = list( range(local_rank * ngpu_per_process, (local_rank + 1) * ngpu_per_process)) if torch.cuda.is_available() and local_rank != -1: torch.cuda.set_device( device_ids[0] ) # device_ids[0] =local_rank if local_world_size = n_gpu per node device = 'cuda' self.logger_info( f"[Process {os.getpid()}] world_size = {dist.get_world_size()}, " + f"rank = {dist.get_rank()}, n_gpu/process = {ngpu_per_process}, device_ids = {device_ids}" ) else: self.logger_warning('Training will be using CPU!') device = 'cpu' device = torch.device(device) return device, device_ids else: n_gpu = torch.cuda.device_count() print(f"NUMBER GPU {n_gpu}") n_gpu_use = local_world_size if n_gpu_use > 0 and n_gpu == 0: self.logger_warning( "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.") n_gpu_use = 0 if n_gpu_use > n_gpu: self.logger_warning( "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(n_gpu_use, n_gpu)) n_gpu_use = n_gpu list_ids = list(range(n_gpu_use)) if n_gpu_use > 0: torch.cuda.set_device( list_ids[0]) # only use first available gpu as devices self.logger_warning(f'Training is using GPU {list_ids[0]}!') device = 'cuda' else: self.logger_warning('Training is using CPU!') device = 'cpu' device = torch.device(device) return device, list_ids def _save_checkpoint(self, epoch, save_best=False): ''' Saving checkpoints :param epoch: current epoch number :param save_best: if True, rename the saved checkpoint to 'model_best.pth' :return: ''' # only local master process do save model if not self.local_master: return if hasattr(self.model, 'module'): arch = type(self.model.module).__name__ state_dict = self.model.module.state_dict() else: arch = type(self.model).__name__ state_dict = self.model.state_dict() state = { 'arch': arch, 'epoch': epoch, 'state_dict': state_dict, 'optimizer': self.optimizer.state_dict(), 'monitor_best': self.monitor_best, 'config': self.config } if save_best: best_path = str(self.checkpoint_dir / 'model_best.pth') torch.save(state, best_path) self.logger_info("Saving current best: model_best.pth ...") else: filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) torch.save(state, filename) self.logger_info("Saving checkpoint: {} ...".format(filename)) def _resume_checkpoint(self, resume_path): ''' Resume from saved checkpoints :param resume_path: Checkpoint path to be resumed :return: ''' resume_path = str(resume_path) self.logger_info("Loading checkpoint: {} ...".format(resume_path)) # map_location = {'cuda:%d' % 0: 'cuda:%d' % self.config['local_rank']} checkpoint = torch.load(resume_path, map_location=self.device) self.start_epoch = checkpoint['epoch'] + 1 self.monitor_best = checkpoint['monitor_best'] # load architecture params from checkpoint. if checkpoint['config']['model_arch'] != self.config['model_arch']: self.logger_warning( "Warning: Architecture configuration given in config file is different from that of " "checkpoint. This may yield an exception while state_dict is being loaded." ) self.model.load_state_dict(checkpoint['state_dict']) # load optimizer state from checkpoint only when optimizer type is not changed. if checkpoint['config']['optimizer']['type'] != self.config[ 'optimizer']['type']: self.logger_warning( "Warning: Optimizer type given in config file is different from that of checkpoint. " "Optimizer parameters not being resumed.") else: self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger_info( "Checkpoint loaded. Resume training from epoch {}".format( self.start_epoch))
class Trainer: """ Implements training and validation logic """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): self.config = config self.logger = config.get_logger('trainer', config['trainer']['verbosity']) # setup GPU device if available, move model into configured device self.device, device_ids = self._prepare_device(config['n_gpu']) self.model = model.to(self.device) if len(device_ids) > 1: self.model = torch.nn.DataParallel(model, device_ids=device_ids) self.criterion = criterion self.metric_ftns = metric_ftns self.optimizer = optimizer cfg_trainer = config['trainer'] self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer['save_period'] self.monitor = cfg_trainer.get('monitor', 'off') # configuration to monitor model performance and save best if self.monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = inf if self.mnt_mode == 'min' else -inf self.early_stop = cfg_trainer.get('early_stop', inf) self.start_epoch = 1 self.checkpoint_dir = config.save_dir # setup visualization writer instance self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) if config.resume is not None: self._resume_checkpoint(config.resume) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() epoch_target = [] epoch_scores = [] epoch_word_pairs = [] epoch_loss = [] for batch_idx, batch_data in enumerate(self.data_loader): if not batch_data: continue for field in [ 'input_ids', 'label', 'target', 'attention_mask', 'term1_mask', 'term2_mask' ]: batch_data[field] = batch_data[field].to(self.device) self.optimizer.zero_grad() output = self.model(batch_data) loss = self.criterion( output, batch_data['target'], self.data_loader.dataset.class_weights.to(self.device)) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) # accumulate epoch quantities epoch_target += [t.item() for t in batch_data['label']] epoch_scores += [output.cpu().detach().numpy()] epoch_word_pairs += batch_data['term_pair'] epoch_loss += [loss.item()] # update metrics self.writer.add_scalar("loss", loss.item()) for met in self.metric_ftns: self.writer.add_scalar(met.__name__, met(epoch_target, epoch_scores)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) #self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break log = { m.__name__: m(epoch_target, epoch_scores) for m in self.metric_ftns } log["loss"] = np.sum(epoch_loss) / len(self.data_loader) if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() with torch.no_grad(): epoch_target = [] epoch_scores = [] epoch_word_pairs = [] epoch_loss = [] for batch_idx, batch_data in enumerate(tqdm( self.valid_data_loader)): for field in [ 'input_ids', 'target', 'label', 'attention_mask', 'term1_mask', 'term2_mask' ]: batch_data[field] = batch_data[field].to(self.device) output = self.model(batch_data) pred = torch.argmax(output, dim=-1) loss = self.criterion( output, batch_data['target'].squeeze(-1), self.data_loader.dataset.class_weights.to(self.device)) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') # accumulate epoch quantities epoch_target += [t.item() for t in batch_data['label']] epoch_scores += [output.cpu().detach().numpy()] epoch_word_pairs += batch_data['term_pair'] epoch_loss += [loss.item()] # update metrics self.writer.add_scalar('loss', loss.item()) for met in self.metric_ftns: self.writer.add_scalar(met.__name__, met(epoch_target, epoch_scores)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') log = { m.__name__: m(epoch_target, epoch_scores) for m in self.metric_ftns } log["loss"] = np.sum(epoch_loss) / len(self.valid_data_loader) return log def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total) def train(self): """ Full training logic """ not_improved_count = 0 for epoch in range(self.start_epoch, self.epochs + 1): result = self._train_epoch(epoch) # save logged informations into log dict log = {'epoch': epoch} log.update(result) # print logged informations to the screen for key, value in log.items(): self.logger.info(' {:15s}: {}'.format(str(key), value)) # evaluate model performance according to configured metric, save best checkpoint as model_best best = False if self.mnt_mode != 'off': try: # check whether model performance improved or not, according to specified metric(mnt_metric) improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) except KeyError: self.logger.warning( "Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( self.mnt_metric)) self.mnt_mode = 'off' improved = False if improved: self.mnt_best = log[self.mnt_metric] not_improved_count = 0 best = True else: not_improved_count += 1 if not_improved_count > self.early_stop: self.logger.info( "Validation performance didn\'t improve for {} epochs. " "Training stops.".format(self.early_stop)) break if epoch % self.save_period == 0: self._save_checkpoint(epoch, save_best=best) def _prepare_device(self, n_gpu_use): """ setup GPU device if available, move model into configured device """ n_gpu = torch.cuda.device_count() if n_gpu_use > 0 and n_gpu == 0: self.logger.warning( "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.") n_gpu_use = 0 if n_gpu_use > n_gpu: self.logger.warning( "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(n_gpu_use, n_gpu)) n_gpu_use = n_gpu device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') list_ids = list(range(n_gpu_use)) return device, list_ids def _save_checkpoint(self, epoch, save_best=False): """ Saving checkpoints :param epoch: current epoch number :param log: logging information of the epoch :param save_best: if True, rename the saved checkpoint to 'model_best.pth' """ arch = type(self.model).__name__ state = { 'arch': arch, 'epoch': epoch, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'monitor_best': self.mnt_best, 'config': self.config } if save_best: best_path = str(self.checkpoint_dir / 'model_best.pth') torch.save(state, best_path) self.logger.info("Saving current best: model_best.pth ...") def _resume_checkpoint(self, resume_path): """ Resume from saved checkpoints :param resume_path: Checkpoint path to be resumed """ resume_path = str(resume_path) self.logger.info("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path) self.start_epoch = checkpoint['epoch'] + 1 self.mnt_best = checkpoint['monitor_best'] # load architecture params from checkpoint. if checkpoint['config']['arch'] != self.config['arch']: self.logger.warning( "Warning: Architecture configuration given in config file is different from that of " "checkpoint. This may yield an exception while state_dict is being loaded." ) self.model.load_state_dict(checkpoint['state_dict']) # load optimizer state from checkpoint only when optimizer type is not changed. if checkpoint['config']['optimizer']['type'] != self.config[ 'optimizer']['type']: self.logger.warning( "Warning: Optimizer type given in config file is different from that of checkpoint. " "Optimizer parameters not being resumed.") else: self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info( "Checkpoint loaded. Resume training from epoch {}".format( self.start_epoch))
class Trainer: """ Trainer class """ def __init__(self, model, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, max_len_step=None): ''' :param model: :param optimizer: :param config: :param data_loader: :param valid_data_loader: :param lr_scheduler: :param max_len_step: controls number of batches(steps) in each epoch. ''' self.config = config self.distributed = config['distributed'] if self.distributed: self.local_master = (config['local_rank'] == 0) self.global_master = (dist.get_rank() == 0) else: self.local_master = True self.global_master = True self.logger = config.get_logger( 'trainer', config['trainer']['log_verbosity']) if self.local_master else None # setup GPU device if available, move model into configured device self.device, self.device_ids = self._prepare_device( config['local_rank'], config['local_world_size']) self.model = model.to(self.device) self.optimizer = optimizer cfg_trainer = config['trainer'] self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer['save_period'] monitor_open = cfg_trainer['monitor_open'] if monitor_open: self.monitor = cfg_trainer.get('monitor', 'off') else: self.monitor = 'off' # configuration to monitor model performance and save best if self.monitor == 'off': self.monitor_mode = 'off' self.monitor_best = 0 else: self.monitor_mode, self.monitor_metric = self.monitor.split() assert self.monitor_mode in ['min', 'max'] self.monitor_best = inf if self.monitor_mode == 'min' else -inf self.early_stop = cfg_trainer.get('early_stop', inf) self.early_stop = inf if self.early_stop == -1 else self.early_stop self.start_epoch = 1 if self.local_master: self.checkpoint_dir = config.save_dir # setup visualization writer instance self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) # load checkpoint for resume training or finetune self.finetune = config['finetune'] if config.resume is not None: self._resume_checkpoint(config.resume) else: if self.finetune: self.logger_warning( "Finetune mode must set resume args to specific checkpoint path" ) raise RuntimeError( "Finetune mode must set resume args to specific checkpoint path" ) # load checkpoint then load to multi-gpu, avoid 'module.' prefix if self.config['trainer']['sync_batch_norm'] and self.distributed: # sync_batch_norm only support one gpu per process mode self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( self.model) if self.distributed: # move model to distributed gpu self.model = DDP(self.model, device_ids=self.device_ids, output_device=self.device_ids[0], find_unused_parameters=True) # iteration-based training self.len_step = len(data_loader) self.data_loader = data_loader if max_len_step is not None: # max length of iteration step of every epoch self.len_step = min(max_len_step, self.len_step) self.valid_data_loader = valid_data_loader do_validation = self.config['trainer']['do_validation'] self.validation_start_epoch = self.config['trainer'][ 'validation_start_epoch'] self.do_validation = (self.valid_data_loader is not None and do_validation) self.lr_scheduler = lr_scheduler log_step = self.config['trainer']['log_step_interval'] self.log_step = log_step if log_step != -1 and 0 < log_step < self.len_step else int( np.sqrt(data_loader.batch_size)) # do validation interval val_step_interval = self.config['trainer']['val_step_interval'] # self.val_step_interval = val_step_interval if val_step_interval!= -1 and 0 < val_step_interval < self.len_step\ # else int(np.sqrt(data_loader.batch_size)) self.val_step_interval = val_step_interval # build metrics tracker and wrapper tensorboard writer. self.train_metrics = AverageMetricTracker( 'loss', writer=self.writer if self.local_master else None) self.val_metrics = AverageMetricTracker( 'loss', 'word_acc', 'word_acc_case_insensitive', 'edit_distance_acc', writer=self.writer if self.local_master else None) def train(self): """ Full training logic, including train and validation. """ if self.distributed: dist.barrier() # Syncing machines before training not_improved_count = 0 for epoch in range(self.start_epoch, self.epochs + 1): # ensure distribute worker sample different data, # set different random seed by passing epoch to sampler if self.distributed: self.data_loader.sampler.set_epoch(epoch) self.valid_data_loader.batch_sampler.set_epoch( epoch ) if self.valid_data_loader.batch_sampler is not None else None torch.cuda.empty_cache() result_dict = self._train_epoch(epoch) # import pdb;pdb.set_trace() # validate after training an epoch if self.do_validation and epoch >= self.validation_start_epoch: val_metric_res_dict = self._valid_epoch(epoch) # import pdb;pdb.set_trace() val_res = f"\nValidation result after {epoch} epoch: " \ f"Word_acc: {val_metric_res_dict['word_acc']:.6f} " \ f"Word_acc_case_ins: {val_metric_res_dict['word_acc_case_insensitive']:.6f} " \ f"Edit_distance_acc: {val_metric_res_dict['edit_distance_acc']:.6f}" else: val_res = '' # update lr after training an epoch, epoch-wise if self.lr_scheduler is not None: self.lr_scheduler.step() # every epoch log information self.logger_info( '[Epoch End] Epoch:[{}/{}] Loss: {:.6f} LR: {:.8f}'.format( epoch, self.epochs, result_dict['loss'], self._get_lr()) + val_res) # evaluate model performance according to configured metric, check early stop, and # save best checkpoint as model_best best = False if self.monitor_mode != 'off' and self.do_validation and epoch >= self.validation_start_epoch: best, not_improved_count = self._is_best_monitor_metric( best, not_improved_count, val_metric_res_dict) if not_improved_count > self.early_stop: # epoch level count self.logger_info( "Validation performance didn\'t improve for {} epochs. " "Training stops.".format(self.early_stop)) break # epoch-level save period if best or (epoch % self.save_period == 0 and epoch >= self.validation_start_epoch): self._save_checkpoint(epoch, save_best=best) def _is_best_monitor_metric(self, best, not_improved_count, val_result_dict, update_not_improved_count=True): ''' monitor metric :param best: bool :param not_improved_count: int :param val_result_dict: dict :param update_monitor_best: bool, true: update monitor_best when epoch-level validation :return: ''' val_monitor_metric_res = val_result_dict[self.monitor_metric] try: # check whether model performance improved or not, according to specified metric(monitor_metric) improved = (self.monitor_mode == 'min' and val_monitor_metric_res <= self.monitor_best) or \ (self.monitor_mode == 'max' and val_monitor_metric_res >= self.monitor_best) except KeyError: self.logger_warning( "Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( self.monitor_metric)) self.monitor_mode = 'off' improved = False if improved: self.monitor_best = val_monitor_metric_res not_improved_count = 0 best = True else: if update_not_improved_count: # update when do epoch-level validation, step-level not changed count not_improved_count += 1 return best, not_improved_count def _train_epoch(self, epoch): ''' Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log dict that contains average loss and metric in this epoch. ''' self.model.train() self.train_metrics.reset() ## step iteration start ## for step_idx, input_data_item in enumerate(self.data_loader): batch_size = input_data_item['batch_size'] if batch_size == 0: continue images = input_data_item['images'] text_label = input_data_item['labels'] # # step-wise lr scheduler, comment this, using epoch-wise lr_scheduler # if self.lr_scheduler is not None: # self.lr_scheduler.step() # for step_idx in range(self.len_step): step_idx += 1 # import pdb;pdb.set_trace() # prepare input data images = images.to(self.device) target = LabelTransformer.encode(text_label) target = target.to(self.device) target = target.permute(1, 0) if self.config['trainer']['anomaly_detection']: # This mode will increase the runtime and should only be enabled for debugging with torch.autograd.detect_anomaly(): # forward outputs = self.model( images, target[:, :-1]) # need to remove <EOS> in target loss = F.cross_entropy( outputs.contiguous().view(-1, outputs.shape[-1]), target[:, 1:].contiguous().view( -1), # need to remove <SOS> in target ignore_index=LabelTransformer.PAD) # backward and update parameters self.optimizer.zero_grad() loss.backward() # self.average_gradients(self.model) self.optimizer.step() else: # forward outputs = self.model( images, target[:, :-1]) # need to remove <EOS> in target loss = F.cross_entropy( outputs.contiguous().view(-1, outputs.shape[-1]), target[:, 1:].contiguous().view( -1), # need to remove <SOS> in target ignore_index=LabelTransformer.PAD) # backward and update parameters self.optimizer.zero_grad() loss.backward() # self.average_gradients(self.model) self.optimizer.step() ## Train batch done. Logging results # due to training mode (bn, dropout), we don't calculate acc batch_total = images.shape[0] reduced_loss = loss.item() # mean results of ce if self.distributed: # obtain the sum of all train metrics at all processes by all_reduce operation # Must keep track of global batch size, # since not all machines are guaranteed equal batches at the end of an epoch reduced_metrics_tensor = torch.tensor( [batch_total, reduced_loss]).float().to(self.device) # Use a barrier() to make sure that all process have finished above code dist.barrier() # averages metric tensor across the whole world # import pdb;pdb.set_trace() # reduced_metrics_tensor = self.mean_reduce_tensor(reduced_metrics_tensor) reduced_metrics_tensor = self.sum_tesnor( reduced_metrics_tensor) batch_total, reduced_loss = reduced_metrics_tensor.cpu().numpy( ) reduced_loss = reduced_loss / dist.get_world_size() # update metrics and write to tensorboard global_step = (epoch - 1) * self.len_step + step_idx - 1 self.writer.set_step(global_step, mode='train') if self.local_master else None # write tag is loss/train (mode =train) self.train_metrics.update( 'loss', reduced_loss, batch_total ) # here, loss is mean results over batch, accumulate values # log messages if step_idx % self.log_step == 0 or step_idx == 1: self.logger_info( 'Train Epoch:[{}/{}] Step:[{}/{}] Loss: {:.6f} Loss_avg: {:.6f} LR: {:.8f}' .format(epoch, self.epochs, step_idx, self.len_step, self.train_metrics.val('loss'), self.train_metrics.avg('loss'), self._get_lr())) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) # do validation after val_step_interval iteration if self.do_validation and step_idx % self.val_step_interval == 0 and epoch >= self.validation_start_epoch: val_metric_res_dict = self._valid_epoch( epoch) # average metric self.logger_info( '[Step Validation] Epoch:[{}/{}] Step:[{}/{}] Word_acc: {:.6f} Word_acc_case_ins {:.6f}' 'Edit_distance_acc: {:.6f}'.format( epoch, self.epochs, step_idx, self.len_step, val_metric_res_dict['word_acc'], val_metric_res_dict['word_acc_case_insensitive'], val_metric_res_dict['edit_distance_acc'])) # check if best metric, if true, then save as model_best checkpoint. best, not_improved_count = self._is_best_monitor_metric( False, 0, val_metric_res_dict, update_not_improved_count=False) if best: # step-level valida then save model self._save_checkpoint(epoch, best, step_idx) # decide whether continue iter if step_idx == self.len_step: break ## step iteration end ## log_dict = self.train_metrics.result() return log_dict def _valid_epoch(self, epoch): ''' Validate after training an epoch or regular step, this is a time-consuming procedure if validation data is big. :param epoch: Integer, current training epoch. :return: A dict that contains information about validation ''' self.model.eval() self.val_metrics.reset() for step_idx, input_data_item in enumerate(self.valid_data_loader): batch_size = input_data_item['batch_size'] images = input_data_item['images'] text_label = input_data_item['labels'] if self.distributed: word_acc, word_acc_case_ins, edit_distance_acc, total_distance_ref, batch_total = \ self._distributed_predict(batch_size, images, text_label) else: # one cpu or gpu non-distributed mode with torch.no_grad(): images = images.to(self.device) # target = LabelTransformer.encode(text_label) # target = target.to(self.device) # target = target.permute(1, 0) if hasattr(self.model, 'module'): model = self.model.module else: model = self.model # (bs, max_len) outputs = decode_util.greedy_decode( model, images, LabelTransformer.max_length, LabelTransformer.SOS, padding_symbol=LabelTransformer.PAD, device=images.device, padding=True) correct = 0 correct_case_ins = 0 total_distance_ref = 0 total_edit_distance = 0 for index, (pred, text_gold) in enumerate( zip(outputs[:, 1:], text_label)): predict_text = "" for i in range(len(pred)): # decode one sample if pred[i] == LabelTransformer.EOS: break if pred[i] == LabelTransformer.UNK: continue decoded_char = LabelTransformer.decode(pred[i]) predict_text += decoded_char # calculate edit distance ref = len(text_gold) edit_distance = distance.levenshtein( text_gold, predict_text) total_distance_ref += ref total_edit_distance += edit_distance # calculate word accuracy related # predict_text = predict_text.strip() # text_gold = text_gold.strip() if predict_text == text_gold: correct += 1 if predict_text.lower() == text_gold.lower(): correct_case_ins += 1 batch_total = images.shape[ 0] # valid batch size of current steps # calculate accuracy directly, due to non-distributed word_acc = correct / batch_total word_acc_case_ins = correct_case_ins / batch_total edit_distance_acc = 1 - total_edit_distance / total_distance_ref # update valid metric and write to tensorboard, self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + step_idx, 'valid') \ if self.local_master else None # self.val_metrics.update('loss', loss, batch_total) # tag is loss/valid (mode =valid) self.val_metrics.update('word_acc', word_acc, batch_total) self.val_metrics.update('word_acc_case_insensitive', word_acc_case_ins, batch_total) self.val_metrics.update('edit_distance_acc', edit_distance_acc, total_distance_ref) val_metric_res_dict = self.val_metrics.result() # rollback to train mode self.model.train() return val_metric_res_dict def _distributed_predict(self, batch_size, images, text_label): # Allows distributed prediction on uneven batches. # Test set isn't always large enough for every GPU to get a batch # obtain the sum of all val metrics at all processes by all_reduce operation # dist.barrier() # batch_size = images.size(0) correct = correct_case_ins = valid_batches = total_edit_distance = total_distance_ref = 0 if batch_size: # not empty samples at current gpu validation process with torch.no_grad(): images = images.to(self.device) # target = LabelTransformer.encode(text_label) # target = target.to(self.device) # target = target.permute(1, 0) if hasattr(self.model, 'module'): model = self.model.module else: model = self.model outputs = decode_util.greedy_decode( model, images, LabelTransformer.max_length, LabelTransformer.SOS, padding_symbol=LabelTransformer.PAD, device=images.device, padding=True) for index, (pred, text_gold) in enumerate( zip(outputs[:, 1:], text_label)): predict_text = "" for i in range(len(pred)): # decode one sample if pred[i] == LabelTransformer.EOS: break if pred[i] == LabelTransformer.UNK: continue decoded_char = LabelTransformer.decode(pred[i]) predict_text += decoded_char # calculate edit distance ref = len(text_gold) edit_distance = distance.levenshtein( text_gold, predict_text) total_distance_ref += ref total_edit_distance += edit_distance # calculate word accuracy related # predict_text = predict_text.strip() # text_gold = text_gold.strip() if predict_text == text_gold: correct += 1 if predict_text.lower() == text_gold.lower(): correct_case_ins += 1 valid_batches = 1 # can be regard as dist.world_size # sum metrics across all valid process sum_metrics_tensor = torch.tensor([ batch_size, valid_batches, correct, correct_case_ins, total_edit_distance, total_distance_ref ]).float().to(self.device) # # Use a barrier() to make sure that all process have finished above code # dist.barrier() sum_metrics_tensor = self.sum_tesnor(sum_metrics_tensor) sum_metrics_tensor = sum_metrics_tensor.cpu().numpy() batch_total, valid_batches = sum_metrics_tensor[0:2] # averages metric across the valid process # loss= sum_metrics_tensor[2] / valid_batches correct, correct_case_ins, total_edit_distance, total_distance_ref = sum_metrics_tensor[ 2:] word_acc = correct / batch_total word_acc_case_ins = correct_case_ins / batch_total edit_distance_acc = 1 - total_edit_distance / total_distance_ref return word_acc, word_acc_case_ins, edit_distance_acc, total_distance_ref, batch_total def _get_lr(self): for group in self.optimizer.param_groups: return group['lr'] def average_gradients(self, model): ''' Gradient averaging :param model: :return: ''' size = float(dist.get_world_size()) for param in model.parameters(): dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM) param.grad.data /= size def mean_reduce_tensor(self, tensor: torch.Tensor): ''' averages tensor across the whole world''' sum_tensor = self.sum_tesnor(tensor) return sum_tensor / dist.get_world_size() def sum_tesnor(self, tensor: torch.Tensor): '''obtain the sum of tensor at all processes''' rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) return rt def logger_info(self, msg): self.logger.info(msg) if self.local_master else None def logger_warning(self, msg): self.logger.warning(msg) if self.local_master else None def _prepare_device(self, local_rank, local_world_size): ''' setup GPU device if available, move model into configured device :param local_rank: :param local_world_size: :return: ''' if self.distributed: ngpu_per_process = torch.cuda.device_count() // local_world_size device_ids = list( range(local_rank * ngpu_per_process, (local_rank + 1) * ngpu_per_process)) if torch.cuda.is_available() and local_rank != -1: torch.cuda.set_device( device_ids[0] ) # device_ids[0] =local_rank if local_world_size = n_gpu per node device = 'cuda' self.logger_info( f"[Process {os.getpid()}] world_size = {dist.get_world_size()}, " + f"rank = {dist.get_rank()}, n_gpu/process = {ngpu_per_process}, device_ids = {device_ids}" ) else: self.logger_warning('Training will be using CPU!') device = 'cpu' device = torch.device(device) return device, device_ids else: n_gpu = torch.cuda.device_count() n_gpu_use = local_world_size if n_gpu_use > 0 and n_gpu == 0: self.logger_warning( "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.") n_gpu_use = 0 if n_gpu_use > n_gpu: self.logger_warning( "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(n_gpu_use, n_gpu)) n_gpu_use = n_gpu list_ids = list(range(n_gpu_use)) if n_gpu_use > 0: torch.cuda.set_device( list_ids[0]) # only use first available gpu as devices self.logger_warning(f'Training is using GPU {list_ids[0]}!') device = 'cuda' else: self.logger_warning('Training is using CPU!') device = 'cpu' device = torch.device(device) return device, list_ids def _save_checkpoint(self, epoch, save_best=False, step_idx=None): ''' Saving checkpoints :param epoch: current epoch number :param save_best: if True, rename the saved checkpoint to 'model_best.pth' :return: ''' # only both local and global master process do save model if not (self.local_master and self.global_master): return if hasattr(self.model, 'module'): arch_name = type(self.model.module).__name__ model_state_dict = self.model.module.state_dict() else: arch_name = type(self.model).__name__ model_state_dict = self.model.state_dict() state = { 'arch': arch_name, 'epoch': epoch, 'model_state_dict': model_state_dict, 'optimizer': self.optimizer.state_dict(), 'monitor_best': self.monitor_best, 'config': self.config } if step_idx is None: filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) else: filename = str( self.checkpoint_dir / 'checkpoint-epoch{}-step{}.pth'.format(epoch, step_idx)) torch.save(state, filename) self.logger_info("Saving checkpoint: {} ...".format(filename)) if save_best: best_path = str(self.checkpoint_dir / 'model_best.pth') shutil.copyfile(filename, best_path) self.logger_info( f"Saving current best (at {epoch} epoch): model_best.pth Best {self.monitor_metric}: {self.monitor_best:.6f}" ) # if save_best: # best_path = str(self.checkpoint_dir / 'model_best.pth') # torch.save(state, best_path) # self.logger_info( # f"Saving current best: model_best.pth Best {self.monitor_metric}: {self.monitor_best:.6f}.") # else: # filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) # torch.save(state, filename) # self.logger_info("Saving checkpoint: {} ...".format(filename)) def _resume_checkpoint(self, resume_path): ''' Resume from saved checkpoints :param resume_path: Checkpoint path to be resumed :return: ''' resume_path = str(resume_path) self.logger_info("Loading checkpoint: {} ...".format(resume_path)) # map_location = {'cuda:%d' % 0: 'cuda:%d' % self.config['local_rank']} checkpoint = torch.load(resume_path, map_location=self.device) self.start_epoch = checkpoint['epoch'] + 1 if not self.finetune else 1 self.monitor_best = checkpoint['monitor_best'] # load architecture params from checkpoint. if checkpoint['config']['model_arch'] != self.config[ 'model_arch']: # TODO verify adapt and adv arch self.logger_warning( "Warning: Architecture configuration given in config file is different from that of " "checkpoint. This may yield an exception while state_dict is being loaded." ) # self.model.load_state_dict(checkpoint['state_dict']) self.model.load_state_dict(checkpoint['model_state_dict']) # load optimizer state from checkpoint only when optimizer type is not changed. if not self.finetune: # resume mode will load optimizer state and continue train if checkpoint['config']['optimizer']['type'] != self.config[ 'optimizer']['type']: self.logger_warning( "Warning: Optimizer type given in config file is different from that of checkpoint. " "Optimizer parameters not being resumed.") else: self.optimizer.load_state_dict(checkpoint['optimizer']) if self.finetune: self.logger_info( "Checkpoint loaded. Finetune training from epoch {}".format( self.start_epoch)) else: self.logger_info( "Checkpoint loaded. Resume training from epoch {}".format( self.start_epoch))
class Trainer(BaseTrainer): """ Trainer class Note: Inherited from BaseTrainer. """ def __init__(self, model, metrics, optimizer, config, train_dataset, valid_datasets, lr_scheduler=None): super().__init__(model, metrics, optimizer, config, train_dataset) self.config = config self.config['data_loaders']['valid']['args']['batch_size'] = self.data_loader.batch_size self.valid_data_loaders = {} for corpus, valid_dataset in valid_datasets.items(): self.valid_data_loaders[corpus] = config.init_obj('data_loaders.valid', module_loader, valid_dataset) self.lr_scheduler = lr_scheduler self.log_step = math.ceil(len(self.data_loader.dataset) / np.sqrt(self.data_loader.batch_size) / 200) self.writer = TensorboardWriter(config.log_dir) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Current training epoch. :return: A log that contains all information you want to save. Note: If you have additional information to record, for example: > additional_log = {"x": x, "y": y} merge it with log before return. i.e. > log = {**log, **additional_log} > return log The metrics in log must have the key 'metrics'. """ self.model.train() self.optimizer.zero_grad() total_loss = 0 accum_count = 0 for step, batch in enumerate(self.data_loader): # (input_ids, input_mask, segment_ids, ng_token_mask, target, deps, task) batch = {label: t.to(self.device, non_blocking=True) for label, t in batch.items()} current_step = (epoch - 1) * len(self.data_loader) + step loss, *_ = self.model(**batch, progress=current_step / self.total_step) if len(loss.size()) > 0: loss = loss.mean() # mean() to average on multi-gpu parallel training loss_value = loss.item() # mean loss for this step total_loss += loss_value * next(iter(batch.values())).size(0) if step % self.log_step == 0: self.logger.info('Train Epoch: {} [{}/{} ({:.0f}%)] Time: {} Loss: {:.6f}'.format( epoch, step * self.data_loader.batch_size, len(self.data_loader.dataset), 100.0 * step / len(self.data_loader), datetime.datetime.now().strftime('%H:%M:%S'), loss_value)) gradient_accumulation_steps = self.gradient_accumulation_steps if step >= (len(self.data_loader) // self.gradient_accumulation_steps) * self.gradient_accumulation_steps: gradient_accumulation_steps = len(self.data_loader) % self.gradient_accumulation_steps # fraction if gradient_accumulation_steps > 1: loss = loss / gradient_accumulation_steps # scale loss loss.backward() accum_count += 1 if accum_count == gradient_accumulation_steps: self.writer.set_step(self.writer.step + 1) self.writer.add_scalar('lr', self.lr_scheduler.get_last_lr()[0]) self.writer.add_scalar('loss', loss_value) self.writer.add_scalar('progress', current_step / self.total_step) self.optimizer.step() self.optimizer.zero_grad() accum_count = 0 if self.lr_scheduler is not None: self.lr_scheduler.step() log = { 'loss': total_loss / len(self.data_loader.dataset), } corpus2result = {} for corpus, valid_data_loader in self.valid_data_loaders.items(): val_log = self._valid_epoch(valid_data_loader, corpus) log[f'val_{corpus}_loss'] = val_log['loss'] corpus2result[corpus] = val_log['result'] if 'all' not in corpus2result: corpus2result['all'] = val_log['result'] else: corpus2result['all'] += val_log['result'] for corpus, result in corpus2result.items(): val_metrics = self._eval_metrics(result.to_dict(), corpus) log.update({f'val_{corpus}_{met.__name__}': v for met, v in zip(self.metrics, val_metrics)}) return log def _valid_epoch(self, data_loader, corpus): """ Validate after training an epoch :return: A log that contains information about validation Note: The validation metrics in log must have the key 'val_metrics'. """ self.model.eval() total_loss = 0 arguments_set: List[List[List[int]]] = [] contingency_set: List[int] = [] with torch.no_grad(): for step, batch in enumerate(data_loader): batch = {label: t.to(self.device, non_blocking=True) for label, t in batch.items()} loss, *output = self.model(**batch) if len(loss.size()) > 0: loss = loss.mean() pas_scores = output[0] # (b, seq, case, seq) if corpus != 'commonsense': arguments_set += torch.argmax(pas_scores, dim=3).tolist() # (b, seq, case) total_loss += loss.item() * pas_scores.size(0) if step % self.log_step == 0: self.logger.info('Validation [{}/{} ({:.0f}%)] Time: {}'.format( step * data_loader.batch_size, len(data_loader.dataset), 100.0 * step / len(data_loader), datetime.datetime.now().strftime('%H:%M:%S'))) log = {'loss': total_loss / len(data_loader.dataset)} self.writer.add_scalar(f'loss/{corpus}', log['loss']) if corpus != 'commonsense': dataset = data_loader.dataset prediction_writer = PredictionKNPWriter(dataset, self.logger) documents_pred = prediction_writer.write(arguments_set, None, add_pas_tag=False) targets2label = {tuple(): '', ('pred',): 'pred', ('noun',): 'noun', ('pred', 'noun'): 'all'} scorer = Scorer(documents_pred, dataset.gold_documents, target_cases=dataset.target_cases, target_exophors=dataset.target_exophors, coreference=dataset.coreference, bridging=dataset.bridging, pas_target=targets2label[tuple(dataset.pas_targets)]) result = scorer.run() log['result'] = result else: log['f1'] = self._eval_commonsense(contingency_set) return log def _eval_commonsense(self, contingency_set: List[int]) -> float: valid_data_loader = self.valid_data_loaders['commonsense'] gold = [f.label for f in valid_data_loader.dataset.features] f1 = f1_score(gold, contingency_set) self.writer.add_scalar(f'commonsense_f1', f1) return f1 def _eval_metrics(self, result: dict, corpus: str): f1_metrics = np.zeros(len(self.metrics)) for i, metric in enumerate(self.metrics): f1_metrics[i] += metric(result) self.writer.add_scalar(f'{metric.__name__}/{corpus}', f1_metrics[i]) return f1_metrics
class BaseTrainer: """ Base class for all trainers """ def __init__(self, model, criterion, metric_ftns, optimizer, lr_scheduler, config, trainloader, validloader=None, len_epoch=None): self.config = config self.logger = config.get_logger('trainer', config['trainer']['verbosity']) self.trainloader = trainloader self.validloader = validloader if len_epoch is None: # epoch-based training self.len_epoch = len(self.trainloader) else: # iteration-based training self.trainloader = inf_loop(trainloader) self.len_epoch = len_epoch # setup GPU device if available, move model into configured device n_gpu_use = torch.cuda.device_count() self.device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') self.model = model.to(self.device) self.model = torch.nn.DataParallel(model) self.criterion = criterion self.metric_ftns = metric_ftns self.optimizer = optimizer self.lr_scheduler = lr_scheduler cfg_trainer = config['trainer'] self.epochs = cfg_trainer['epochs'] self.log_step = cfg_trainer['log_step'] self.save_period = cfg_trainer['save_period'] self.monitor = cfg_trainer.get('monitor', 'off') self.start_epoch = 1 self.checkpoint_dir = config.save_dir # configuration to monitor model performance and save best if self.monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = inf if self.mnt_mode == 'min' else -inf self.early_stop = cfg_trainer.get('early_stop', inf) # setup visualization writer instance self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) if config.resume is not None: self._resume_checkpoint(config.resume) @abstractmethod def _train_step(self, batch): """ Training logic for a step :param batch: batch of current step :return: loss: torch Variable with map for backwarding mets: metrics computed between output and target, dict """ raise NotImplementedError @abstractmethod def _valid_step(self, batch): """ Valid logic for a step :param batch: batch of current step :return: loss: torch Variable without map mets: metrics computed between output and target, dict """ raise NotImplementedError def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() tic = time.time() datatime = batchtime = 0 for batch_idx, batch in enumerate(self.trainloader): datatime += time.time() - tic # ------------------------------------------------------------------------- loss, mets = self._train_step(batch) # ------------------------------------------------------------------------- self.optimizer.zero_grad() loss.backward() self.optimizer.step() batchtime += time.time() - tic tic = time.time() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for key, val in mets.items(): self.train_metrics.update(key, val) if batch_idx % self.log_step == 0: processed_percent = batch_idx / self.len_epoch * 100 self.logger.debug( 'Train Epoch:{} [{}/{}]({:.0f}%)\tTime:{:5.2f}/{:<5.2f}\tLoss:({:.4f}){:.4f}' .format(epoch, batch_idx, self.len_epoch, processed_percent, datatime, batchtime, loss.item(), self.train_metrics.avg('loss'))) datatime = batchtime = 0 if batch_idx == self.len_epoch: break log = self.train_metrics.result() log = {'train_' + k: v for k, v in log.items()} if self.validloader is not None: val_log = self._valid_epoch(epoch) log.update(**{'valid_' + k: v for k, v in val_log.items()}) return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() for batch_idx, batch in enumerate(self.validloader): # ------------------------------------------------------------------------- loss, mets = self._valid_step(batch) # ------------------------------------------------------------------------- self.writer.set_step( (epoch - 1) * len(self.validloader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for key, val in mets.items(): self.valid_metrics.update(key, val) return self.valid_metrics.result() def train(self): """ Full training logic """ not_improved_count = 0 for epoch in range(self.start_epoch, self.epochs + 1): result = self._train_epoch(epoch) # save logged informations into log dict lr = self.optimizer.param_groups[0]['lr'] log = {'epoch': epoch, 'lr': lr} log.update(result) # print logged informations to the screen for key, value in log.items(): self.logger.info(' {:20s}: {}'.format(str(key), value)) # evaluate model performance according to configured metric, save best checkpoint as model_best best = False if self.mnt_mode != 'off': try: # check whether model performance improved or not, according to specified metric(mnt_metric) improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) except KeyError: self.logger.warning( "Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( self.mnt_metric)) self.mnt_mode = 'off' improved = False if improved: self.mnt_best = log[self.mnt_metric] not_improved_count = 0 best = True else: not_improved_count += 1 if not_improved_count > self.early_stop: self.logger.info( "Validation performance didn\'t improve for {} epochs. " "Training stops.".format(self.early_stop)) break if self.lr_scheduler is not None: if isinstance(self.lr_scheduler, ReduceLROnPlateau): self.lr_scheduler.step(log[self.mnt_metric]) else: self.lr_scheduler.step() if epoch % self.save_period == 0: self._save_checkpoint(epoch, save_best=best) # add histogram of model parameters to the tensorboard self.writer.set_step(epoch) for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') def _save_checkpoint(self, epoch, save_best=False): """ Saving checkpoints :param epoch: current epoch number :param log: logging information of the epoch :param save_best: if True, rename the saved checkpoint to 'model_best.pth' """ state = { 'epoch': epoch, 'model': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'lr_scheduler': self.lr_scheduler.state_dict(), 'monitor_best': self.mnt_best } filename = str(self.checkpoint_dir / 'chkpt_{:03d}.pth'.format(epoch)) torch.save(state, filename) self.logger.info("Saving checkpoint: {} ...".format(filename)) if save_best: best_path = str(self.checkpoint_dir / 'model_best.pth') torch.save(state, best_path) self.logger.info("Saving current best: model_best.pth ...") def _resume_checkpoint(self, resume_path): """ Resume from saved checkpoints :param resume_path: Checkpoint path to be resumed """ resume_path = str(resume_path) self.logger.info("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path) try: self.start_epoch = checkpoint['epoch'] + 1 self.model.module.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) self.mnt_best = checkpoint['monitor_best'] except KeyError: self.model.module.load_state_dict(checkpoint) self.logger.info( "Checkpoint loaded. Resume training from epoch {}".format( self.start_epoch))