class Trainer: """ Implements training and validation logic """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): self.config = config self.logger = config.get_logger('trainer', config['trainer']['verbosity']) # setup GPU device if available, move model into configured device self.device, device_ids = self._prepare_device(config['n_gpu']) self.model = model.to(self.device) if len(device_ids) > 1: self.model = torch.nn.DataParallel(model, device_ids=device_ids) self.criterion = criterion self.metric_ftns = metric_ftns self.optimizer = optimizer cfg_trainer = config['trainer'] self.epochs = cfg_trainer['epochs'] self.save_period = cfg_trainer['save_period'] self.monitor = cfg_trainer.get('monitor', 'off') # configuration to monitor model performance and save best if self.monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = inf if self.mnt_mode == 'min' else -inf self.early_stop = cfg_trainer.get('early_stop', inf) self.start_epoch = 1 self.checkpoint_dir = config.save_dir # setup visualization writer instance self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) if config.resume is not None: self._resume_checkpoint(config.resume) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() epoch_target = [] epoch_scores = [] epoch_word_pairs = [] epoch_loss = [] for batch_idx, batch_data in enumerate(self.data_loader): if not batch_data: continue for field in [ 'input_ids', 'label', 'target', 'attention_mask', 'term1_mask', 'term2_mask' ]: batch_data[field] = batch_data[field].to(self.device) self.optimizer.zero_grad() output = self.model(batch_data) loss = self.criterion( output, batch_data['target'], self.data_loader.dataset.class_weights.to(self.device)) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) # accumulate epoch quantities epoch_target += [t.item() for t in batch_data['label']] epoch_scores += [output.cpu().detach().numpy()] epoch_word_pairs += batch_data['term_pair'] epoch_loss += [loss.item()] # update metrics self.writer.add_scalar("loss", loss.item()) for met in self.metric_ftns: self.writer.add_scalar(met.__name__, met(epoch_target, epoch_scores)) if batch_idx % self.log_step == 0: self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) #self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break log = { m.__name__: m(epoch_target, epoch_scores) for m in self.metric_ftns } log["loss"] = np.sum(epoch_loss) / len(self.data_loader) if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() with torch.no_grad(): epoch_target = [] epoch_scores = [] epoch_word_pairs = [] epoch_loss = [] for batch_idx, batch_data in enumerate(tqdm( self.valid_data_loader)): for field in [ 'input_ids', 'target', 'label', 'attention_mask', 'term1_mask', 'term2_mask' ]: batch_data[field] = batch_data[field].to(self.device) output = self.model(batch_data) pred = torch.argmax(output, dim=-1) loss = self.criterion( output, batch_data['target'].squeeze(-1), self.data_loader.dataset.class_weights.to(self.device)) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') # accumulate epoch quantities epoch_target += [t.item() for t in batch_data['label']] epoch_scores += [output.cpu().detach().numpy()] epoch_word_pairs += batch_data['term_pair'] epoch_loss += [loss.item()] # update metrics self.writer.add_scalar('loss', loss.item()) for met in self.metric_ftns: self.writer.add_scalar(met.__name__, met(epoch_target, epoch_scores)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') log = { m.__name__: m(epoch_target, epoch_scores) for m in self.metric_ftns } log["loss"] = np.sum(epoch_loss) / len(self.valid_data_loader) return log def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total) def train(self): """ Full training logic """ not_improved_count = 0 for epoch in range(self.start_epoch, self.epochs + 1): result = self._train_epoch(epoch) # save logged informations into log dict log = {'epoch': epoch} log.update(result) # print logged informations to the screen for key, value in log.items(): self.logger.info(' {:15s}: {}'.format(str(key), value)) # evaluate model performance according to configured metric, save best checkpoint as model_best best = False if self.mnt_mode != 'off': try: # check whether model performance improved or not, according to specified metric(mnt_metric) improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) except KeyError: self.logger.warning( "Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( self.mnt_metric)) self.mnt_mode = 'off' improved = False if improved: self.mnt_best = log[self.mnt_metric] not_improved_count = 0 best = True else: not_improved_count += 1 if not_improved_count > self.early_stop: self.logger.info( "Validation performance didn\'t improve for {} epochs. " "Training stops.".format(self.early_stop)) break if epoch % self.save_period == 0: self._save_checkpoint(epoch, save_best=best) def _prepare_device(self, n_gpu_use): """ setup GPU device if available, move model into configured device """ n_gpu = torch.cuda.device_count() if n_gpu_use > 0 and n_gpu == 0: self.logger.warning( "Warning: There\'s no GPU available on this machine," "training will be performed on CPU.") n_gpu_use = 0 if n_gpu_use > n_gpu: self.logger.warning( "Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(n_gpu_use, n_gpu)) n_gpu_use = n_gpu device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') list_ids = list(range(n_gpu_use)) return device, list_ids def _save_checkpoint(self, epoch, save_best=False): """ Saving checkpoints :param epoch: current epoch number :param log: logging information of the epoch :param save_best: if True, rename the saved checkpoint to 'model_best.pth' """ arch = type(self.model).__name__ state = { 'arch': arch, 'epoch': epoch, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'monitor_best': self.mnt_best, 'config': self.config } if save_best: best_path = str(self.checkpoint_dir / 'model_best.pth') torch.save(state, best_path) self.logger.info("Saving current best: model_best.pth ...") def _resume_checkpoint(self, resume_path): """ Resume from saved checkpoints :param resume_path: Checkpoint path to be resumed """ resume_path = str(resume_path) self.logger.info("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path) self.start_epoch = checkpoint['epoch'] + 1 self.mnt_best = checkpoint['monitor_best'] # load architecture params from checkpoint. if checkpoint['config']['arch'] != self.config['arch']: self.logger.warning( "Warning: Architecture configuration given in config file is different from that of " "checkpoint. This may yield an exception while state_dict is being loaded." ) self.model.load_state_dict(checkpoint['state_dict']) # load optimizer state from checkpoint only when optimizer type is not changed. if checkpoint['config']['optimizer']['type'] != self.config[ 'optimizer']['type']: self.logger.warning( "Warning: Optimizer type given in config file is different from that of checkpoint. " "Optimizer parameters not being resumed.") else: self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info( "Checkpoint loaded. Resume training from epoch {}".format( self.start_epoch))
class BaseTrainer: """ Base class for all trainers """ def __init__(self, model, criterion, metric_ftns, optimizer, lr_scheduler, config, trainloader, validloader=None, len_epoch=None): self.config = config self.logger = config.get_logger('trainer', config['trainer']['verbosity']) self.trainloader = trainloader self.validloader = validloader if len_epoch is None: # epoch-based training self.len_epoch = len(self.trainloader) else: # iteration-based training self.trainloader = inf_loop(trainloader) self.len_epoch = len_epoch # setup GPU device if available, move model into configured device n_gpu_use = torch.cuda.device_count() self.device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') self.model = model.to(self.device) self.model = torch.nn.DataParallel(model) self.criterion = criterion self.metric_ftns = metric_ftns self.optimizer = optimizer self.lr_scheduler = lr_scheduler cfg_trainer = config['trainer'] self.epochs = cfg_trainer['epochs'] self.log_step = cfg_trainer['log_step'] self.save_period = cfg_trainer['save_period'] self.monitor = cfg_trainer.get('monitor', 'off') self.start_epoch = 1 self.checkpoint_dir = config.save_dir # configuration to monitor model performance and save best if self.monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = inf if self.mnt_mode == 'min' else -inf self.early_stop = cfg_trainer.get('early_stop', inf) # setup visualization writer instance self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) if config.resume is not None: self._resume_checkpoint(config.resume) @abstractmethod def _train_step(self, batch): """ Training logic for a step :param batch: batch of current step :return: loss: torch Variable with map for backwarding mets: metrics computed between output and target, dict """ raise NotImplementedError @abstractmethod def _valid_step(self, batch): """ Valid logic for a step :param batch: batch of current step :return: loss: torch Variable without map mets: metrics computed between output and target, dict """ raise NotImplementedError def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() tic = time.time() datatime = batchtime = 0 for batch_idx, batch in enumerate(self.trainloader): datatime += time.time() - tic # ------------------------------------------------------------------------- loss, mets = self._train_step(batch) # ------------------------------------------------------------------------- self.optimizer.zero_grad() loss.backward() self.optimizer.step() batchtime += time.time() - tic tic = time.time() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for key, val in mets.items(): self.train_metrics.update(key, val) if batch_idx % self.log_step == 0: processed_percent = batch_idx / self.len_epoch * 100 self.logger.debug( 'Train Epoch:{} [{}/{}]({:.0f}%)\tTime:{:5.2f}/{:<5.2f}\tLoss:({:.4f}){:.4f}' .format(epoch, batch_idx, self.len_epoch, processed_percent, datatime, batchtime, loss.item(), self.train_metrics.avg('loss'))) datatime = batchtime = 0 if batch_idx == self.len_epoch: break log = self.train_metrics.result() log = {'train_' + k: v for k, v in log.items()} if self.validloader is not None: val_log = self._valid_epoch(epoch) log.update(**{'valid_' + k: v for k, v in val_log.items()}) return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() for batch_idx, batch in enumerate(self.validloader): # ------------------------------------------------------------------------- loss, mets = self._valid_step(batch) # ------------------------------------------------------------------------- self.writer.set_step( (epoch - 1) * len(self.validloader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for key, val in mets.items(): self.valid_metrics.update(key, val) return self.valid_metrics.result() def train(self): """ Full training logic """ not_improved_count = 0 for epoch in range(self.start_epoch, self.epochs + 1): result = self._train_epoch(epoch) # save logged informations into log dict lr = self.optimizer.param_groups[0]['lr'] log = {'epoch': epoch, 'lr': lr} log.update(result) # print logged informations to the screen for key, value in log.items(): self.logger.info(' {:20s}: {}'.format(str(key), value)) # evaluate model performance according to configured metric, save best checkpoint as model_best best = False if self.mnt_mode != 'off': try: # check whether model performance improved or not, according to specified metric(mnt_metric) improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) except KeyError: self.logger.warning( "Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( self.mnt_metric)) self.mnt_mode = 'off' improved = False if improved: self.mnt_best = log[self.mnt_metric] not_improved_count = 0 best = True else: not_improved_count += 1 if not_improved_count > self.early_stop: self.logger.info( "Validation performance didn\'t improve for {} epochs. " "Training stops.".format(self.early_stop)) break if self.lr_scheduler is not None: if isinstance(self.lr_scheduler, ReduceLROnPlateau): self.lr_scheduler.step(log[self.mnt_metric]) else: self.lr_scheduler.step() if epoch % self.save_period == 0: self._save_checkpoint(epoch, save_best=best) # add histogram of model parameters to the tensorboard self.writer.set_step(epoch) for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') def _save_checkpoint(self, epoch, save_best=False): """ Saving checkpoints :param epoch: current epoch number :param log: logging information of the epoch :param save_best: if True, rename the saved checkpoint to 'model_best.pth' """ state = { 'epoch': epoch, 'model': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'lr_scheduler': self.lr_scheduler.state_dict(), 'monitor_best': self.mnt_best } filename = str(self.checkpoint_dir / 'chkpt_{:03d}.pth'.format(epoch)) torch.save(state, filename) self.logger.info("Saving checkpoint: {} ...".format(filename)) if save_best: best_path = str(self.checkpoint_dir / 'model_best.pth') torch.save(state, best_path) self.logger.info("Saving current best: model_best.pth ...") def _resume_checkpoint(self, resume_path): """ Resume from saved checkpoints :param resume_path: Checkpoint path to be resumed """ resume_path = str(resume_path) self.logger.info("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path) try: self.start_epoch = checkpoint['epoch'] + 1 self.model.module.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) self.mnt_best = checkpoint['monitor_best'] except KeyError: self.model.module.load_state_dict(checkpoint) self.logger.info( "Checkpoint loaded. Resume training from epoch {}".format( self.start_epoch))