class Trainer(): def __init__(self, model, criterion, metrics_name, optimizer, train_loader, logger, log_dir, nb_epochs, save_dir, device="cuda:0", log_step=10, start_epoch=0, enable_tensorboard=True, valid_loader=None, lr_scheduler=None, monitor="min val_loss", early_stop=10, save_epoch_period=1, resume=""): self.model = model self.criterion = criterion self.metrics_name = metrics_name self.optimizer = optimizer self.train_loader = train_loader self.valid_loader = valid_loader self.len_epoch = len(self.train_loader) self.do_validation = (self.valid_loader is not None) self.lr_scheduler = lr_scheduler self.log_step = log_step self.epochs = nb_epochs self.start_epoch = start_epoch + 1 self.logger = logger self.device = device self.save_period = save_epoch_period self.writer = TensorboardWriter(log_dir, self.logger, enable_tensorboard) self.train_metrics = MetricTracker('loss', *self.metrics_name, writer=self.writer) self.valid_metrics = MetricTracker('loss', *self.metrics_name, writer=self.writer) self.checkpoint_dir = save_dir if monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = inf if self.mnt_mode == 'min' else -inf self.early_stop = early_stop if resume != "": self._resume_checkpoint(resume_path=resume) self.model.to(self.device) def train(self): not_improved_count = 0 for epoch in range(self.start_epoch, self.epochs + 1): result = self._train_epoch(epoch) log = {'epoch': epoch} log.update(result) self.logger.info(' {:15s}: {}'.format(str("mnt best"), self.mnt_best)) for key, value in log.items(): self.logger.info(' {:15s}: {}'.format(str(key), value)) best = False if self.mnt_mode != 'off': try: # check whether model performance improved or not, according to specified metric(mnt_metric) improved = (self.mnt_mode == 'min' and log[self.mnt_metric] < self.mnt_best) or \ (self.mnt_mode == 'max' and log[self.mnt_metric] > self.mnt_best) except KeyError: self.logger.warning( "Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( self.mnt_metric)) self.mnt_mode = 'off' improved = False if improved: self.mnt_best = log[self.mnt_metric] not_improved_count = 0 best = True else: not_improved_count += 1 if (not_improved_count > self.early_stop) and (self.early_stop > 0): self.logger.info( "Validation performance didn\'t improve for {} epochs. " "Training stops.".format(self.early_stop)) break if epoch % self.save_period == 0: self._save_checkpoint(epoch, best) def _train_epoch(self, epoch): self.model.train() self.train_metrics.reset() start_time = time.time() for batch_idx, sample in enumerate(self.train_loader): data = sample['image'] target = sample['mask'] data, target = data.to(self.device), target.to(self.device) current_lr = self.lr_scheduler(self.optimizer, batch_idx, epoch) self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met_name in self.metrics_name: self.train_metrics.update( met_name, getattr(metrics, met_name)(output, target)) if batch_idx % self.log_step == 0: time_to_run = time.time() - start_time start_time = time.time() speed = self.log_step / time_to_run self.logger.debug('Train Epoch: {} {} Loss: {:.6f} LR: {:.6f} Speed: {:.4f}iters/s' \ .format(epoch, self._progress(batch_idx), loss.item(), current_lr, speed)) for met_name in self.metrics_name: self.writer.add_scalar(met_name, self.train_metrics.avg(met_name)) self.writer.add_scalar('loss', self.train_metrics.avg('loss')) self.writer.add_scalar("lr", current_lr) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) assert batch_idx <= self.len_epoch log = self.train_metrics.result() if self.do_validation: print("Start validation") val_log, iou_classes = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) for key, value in iou_classes.items(): log.update({key: value}) return log def _valid_epoch(self, epoch): self.model.eval() self.valid_metrics.reset() iou_tracker = metrics.IoU(2) with torch.no_grad(): for batch_idx, sample in enumerate(self.valid_loader): data = sample['image'] target = sample['mask'] data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.writer.set_step( (epoch - 1) * len(self.valid_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) target = target.cpu().numpy() output = output[:, 0] output = output.data.cpu().numpy() pred = np.zeros_like(output) pred[output > 0.5] = 1 pred = pred.astype(np.int64) for i in range(len(target)): iou_tracker.add_batch(target[i], pred[i]) iou_classes = iou_tracker.get_iou() for key, value in iou_classes.items(): self.writer.add_scalar(key, value) self.writer.add_scalar('val_loss', self.valid_metrics.avg('loss')) for met_name in self.metrics_name: self.writer.add_scalar(met_name, self.valid_metrics.avg(met_name)) # for name, p in self.model.named_parameters(): # print(name, p) # self.writer.add_histogram(name, p.cpu().data.numpy(), bins='auto') # return self.valid_metrics.result(), iou_classes def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total) def _save_checkpoint(self, epoch, save_best=False): """ Saving checkpoints :param epoch: current epoch number :param log: logging information of the epoch :param save_best: if True, rename the saved checkpoint to 'model_best.pth' """ arch = type(self.model).__name__ state = { 'arch': arch, 'epoch': epoch, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'monitor_best': self.mnt_best, # 'config': self.config } filename = str(self.checkpoint_dir / 'checkpoint-epoch{:06d}.pth'.format(epoch)) torch.save(state, filename) self.delete_checkpoint() self.logger.info("Saving checkpoint: {} ...".format(filename)) if save_best: best_path = str(self.checkpoint_dir / 'model_best.pth') torch.save(state, best_path) self.logger.info("Saving current best: model_best.pth ...") def delete_checkpoint(self): checkpoints_file = list( self.checkpoint_dir.glob("checkpoint-epoch*.pth")) checkpoints_file.sort() for checkpoint_file in checkpoints_file[:-5]: os.remove(str(checkpoint_file.absolute())) def _resume_checkpoint(self, resume_path): self.logger.info("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path) self.start_epoch = checkpoint['epoch'] + 1 self.mnt_best = checkpoint['monitor_best'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info( "Checkpoint loaded. Resume training from epoch {}".format( self.start_epoch))
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, config, model, optimizer, data_loader, writer, checkpoint_dir, logger, class_dict, valid_data_loader=None, test_data_loader=None, lr_scheduler=None, metric_ftns=None): super(Trainer, self).__init__(config, data_loader, writer, checkpoint_dir, logger, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, metric_ftns=metric_ftns) if (self.config.cuda): use_cuda = torch.cuda.is_available() self.device = torch.device("cuda" if use_cuda else "cpu") else: self.device = torch.device("cpu") self.start_epoch = 1 self.train_data_loader = data_loader self.len_epoch = self.config.dataloader.train.batch_size * len( self.train_data_loader) self.epochs = self.config.epochs self.valid_data_loader = valid_data_loader self.test_data_loader = test_data_loader self.do_validation = self.valid_data_loader is not None self.do_test = self.test_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = self.config.log_interval self.model = model self.num_classes = len(class_dict) self.optimizer = optimizer self.mnt_best = np.inf if self.config.dataset.type == 'multi_target': self.criterion = torch.nn.BCEWithLogitsLoss(reduction='mean') else: self.criterion = torch.nn.CrossEntropyLoss(reduction='mean') self.checkpoint_dir = checkpoint_dir self.gradient_accumulation = config.gradient_accumulation self.writer = writer self.metric_ftns = ['loss', 'acc'] self.train_metrics = MetricTracker(*[m for m in self.metric_ftns], writer=self.writer, mode='train') self.metric_ftns = ['loss', 'acc'] self.valid_metrics = MetricTracker(*[m for m in self.metric_ftns], writer=self.writer, mode='validation') self.logger = logger self.confusion_matrix = torch.zeros(self.num_classes, self.num_classes) def _train_epoch(self, epoch): """ Training logic for an epoch Args: epoch (int): current training epoch. """ self.model.train() self.confusion_matrix = 0 * self.confusion_matrix self.train_metrics.reset() gradient_accumulation = self.gradient_accumulation for batch_idx, (data, target) in enumerate(self.train_data_loader): data = data.to(self.device) target = target.to(self.device) output = self.model(data) loss = self.criterion(output, target) loss = loss.mean() (loss / gradient_accumulation).backward() if (batch_idx % gradient_accumulation == 0): self.optimizer.step() # Now we can do an optimizer step self.optimizer.zero_grad() # Reset gradients tensors prediction = torch.max(output, 1) writer_step = (epoch - 1) * self.len_epoch + batch_idx self.train_metrics.update(key='loss', value=loss.item(), n=1, writer_step=writer_step) self.train_metrics.update( key='acc', value=np.sum(prediction[1].cpu().numpy() == target.squeeze( -1).cpu().numpy()), n=target.size(0), writer_step=writer_step) for t, p in zip(target.cpu().view(-1), prediction[1].cpu().view(-1)): self.confusion_matrix[t.long(), p.long()] += 1 self._progress(batch_idx, epoch, metrics=self.train_metrics, mode='train') self._progress(batch_idx, epoch, metrics=self.train_metrics, mode='train', print_summary=True) def _valid_epoch(self, epoch, mode, loader): """ Args: epoch (int): current epoch mode (string): 'validation' or 'test' loader (dataloader): Returns: validation loss """ self.model.eval() self.valid_sentences = [] self.valid_metrics.reset() self.confusion_matrix = 0 * self.confusion_matrix with torch.no_grad(): for batch_idx, (data, target) in enumerate(loader): data = data.to(self.device) target = target.to(self.device) output = self.model(data) loss = self.criterion(output, target) loss = loss.mean() writer_step = (epoch - 1) * len(loader) + batch_idx prediction = torch.max(output, 1) acc = np.sum(prediction[1].cpu().numpy() == target.squeeze( -1).cpu().numpy()) / target.size(0) self.valid_metrics.update(key='loss', value=loss.item(), n=1, writer_step=writer_step) self.valid_metrics.update( key='acc', value=np.sum(prediction[1].cpu().numpy() == target.squeeze( -1).cpu().numpy()), n=target.size(0), writer_step=writer_step) for t, p in zip(target.cpu().view(-1), prediction[1].cpu().view(-1)): self.confusion_matrix[t.long(), p.long()] += 1 self._progress(batch_idx, epoch, metrics=self.valid_metrics, mode=mode, print_summary=True) s = sensitivity(self.confusion_matrix.numpy()) ppv = positive_predictive_value(self.confusion_matrix.numpy()) print(f" s {s} ,ppv {ppv}") val_loss = self.valid_metrics.avg('loss') return val_loss def train(self): """ Train the model """ for epoch in range(self.start_epoch, self.epochs): torch.manual_seed(self.config.seed) self._train_epoch(epoch) self.logger.info(f"{'!' * 10} VALIDATION , {'!' * 10}") validation_loss = self._valid_epoch(epoch, 'validation', self.valid_data_loader) make_dirs(self.checkpoint_dir) self.checkpointer(epoch, validation_loss) self.lr_scheduler.step(validation_loss) if self.do_test: self.logger.info(f"{'!' * 10} VALIDATION , {'!' * 10}") self.predict(epoch) def predict(self, epoch): """ Inference Args: epoch (): Returns: """ self.model.eval() predictions = [] with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.test_data_loader): data = data.to(self.device) logits = self.model(data, None) maxes, prediction = torch.max( logits, 1) # get the index of the max log-probability # log.info() predictions.append( f"{target[0]},{prediction.cpu().numpy()[0]}") pred_name = os.path.join( self.checkpoint_dir, f'validation_predictions_epoch_{epoch:d}_.csv') write_csv(predictions, pred_name) return predictions def checkpointer(self, epoch, metric): is_best = metric < self.mnt_best if (is_best): self.mnt_best = metric self.logger.info(f"Best val loss {self.mnt_best} so far ") # else: # self.gradient_accumulation = self.gradient_accumulation // 2 # if self.gradient_accumulation < 4: # self.gradient_accumulation = 4 save_model(self.checkpoint_dir, self.model, self.optimizer, self.valid_metrics.avg('loss'), epoch, f'_model_best') save_model(self.checkpoint_dir, self.model, self.optimizer, self.valid_metrics.avg('loss'), epoch, f'_model_last') def _progress(self, batch_idx, epoch, metrics, mode='', print_summary=False): metrics_string = metrics.calc_all_metrics() if ((batch_idx * self.config.dataloader.train.batch_size) % self.log_step == 0): if metrics_string == None: self.logger.warning(f" No metrics") else: self.logger.info( f"{mode} Epoch: [{epoch:2d}/{self.epochs:2d}]\t Sample [{batch_idx * self.config.dataloader.train.batch_size:5d}/{self.len_epoch:5d}]\t {metrics_string}" ) elif print_summary: self.logger.info( f'{mode} summary Epoch: [{epoch}/{self.epochs}]\t {metrics_string}' )
class Tester(BaseTrainer): """ Trainer class """ def __init__(self, config, model, data_loader, writer, checkpoint_dir, logger, valid_data_loader=None, test_data_loader=None, metric_ftns=None): super(Tester, self).__init__(config, data_loader, writer, checkpoint_dir, logger, valid_data_loader=valid_data_loader, test_data_loader=test_data_loader, metric_ftns=metric_ftns) if (self.config.cuda): use_cuda = torch.cuda.is_available() self.device = torch.device("cuda" if use_cuda else "cpu") else: self.device = torch.device("cpu") self.start_epoch = 1 self.epochs = self.config.epochs self.valid_data_loader = valid_data_loader self.test_data_loader = test_data_loader self.do_validation = self.valid_data_loader is not None self.do_test = self.test_data_loader is not None self.log_step = self.config.log_interval self.model = model self.mnt_best = np.inf self.checkpoint_dir = checkpoint_dir self.gradient_accumulation = config.gradient_accumulation self.metric_ftns = ['loss', 'acc'] self.valid_metrics = MetricTracker(*[m for m in self.metric_ftns], writer=self.writer, mode='validation') self.logger = logger def _valid_epoch(self, epoch, mode, loader): """ Args: epoch (int): current epoch mode (string): 'validation' or 'test' loader (dataloader): Returns: validation loss """ self.model.eval() self.valid_sentences = [] self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(loader): data = data.to(self.device) target = target.long().to(self.device) output, loss = self.model(data, target) loss = loss.mean() writer_step = (epoch - 1) * len(loader) + batch_idx prediction = torch.max(output, 1) acc = np.sum(prediction[1].cpu().numpy() == target.cpu().numpy()) / target.size(0) self.valid_metrics.update(key='loss',value=loss.item(),n=1,writer_step=writer_step) self.valid_metrics.update(key='acc', value=np.sum(prediction[1].cpu().numpy() == target.cpu().numpy()), n=target.size(0), writer_step=writer_step) self._progress(batch_idx, epoch, metrics=self.valid_metrics, mode=mode, print_summary=True) val_loss = self.valid_metrics.avg('loss') return val_loss def predict(self): """ Inference Args: epoch (): Returns: """ self.model.eval() predictions = [] with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.test_data_loader): data = data.to(self.device) logits = self.model(data, None) maxes, prediction = torch.max(logits, 1) # get the index of the max log-probability predictions.append(f"{target[0]},{prediction.cpu().numpy()[0]}") self.logger.info('Inference done') pred_name = os.path.join(self.checkpoint_dir, f'predictions.csv') write_csv(predictions, pred_name) return predictions def _progress(self, batch_idx, epoch, metrics, mode='', print_summary=False): metrics_string = metrics.calc_all_metrics() if ((batch_idx * self.config.dataloader.train.batch_size) % self.log_step == 0): if metrics_string == None: self.logger.warning(f" No metrics") else: self.logger.info( f"{mode} Epoch: [{epoch:2d}/{self.epochs:2d}]\t Video [{batch_idx * self.config.dataloader.train.batch_size:5d}/{self.len_epoch:5d}]\t {metrics_string}") elif print_summary: self.logger.info( f'{mode} summary Epoch: [{epoch}/{self.epochs}]\t {metrics_string}')