コード例 #1
0
class Trainer():
    def __init__(self,
                 model,
                 criterion,
                 metrics_name,
                 optimizer,
                 train_loader,
                 logger,
                 log_dir,
                 nb_epochs,
                 save_dir,
                 device="cuda:0",
                 log_step=10,
                 start_epoch=0,
                 enable_tensorboard=True,
                 valid_loader=None,
                 lr_scheduler=None,
                 monitor="min val_loss",
                 early_stop=10,
                 save_epoch_period=1,
                 resume=""):
        self.model = model
        self.criterion = criterion
        self.metrics_name = metrics_name
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader

        self.len_epoch = len(self.train_loader)
        self.do_validation = (self.valid_loader is not None)
        self.lr_scheduler = lr_scheduler
        self.log_step = log_step
        self.epochs = nb_epochs
        self.start_epoch = start_epoch + 1

        self.logger = logger
        self.device = device
        self.save_period = save_epoch_period

        self.writer = TensorboardWriter(log_dir, self.logger,
                                        enable_tensorboard)
        self.train_metrics = MetricTracker('loss',
                                           *self.metrics_name,
                                           writer=self.writer)
        self.valid_metrics = MetricTracker('loss',
                                           *self.metrics_name,
                                           writer=self.writer)
        self.checkpoint_dir = save_dir
        if monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = monitor.split()
            assert self.mnt_mode in ['min', 'max']
            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = early_stop
        if resume != "":
            self._resume_checkpoint(resume_path=resume)
        self.model.to(self.device)

    def train(self):
        not_improved_count = 0

        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)
            log = {'epoch': epoch}
            log.update(result)
            self.logger.info('    {:15s}: {}'.format(str("mnt best"),
                                                     self.mnt_best))
            for key, value in log.items():
                self.logger.info('    {:15s}: {}'.format(str(key), value))
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] < self.mnt_best) or \
                               (self.mnt_mode == 'max' and log[self.mnt_metric] > self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. "
                        "Model performance monitoring is disabled.".format(
                            self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False
                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1
                if (not_improved_count > self.early_stop) and (self.early_stop
                                                               > 0):
                    self.logger.info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, best)

    def _train_epoch(self, epoch):
        self.model.train()
        self.train_metrics.reset()
        start_time = time.time()

        for batch_idx, sample in enumerate(self.train_loader):
            data = sample['image']
            target = sample['mask']
            data, target = data.to(self.device), target.to(self.device)
            current_lr = self.lr_scheduler(self.optimizer, batch_idx, epoch)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for met_name in self.metrics_name:
                self.train_metrics.update(
                    met_name,
                    getattr(metrics, met_name)(output, target))
            if batch_idx % self.log_step == 0:
                time_to_run = time.time() - start_time
                start_time = time.time()
                speed = self.log_step / time_to_run
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f} LR: {:.6f}  Speed: {:.4f}iters/s' \
                                  .format(epoch, self._progress(batch_idx), loss.item(), current_lr, speed))
                for met_name in self.metrics_name:
                    self.writer.add_scalar(met_name,
                                           self.train_metrics.avg(met_name))
                self.writer.add_scalar('loss', self.train_metrics.avg('loss'))
                self.writer.add_scalar("lr", current_lr)
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
            assert batch_idx <= self.len_epoch
        log = self.train_metrics.result()
        if self.do_validation:
            print("Start validation")
            val_log, iou_classes = self._valid_epoch(epoch)

            log.update(**{'val_' + k: v for k, v in val_log.items()})
            for key, value in iou_classes.items():
                log.update({key: value})
        return log

    def _valid_epoch(self, epoch):
        self.model.eval()
        self.valid_metrics.reset()
        iou_tracker = metrics.IoU(2)
        with torch.no_grad():
            for batch_idx, sample in enumerate(self.valid_loader):
                data = sample['image']
                target = sample['mask']
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                self.writer.set_step(
                    (epoch - 1) * len(self.valid_loader) + batch_idx, 'valid')
                self.valid_metrics.update('loss', loss.item())
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
                target = target.cpu().numpy()
                output = output[:, 0]
                output = output.data.cpu().numpy()
                pred = np.zeros_like(output)
                pred[output > 0.5] = 1
                pred = pred.astype(np.int64)
                for i in range(len(target)):
                    iou_tracker.add_batch(target[i], pred[i])
        iou_classes = iou_tracker.get_iou()
        for key, value in iou_classes.items():
            self.writer.add_scalar(key, value)
        self.writer.add_scalar('val_loss', self.valid_metrics.avg('loss'))

        for met_name in self.metrics_name:
            self.writer.add_scalar(met_name, self.valid_metrics.avg(met_name))

        # for name, p in self.model.named_parameters():
        #     print(name, p)
        #     self.writer.add_histogram(name, p.cpu().data.numpy(), bins='auto')
        #
        return self.valid_metrics.result(), iou_classes

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        current = batch_idx
        total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)

    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints

        :param epoch: current epoch number
        :param log: logging information of the epoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        arch = type(self.model).__name__
        state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.mnt_best,
            # 'config': self.config
        }
        filename = str(self.checkpoint_dir /
                       'checkpoint-epoch{:06d}.pth'.format(epoch))
        torch.save(state, filename)
        self.delete_checkpoint()
        self.logger.info("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path = str(self.checkpoint_dir / 'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def delete_checkpoint(self):
        checkpoints_file = list(
            self.checkpoint_dir.glob("checkpoint-epoch*.pth"))
        checkpoints_file.sort()
        for checkpoint_file in checkpoints_file[:-5]:
            os.remove(str(checkpoint_file.absolute()))

    def _resume_checkpoint(self, resume_path):
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path)
        self.start_epoch = checkpoint['epoch'] + 1
        self.mnt_best = checkpoint['monitor_best']

        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))
