Exemple #1
0
class Trainer:
    """
    Implements training and validation logic
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)

        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()

        epoch_target = []
        epoch_scores = []
        epoch_word_pairs = []
        epoch_loss = []

        for batch_idx, batch_data in enumerate(self.data_loader):

            if not batch_data:
                continue

            for field in [
                    'input_ids', 'label', 'target', 'attention_mask',
                    'term1_mask', 'term2_mask'
            ]:
                batch_data[field] = batch_data[field].to(self.device)

            self.optimizer.zero_grad()
            output = self.model(batch_data)
            loss = self.criterion(
                output, batch_data['target'],
                self.data_loader.dataset.class_weights.to(self.device))
            loss.backward()
            self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)

            # accumulate epoch quantities
            epoch_target += [t.item() for t in batch_data['label']]
            epoch_scores += [output.cpu().detach().numpy()]
            epoch_word_pairs += batch_data['term_pair']
            epoch_loss += [loss.item()]

            # update metrics
            self.writer.add_scalar("loss", loss.item())
            for met in self.metric_ftns:
                self.writer.add_scalar(met.__name__,
                                       met(epoch_target, epoch_scores))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))
                #self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

            if batch_idx == self.len_epoch:
                break

        log = {
            m.__name__: m(epoch_target, epoch_scores)
            for m in self.metric_ftns
        }
        log["loss"] = np.sum(epoch_loss) / len(self.data_loader)

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        with torch.no_grad():

            epoch_target = []
            epoch_scores = []
            epoch_word_pairs = []
            epoch_loss = []

            for batch_idx, batch_data in enumerate(tqdm(
                    self.valid_data_loader)):

                for field in [
                        'input_ids', 'target', 'label', 'attention_mask',
                        'term1_mask', 'term2_mask'
                ]:
                    batch_data[field] = batch_data[field].to(self.device)

                output = self.model(batch_data)
                pred = torch.argmax(output, dim=-1)
                loss = self.criterion(
                    output, batch_data['target'].squeeze(-1),
                    self.data_loader.dataset.class_weights.to(self.device))

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')

                # accumulate epoch quantities
                epoch_target += [t.item() for t in batch_data['label']]
                epoch_scores += [output.cpu().detach().numpy()]
                epoch_word_pairs += batch_data['term_pair']
                epoch_loss += [loss.item()]

                # update metrics
                self.writer.add_scalar('loss', loss.item())
                for met in self.metric_ftns:
                    self.writer.add_scalar(met.__name__,
                                           met(epoch_target, epoch_scores))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')

        log = {
            m.__name__: m(epoch_target, epoch_scores)
            for m in self.metric_ftns
        }
        log["loss"] = np.sum(epoch_loss) / len(self.valid_data_loader)
        return log

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)

    def train(self):
        """
        Full training logic
        """
        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)

            # save logged informations into log dict
            log = {'epoch': epoch}
            log.update(result)

            # print logged informations to the screen
            for key, value in log.items():
                self.logger.info('    {:15s}: {}'.format(str(key), value))

            # evaluate model performance according to configured metric, save best checkpoint as model_best
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
                               (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. "
                        "Model performance monitoring is disabled.".format(
                            self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False

                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1

                if not_improved_count > self.early_stop:
                    self.logger.info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)

    def _prepare_device(self, n_gpu_use):
        """
        setup GPU device if available, move model into configured device
        """
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            self.logger.warning(
                "Warning: There\'s no GPU available on this machine,"
                "training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            self.logger.warning(
                "Warning: The number of GPU\'s configured to use is {}, but only {} are available "
                "on this machine.".format(n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
        list_ids = list(range(n_gpu_use))
        return device, list_ids

    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints

        :param epoch: current epoch number
        :param log: logging information of the epoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        arch = type(self.model).__name__
        state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.mnt_best,
            'config': self.config
        }
        if save_best:
            best_path = str(self.checkpoint_dir / 'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def _resume_checkpoint(self, resume_path):
        """
        Resume from saved checkpoints

        :param resume_path: Checkpoint path to be resumed
        """
        resume_path = str(resume_path)
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path)
        self.start_epoch = checkpoint['epoch'] + 1
        self.mnt_best = checkpoint['monitor_best']

        # load architecture params from checkpoint.
        if checkpoint['config']['arch'] != self.config['arch']:
            self.logger.warning(
                "Warning: Architecture configuration given in config file is different from that of "
                "checkpoint. This may yield an exception while state_dict is being loaded."
            )
        self.model.load_state_dict(checkpoint['state_dict'])

        # load optimizer state from checkpoint only when optimizer type is not changed.
        if checkpoint['config']['optimizer']['type'] != self.config[
                'optimizer']['type']:
            self.logger.warning(
                "Warning: Optimizer type given in config file is different from that of checkpoint. "
                "Optimizer parameters not being resumed.")
        else:
            self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))
class BaseTrainer:
    """
    Base class for all trainers
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 lr_scheduler,
                 config,
                 trainloader,
                 validloader=None,
                 len_epoch=None):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])
        self.trainloader = trainloader
        self.validloader = validloader

        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.trainloader)
        else:
            # iteration-based training
            self.trainloader = inf_loop(trainloader)
            self.len_epoch = len_epoch

        # setup GPU device if available, move model into configured device
        n_gpu_use = torch.cuda.device_count()
        self.device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
        self.model = model.to(self.device)
        self.model = torch.nn.DataParallel(model)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.log_step = cfg_trainer['log_step']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        self.start_epoch = 1
        self.checkpoint_dir = config.save_dir

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])
        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

        if config.resume is not None:
            self._resume_checkpoint(config.resume)

    @abstractmethod
    def _train_step(self, batch):
        """
        Training logic for a step

        :param batch: batch of current step
        :return: 
            loss: torch Variable with map for backwarding
            mets: metrics computed between output and target, dict
        """
        raise NotImplementedError

    @abstractmethod
    def _valid_step(self, batch):
        """
        Valid logic for a step

        :param batch: batch of current step
        :return:
            loss: torch Variable without map
            mets: metrics computed between output and target, dict
        """
        raise NotImplementedError

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()

        tic = time.time()
        datatime = batchtime = 0
        for batch_idx, batch in enumerate(self.trainloader):
            datatime += time.time() - tic
            # -------------------------------------------------------------------------
            loss, mets = self._train_step(batch)
            # -------------------------------------------------------------------------
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            batchtime += time.time() - tic
            tic = time.time()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for key, val in mets.items():
                self.train_metrics.update(key, val)

            if batch_idx % self.log_step == 0:
                processed_percent = batch_idx / self.len_epoch * 100
                self.logger.debug(
                    'Train Epoch:{} [{}/{}]({:.0f}%)\tTime:{:5.2f}/{:<5.2f}\tLoss:({:.4f}){:.4f}'
                    .format(epoch, batch_idx, self.len_epoch,
                            processed_percent, datatime, batchtime,
                            loss.item(), self.train_metrics.avg('loss')))
                datatime = batchtime = 0

            if batch_idx == self.len_epoch:
                break

        log = self.train_metrics.result()
        log = {'train_' + k: v for k, v in log.items()}

        if self.validloader is not None:
            val_log = self._valid_epoch(epoch)
            log.update(**{'valid_' + k: v for k, v in val_log.items()})
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        for batch_idx, batch in enumerate(self.validloader):
            # -------------------------------------------------------------------------
            loss, mets = self._valid_step(batch)
            # -------------------------------------------------------------------------
            self.writer.set_step(
                (epoch - 1) * len(self.validloader) + batch_idx, 'valid')
            self.valid_metrics.update('loss', loss.item())
            for key, val in mets.items():
                self.valid_metrics.update(key, val)

        return self.valid_metrics.result()

    def train(self):
        """
        Full training logic
        """
        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)

            # save logged informations into log dict
            lr = self.optimizer.param_groups[0]['lr']
            log = {'epoch': epoch, 'lr': lr}
            log.update(result)

            # print logged informations to the screen
            for key, value in log.items():
                self.logger.info('    {:20s}: {}'.format(str(key), value))

            # evaluate model performance according to configured metric, save best checkpoint as model_best
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
                               (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. "
                        "Model performance monitoring is disabled.".format(
                            self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False

                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1

                if not_improved_count > self.early_stop:
                    self.logger.info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if self.lr_scheduler is not None:
                if isinstance(self.lr_scheduler, ReduceLROnPlateau):
                    self.lr_scheduler.step(log[self.mnt_metric])
                else:
                    self.lr_scheduler.step()

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)

            # add histogram of model parameters to the tensorboard
            self.writer.set_step(epoch)
            for name, p in self.model.named_parameters():
                self.writer.add_histogram(name, p, bins='auto')

    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints

        :param epoch: current epoch number
        :param log: logging information of the epoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        state = {
            'epoch': epoch,
            'model': self.model.module.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'lr_scheduler': self.lr_scheduler.state_dict(),
            'monitor_best': self.mnt_best
        }
        filename = str(self.checkpoint_dir / 'chkpt_{:03d}.pth'.format(epoch))
        torch.save(state, filename)
        self.logger.info("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path = str(self.checkpoint_dir / 'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def _resume_checkpoint(self, resume_path):
        """
        Resume from saved checkpoints

        :param resume_path: Checkpoint path to be resumed
        """
        resume_path = str(resume_path)
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path)
        try:
            self.start_epoch = checkpoint['epoch'] + 1
            self.model.module.load_state_dict(checkpoint['model'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            self.mnt_best = checkpoint['monitor_best']
        except KeyError:
            self.model.module.load_state_dict(checkpoint)

        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))