コード例 #2
0
ファイル: trainer.py プロジェクト: iliasprc/COVIDNet
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 config,
                 model,
                 optimizer,
                 data_loader,
                 writer,
                 checkpoint_dir,
                 logger,
                 class_dict,
                 valid_data_loader=None,
                 test_data_loader=None,
                 lr_scheduler=None,
                 metric_ftns=None):
        super(Trainer, self).__init__(config,
                                      data_loader,
                                      writer,
                                      checkpoint_dir,
                                      logger,
                                      valid_data_loader=valid_data_loader,
                                      test_data_loader=test_data_loader,
                                      metric_ftns=metric_ftns)

        if (self.config.cuda):
            use_cuda = torch.cuda.is_available()
            self.device = torch.device("cuda" if use_cuda else "cpu")
        else:
            self.device = torch.device("cpu")
        self.start_epoch = 1
        self.train_data_loader = data_loader

        self.len_epoch = self.config.dataloader.train.batch_size * len(
            self.train_data_loader)
        self.epochs = self.config.epochs
        self.valid_data_loader = valid_data_loader
        self.test_data_loader = test_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_test = self.test_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = self.config.log_interval
        self.model = model
        self.num_classes = len(class_dict)
        self.optimizer = optimizer

        self.mnt_best = np.inf
        if self.config.dataset.type == 'multi_target':
            self.criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
        else:
            self.criterion = torch.nn.CrossEntropyLoss(reduction='mean')
        self.checkpoint_dir = checkpoint_dir
        self.gradient_accumulation = config.gradient_accumulation
        self.writer = writer
        self.metric_ftns = ['loss', 'acc']
        self.train_metrics = MetricTracker(*[m for m in self.metric_ftns],
                                           writer=self.writer,
                                           mode='train')
        self.metric_ftns = ['loss', 'acc']
        self.valid_metrics = MetricTracker(*[m for m in self.metric_ftns],
                                           writer=self.writer,
                                           mode='validation')
        self.logger = logger

        self.confusion_matrix = torch.zeros(self.num_classes, self.num_classes)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        Args:
            epoch (int): current training epoch.
        """

        self.model.train()
        self.confusion_matrix = 0 * self.confusion_matrix
        self.train_metrics.reset()
        gradient_accumulation = self.gradient_accumulation
        for batch_idx, (data, target) in enumerate(self.train_data_loader):

            data = data.to(self.device)

            target = target.to(self.device)

            output = self.model(data)
            loss = self.criterion(output, target)
            loss = loss.mean()

            (loss / gradient_accumulation).backward()
            if (batch_idx % gradient_accumulation == 0):
                self.optimizer.step()  # Now we can do an optimizer step
                self.optimizer.zero_grad()  # Reset gradients tensors

            prediction = torch.max(output, 1)

            writer_step = (epoch - 1) * self.len_epoch + batch_idx

            self.train_metrics.update(key='loss',
                                      value=loss.item(),
                                      n=1,
                                      writer_step=writer_step)
            self.train_metrics.update(
                key='acc',
                value=np.sum(prediction[1].cpu().numpy() == target.squeeze(
                    -1).cpu().numpy()),
                n=target.size(0),
                writer_step=writer_step)
            for t, p in zip(target.cpu().view(-1),
                            prediction[1].cpu().view(-1)):
                self.confusion_matrix[t.long(), p.long()] += 1
            self._progress(batch_idx,
                           epoch,
                           metrics=self.train_metrics,
                           mode='train')

        self._progress(batch_idx,
                       epoch,
                       metrics=self.train_metrics,
                       mode='train',
                       print_summary=True)

    def _valid_epoch(self, epoch, mode, loader):
        """

        Args:
            epoch (int): current epoch
            mode (string): 'validation' or 'test'
            loader (dataloader):

        Returns: validation loss

        """
        self.model.eval()
        self.valid_sentences = []
        self.valid_metrics.reset()
        self.confusion_matrix = 0 * self.confusion_matrix
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(loader):
                data = data.to(self.device)

                target = target.to(self.device)

                output = self.model(data)
                loss = self.criterion(output, target)
                loss = loss.mean()
                writer_step = (epoch - 1) * len(loader) + batch_idx

                prediction = torch.max(output, 1)
                acc = np.sum(prediction[1].cpu().numpy() == target.squeeze(
                    -1).cpu().numpy()) / target.size(0)

                self.valid_metrics.update(key='loss',
                                          value=loss.item(),
                                          n=1,
                                          writer_step=writer_step)
                self.valid_metrics.update(
                    key='acc',
                    value=np.sum(prediction[1].cpu().numpy() == target.squeeze(
                        -1).cpu().numpy()),
                    n=target.size(0),
                    writer_step=writer_step)
                for t, p in zip(target.cpu().view(-1),
                                prediction[1].cpu().view(-1)):
                    self.confusion_matrix[t.long(), p.long()] += 1
        self._progress(batch_idx,
                       epoch,
                       metrics=self.valid_metrics,
                       mode=mode,
                       print_summary=True)

        s = sensitivity(self.confusion_matrix.numpy())
        ppv = positive_predictive_value(self.confusion_matrix.numpy())
        print(f" s {s} ,ppv {ppv}")
        val_loss = self.valid_metrics.avg('loss')

        return val_loss

    def train(self):
        """
        Train the model
        """
        for epoch in range(self.start_epoch, self.epochs):
            torch.manual_seed(self.config.seed)
            self._train_epoch(epoch)

            self.logger.info(f"{'!' * 10}    VALIDATION   , {'!' * 10}")
            validation_loss = self._valid_epoch(epoch, 'validation',
                                                self.valid_data_loader)
            make_dirs(self.checkpoint_dir)

            self.checkpointer(epoch, validation_loss)
            self.lr_scheduler.step(validation_loss)
            if self.do_test:
                self.logger.info(f"{'!' * 10}    VALIDATION   , {'!' * 10}")
                self.predict(epoch)

    def predict(self, epoch):
        """
        Inference
        Args:
            epoch ():

        Returns:

        """
        self.model.eval()

        predictions = []
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.test_data_loader):
                data = data.to(self.device)

                logits = self.model(data, None)

                maxes, prediction = torch.max(
                    logits, 1)  # get the index of the max log-probability
                # log.info()
                predictions.append(
                    f"{target[0]},{prediction.cpu().numpy()[0]}")

        pred_name = os.path.join(
            self.checkpoint_dir,
            f'validation_predictions_epoch_{epoch:d}_.csv')
        write_csv(predictions, pred_name)
        return predictions

    def checkpointer(self, epoch, metric):

        is_best = metric < self.mnt_best
        if (is_best):
            self.mnt_best = metric

            self.logger.info(f"Best val loss {self.mnt_best} so far ")
            # else:
            #     self.gradient_accumulation = self.gradient_accumulation // 2
            #     if self.gradient_accumulation < 4:
            #         self.gradient_accumulation = 4

            save_model(self.checkpoint_dir, self.model, self.optimizer,
                       self.valid_metrics.avg('loss'), epoch, f'_model_best')
        save_model(self.checkpoint_dir, self.model, self.optimizer,
                   self.valid_metrics.avg('loss'), epoch, f'_model_last')

    def _progress(self,
                  batch_idx,
                  epoch,
                  metrics,
                  mode='',
                  print_summary=False):
        metrics_string = metrics.calc_all_metrics()
        if ((batch_idx * self.config.dataloader.train.batch_size) %
                self.log_step == 0):

            if metrics_string == None:
                self.logger.warning(f" No metrics")
            else:
                self.logger.info(
                    f"{mode} Epoch: [{epoch:2d}/{self.epochs:2d}]\t Sample [{batch_idx * self.config.dataloader.train.batch_size:5d}/{self.len_epoch:5d}]\t {metrics_string}"
                )
        elif print_summary:
            self.logger.info(
                f'{mode} summary  Epoch: [{epoch}/{self.epochs}]\t {metrics_string}'
            )
コード例 #3
0
class Tester(BaseTrainer):
    """
    Trainer class
    """

    def __init__(self, config, model, data_loader, writer, checkpoint_dir, logger,
                 valid_data_loader=None, test_data_loader=None, metric_ftns=None):
        super(Tester, self).__init__(config, data_loader, writer, checkpoint_dir, logger,
                                     valid_data_loader=valid_data_loader,
                                     test_data_loader=test_data_loader, metric_ftns=metric_ftns)

        if (self.config.cuda):
            use_cuda = torch.cuda.is_available()
            self.device = torch.device("cuda" if use_cuda else "cpu")
        else:
            self.device = torch.device("cpu")
        self.start_epoch = 1

        self.epochs = self.config.epochs
        self.valid_data_loader = valid_data_loader
        self.test_data_loader = test_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_test = self.test_data_loader is not None

        self.log_step = self.config.log_interval
        self.model = model

        self.mnt_best = np.inf

        self.checkpoint_dir = checkpoint_dir
        self.gradient_accumulation = config.gradient_accumulation
        self.metric_ftns = ['loss', 'acc']
        self.valid_metrics = MetricTracker(*[m for m in self.metric_ftns], writer=self.writer, mode='validation')
        self.logger = logger
    def _valid_epoch(self, epoch, mode, loader):
        """

        Args:
            epoch (int): current epoch
            mode (string): 'validation' or 'test'
            loader (dataloader):

        Returns: validation loss

        """
        self.model.eval()
        self.valid_sentences = []
        self.valid_metrics.reset()

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(loader):
                data = data.to(self.device)

                target = target.long().to(self.device)

                output, loss = self.model(data, target)
                loss = loss.mean()
                writer_step = (epoch - 1) * len(loader) + batch_idx

                prediction = torch.max(output, 1)
                acc = np.sum(prediction[1].cpu().numpy() == target.cpu().numpy()) / target.size(0)

                self.valid_metrics.update(key='loss',value=loss.item(),n=1,writer_step=writer_step)
                self.valid_metrics.update(key='acc', value=np.sum(prediction[1].cpu().numpy() == target.cpu().numpy()), n=target.size(0), writer_step=writer_step)

        self._progress(batch_idx, epoch, metrics=self.valid_metrics, mode=mode, print_summary=True)


        val_loss = self.valid_metrics.avg('loss')


        return val_loss


    def predict(self):
        """
        Inference
        Args:
            epoch ():

        Returns:

        """
        self.model.eval()

        predictions = []
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.test_data_loader):
                data = data.to(self.device)

                logits = self.model(data, None)

                maxes, prediction = torch.max(logits, 1)  # get the index of the max log-probability

                predictions.append(f"{target[0]},{prediction.cpu().numpy()[0]}")
        self.logger.info('Inference done')
        pred_name = os.path.join(self.checkpoint_dir, f'predictions.csv')
        write_csv(predictions, pred_name)
        return predictions

    def _progress(self, batch_idx, epoch, metrics, mode='', print_summary=False):
        metrics_string = metrics.calc_all_metrics()
        if ((batch_idx * self.config.dataloader.train.batch_size) % self.log_step == 0):

            if metrics_string == None:
                self.logger.warning(f" No metrics")
            else:
                self.logger.info(
                    f"{mode} Epoch: [{epoch:2d}/{self.epochs:2d}]\t Video [{batch_idx * self.config.dataloader.train.batch_size:5d}/{self.len_epoch:5d}]\t {metrics_string}")
        elif print_summary:
            self.logger.info(
                f'{mode} summary  Epoch: [{epoch}/{self.epochs}]\t {metrics_string}')