Ejemplo n.º 1
0
class Trainer:
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 max_len_step=None):
        '''

        :param model:
        :param optimizer:
        :param config:
        :param data_loader:
        :param valid_data_loader:
        :param lr_scheduler:
        :param max_len_step:  controls number of batches(steps) in each epoch.
        '''
        self.config = config
        self.distributed = config['distributed']
        if self.distributed:
            self.local_master = (config['local_rank'] == 0)
            self.global_master = (dist.get_rank() == 0)
        else:
            self.local_master = True
            self.global_master = True
        self.logger = config.get_logger(
            'trainer',
            config['trainer']['log_verbosity']) if self.local_master else None

        # setup GPU device if available, move model into configured device
        self.device, self.device_ids = self._prepare_device(
            config['local_rank'], config['local_world_size'])
        self.model = model.to(self.device)

        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        monitor_open = cfg_trainer['monitor_open']
        if monitor_open:
            self.monitor = cfg_trainer.get('monitor', 'off')
        else:
            self.monitor = 'off'

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.monitor_mode = 'off'
            self.monitor_best = 0
        else:
            self.monitor_mode, self.monitor_metric = self.monitor.split()
            assert self.monitor_mode in ['min', 'max']

            self.monitor_best = inf if self.monitor_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)
            self.early_stop = inf if self.early_stop == -1 else self.early_stop

        self.start_epoch = 1

        if self.local_master:
            self.checkpoint_dir = config.save_dir
            # setup visualization writer instance
            self.writer = TensorboardWriter(config.log_dir, self.logger,
                                            cfg_trainer['tensorboard'])

        # load checkpoint for resume training
        if config.resume is not None:
            self._resume_checkpoint(config.resume)

        # load checkpoint following load to multi-gpu, avoid 'module.' prefix
        if self.config['trainer']['sync_batch_norm'] and self.distributed:
            self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                self.model)

        if self.distributed:
            self.model = DDP(self.model,
                             device_ids=self.device_ids,
                             output_device=self.device_ids[0],
                             find_unused_parameters=True)

        self.data_loader = data_loader
        if max_len_step is None:  # max length of iteration step of every epoch
            # epoch-based training
            self.len_step = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_step = max_len_step
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler

        log_step = self.config['trainer']['log_step_interval']
        self.log_step = log_step if log_step != -1 and 0 < log_step < self.len_step else int(
            np.sqrt(data_loader.batch_size))

        val_step_interval = self.config['trainer']['val_step_interval']
        # self.val_step_interval = val_step_interval if val_step_interval!= -1 and 0 < val_step_interval < self.len_step\
        #                                             else int(np.sqrt(data_loader.batch_size))
        self.val_step_interval = val_step_interval

        self.gl_loss_lambda = self.config['trainer']['gl_loss_lambda']

        self.train_loss_metrics = MetricTracker(
            'loss',
            'gl_loss',
            'crf_loss',
            writer=self.writer if self.local_master else None)
        self.valid_f1_metrics = SpanBasedF1MetricTracker(iob_labels_vocab_cls)

    def train(self):
        """
        Full training logic, including train and validation.
        """

        if self.distributed:
            dist.barrier()  # Syncing machines before training

        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):

            # ensure distribute worker sample different data,
            # set different random seed by passing epoch to sampler
            if self.distributed:
                self.data_loader.sampler.set_epoch(epoch)
            result_dict = self._train_epoch(epoch)

            # print logged informations to the screen
            if self.do_validation:
                val_result_dict = result_dict['val_result_dict']
                val_res = SpanBasedF1MetricTracker.dict2str(val_result_dict)
            else:
                val_res = ''
            # every epoch log information
            self.logger_info(
                '[Epoch Validation] Epoch:[{}/{}] Total Loss: {:.6f} '
                'GL_Loss: {:.6f} CRF_Loss: {:.6f} \n{}'.format(
                    epoch, self.epochs, result_dict['loss'],
                    result_dict['gl_loss'] * self.gl_loss_lambda,
                    result_dict['crf_loss'], val_res))

            # evaluate model performance according to configured metric, check early stop, and
            # save best checkpoint as model_best
            best = False
            if self.monitor_mode != 'off' and self.do_validation:
                best, not_improved_count = self._is_best_monitor_metric(
                    best, not_improved_count, val_result_dict)
                if not_improved_count > self.early_stop:
                    self.logger_info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)

    def _is_best_monitor_metric(self, best, not_improved_count,
                                val_result_dict):
        '''
        monitor metric
        :param best:
        :param not_improved_count:
        :param val_result_dict:
        :return:
        '''
        entity_name, metric = self.monitor_metric.split('-')
        val_monitor_metric_res = val_result_dict[entity_name][metric]
        try:
            # check whether model performance improved or not, according to specified metric(monitor_metric)
            improved = (self.monitor_mode == 'min' and val_monitor_metric_res <= self.monitor_best) or \
                       (self.monitor_mode == 'max' and val_monitor_metric_res >= self.monitor_best)
        except KeyError:
            self.logger_warning(
                "Warning: Metric '{}' is not found. "
                "Model performance monitoring is disabled.".format(
                    self.monitor_metric))
            self.monitor_mode = 'off'
            improved = False
        if improved:
            self.monitor_best = val_monitor_metric_res
            not_improved_count = 0
            best = True
        else:
            not_improved_count += 1
        return best, not_improved_count

    def _train_epoch(self, epoch):
        '''
        Training logic for an epoch
        :param epoch: Integer, current training epoch.
        :return: A log dict that contains average loss and metric in this epoch.
        '''
        self.model.train()
        self.train_loss_metrics.reset()
        ## step iteration start ##
        for step_idx, input_data_item in enumerate(self.data_loader):
            step_idx += 1
            for key, input_value in input_data_item.items():
                if input_value is not None and isinstance(
                        input_value, torch.Tensor):
                    input_data_item[key] = input_value.to(self.device,
                                                          non_blocking=True)
            if self.config['trainer']['anomaly_detection']:
                # This mode will increase the runtime and should only be enabled for debugging
                with torch.autograd.detect_anomaly():
                    self.optimizer.zero_grad()
                    # model forward
                    output = self.model(**input_data_item)
                    # calculate loss
                    gl_loss = output['gl_loss']
                    crf_loss = output['crf_loss']
                    total_loss = torch.sum(
                        crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss)
                    # backward
                    total_loss.backward()
                    # self.average_gradients(self.model)
                    self.optimizer.step()
            else:
                self.optimizer.zero_grad()
                # model forward
                output = self.model(**input_data_item)
                # calculate loss
                gl_loss = output['gl_loss']
                crf_loss = output['crf_loss']
                total_loss = torch.sum(
                    crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss)
                # backward
                total_loss.backward()
                # self.average_gradients(self.model)
                self.optimizer.step()

            # Use a barrier() to make sure that all process have finished forward and backward
            if self.distributed:
                dist.barrier()
                #  obtain the sum of all total_loss at all processes
                dist.all_reduce(total_loss, op=dist.reduce_op.SUM)

                size = dist.get_world_size()
            else:
                size = 1
            gl_loss /= size  # averages gl_loss across the whole world
            crf_loss /= size  # averages crf_loss across the whole world

            # calculate average loss across the batch size
            avg_gl_loss = torch.mean(gl_loss)
            avg_crf_loss = torch.mean(crf_loss)
            avg_loss = avg_crf_loss + self.gl_loss_lambda * avg_gl_loss
            # update metrics
            self.writer.set_step((epoch - 1) * self.len_step + step_idx -
                                 1) if self.local_master else None
            self.train_loss_metrics.update('loss', avg_loss.item())
            self.train_loss_metrics.update(
                'gl_loss',
                avg_gl_loss.item() * self.gl_loss_lambda)
            self.train_loss_metrics.update('crf_loss', avg_crf_loss.item())

            # log messages
            if step_idx % self.log_step == 0:
                self.logger_info(
                    'Train Epoch:[{}/{}] Step:[{}/{}] Total Loss: {:.6f} GL_Loss: {:.6f} CRF_Loss: {:.6f}'
                    .format(epoch, self.epochs, step_idx, self.len_step,
                            avg_loss.item(),
                            avg_gl_loss.item() * self.gl_loss_lambda,
                            avg_crf_loss.item()))
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

            # do validation after val_step_interval iteration
            if self.do_validation and step_idx % self.val_step_interval == 0:
                val_result_dict = self._valid_epoch(epoch)
                self.logger_info(
                    '[Step Validation] Epoch:[{}/{}] Step:[{}/{}]  \n{}'.
                    format(epoch, self.epochs, step_idx, self.len_step,
                           SpanBasedF1MetricTracker.dict2str(val_result_dict)))

                # check if best metric, if true, then save as model_best checkpoint.
                best, not_improved_count = self._is_best_monitor_metric(
                    False, 0, val_result_dict)
                if best:
                    self._save_checkpoint(epoch, best)

            # decide whether continue iter
            if step_idx == self.len_step + 1:
                break

        ## step iteration end ##

        # {'loss': avg_loss, 'gl_loss': avg_gl_loss, 'crf_loss': avg_crf_loss}
        log = self.train_loss_metrics.result()

        # do validation after training an epoch
        if self.do_validation:
            val_result_dict = self._valid_epoch(epoch)
            log['val_result_dict'] = val_result_dict

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log

    def _valid_epoch(self, epoch):
        '''
         Validate after training an epoch or regular step, this is a time-consuming procedure if validation data is big.
        :param epoch: Integer, current training epoch.
        :return: A dict that contains information about validation
        '''

        self.model.eval()
        self.valid_f1_metrics.reset()
        with torch.no_grad():
            for step_idx, input_data_item in enumerate(self.valid_data_loader):
                for key, input_value in input_data_item.items():
                    if input_value is not None and isinstance(
                            input_value, torch.Tensor):
                        input_data_item[key] = input_value.to(
                            self.device, non_blocking=True)

                output = self.model(**input_data_item)
                # print("awesome 307")
                logits = output['logits']
                new_mask = output['new_mask']
                if hasattr(self.model, 'module'):
                    #  List[(List[int], torch.Tensor)] contain the tag indices of the maximum likelihood tag sequence.
                    #  and the score of the viterbi path.
                    best_paths = self.model.module.decoder.crf_layer.viterbi_tags(
                        logits, mask=new_mask, logits_batch_first=True)
                else:
                    best_paths = self.model.decoder.crf_layer.viterbi_tags(
                        logits, mask=new_mask, logits_batch_first=True)
                predicted_tags = []
                for path, score in best_paths:
                    predicted_tags.append(path)

                self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + step_idx, 'valid') \
                    if self.local_master else None

                # calculate and update f1 metrics
                # (B, N*T, out_dim)
                predicted_tags_hard_prob = logits * 0
                for i, instance_tags in enumerate(predicted_tags):
                    for j, tag_id in enumerate(instance_tags):
                        predicted_tags_hard_prob[i, j, tag_id] = 1

                golden_tags = input_data_item['iob_tags_label']
                mask = input_data_item['mask']
                union_iob_tags = iob_tags_to_union_iob_tags(golden_tags, mask)

                if self.distributed:
                    dist.barrier()  #
                self.valid_f1_metrics.update(predicted_tags_hard_prob.long(),
                                             union_iob_tags, new_mask)

        # add histogram of model parameters to the tensorboard
        # for name, p in self.model.named_parameters():
        #     self.writer.add_histogram(name, p, bins='auto')

        f1_result_dict = self.valid_f1_metrics.result()

        # rollback to train mode
        self.model.train()

        return f1_result_dict

    def average_gradients(self, model):
        '''
        Gradient averaging
        :param model:
        :return:
        '''
        size = float(dist.get_world_size())
        for param in model.parameters():
            dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
            param.grad.data /= size

    def logger_info(self, msg):
        self.logger.info(msg) if self.local_master else None

    def logger_warning(self, msg):
        self.logger.warning(msg) if self.local_master else None

    def _prepare_device(self, local_rank, local_world_size):
        '''
         setup GPU device if available, move model into configured device
        :param local_rank:
        :param local_world_size:
        :return:
        '''
        if self.distributed:
            ngpu_per_process = torch.cuda.device_count() // local_world_size
            device_ids = list(
                range(local_rank * ngpu_per_process,
                      (local_rank + 1) * ngpu_per_process))

            if torch.cuda.is_available() and local_rank != -1:
                torch.cuda.set_device(
                    device_ids[0]
                )  # device_ids[0] =local_rank if local_world_size = n_gpu per node
                device = 'cuda'
                self.logger_info(
                    f"[Process {os.getpid()}] world_size = {dist.get_world_size()}, "
                    +
                    f"rank = {dist.get_rank()}, n_gpu/process = {ngpu_per_process}, device_ids = {device_ids}"
                )
            else:
                self.logger_warning('Training will be using CPU!')
                device = 'cpu'
            device = torch.device(device)
            return device, device_ids
        else:
            n_gpu = torch.cuda.device_count()
            print(f"NUMBER GPU {n_gpu}")
            n_gpu_use = local_world_size
            if n_gpu_use > 0 and n_gpu == 0:
                self.logger_warning(
                    "Warning: There\'s no GPU available on this machine,"
                    "training will be performed on CPU.")
                n_gpu_use = 0
            if n_gpu_use > n_gpu:
                self.logger_warning(
                    "Warning: The number of GPU\'s configured to use is {}, but only {} are available "
                    "on this machine.".format(n_gpu_use, n_gpu))
                n_gpu_use = n_gpu

            list_ids = list(range(n_gpu_use))
            if n_gpu_use > 0:
                torch.cuda.set_device(
                    list_ids[0])  # only use first available gpu as devices
                self.logger_warning(f'Training is using GPU {list_ids[0]}!')
                device = 'cuda'
            else:
                self.logger_warning('Training is using CPU!')
                device = 'cpu'
            device = torch.device(device)
            return device, list_ids

    def _save_checkpoint(self, epoch, save_best=False):
        '''
        Saving checkpoints
        :param epoch:  current epoch number
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        :return:
        '''
        # only local master process do save model
        if not self.local_master:
            return

        if hasattr(self.model, 'module'):
            arch = type(self.model.module).__name__
            state_dict = self.model.module.state_dict()
        else:
            arch = type(self.model).__name__
            state_dict = self.model.state_dict()
        state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': state_dict,
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.monitor_best,
            'config': self.config
        }
        if save_best:
            best_path = str(self.checkpoint_dir / 'model_best.pth')
            torch.save(state, best_path)
            self.logger_info("Saving current best: model_best.pth ...")
        else:
            filename = str(self.checkpoint_dir /
                           'checkpoint-epoch{}.pth'.format(epoch))
            torch.save(state, filename)
            self.logger_info("Saving checkpoint: {} ...".format(filename))

    def _resume_checkpoint(self, resume_path):
        '''
        Resume from saved checkpoints
        :param resume_path: Checkpoint path to be resumed
        :return:
        '''
        resume_path = str(resume_path)
        self.logger_info("Loading checkpoint: {} ...".format(resume_path))
        # map_location = {'cuda:%d' % 0: 'cuda:%d' % self.config['local_rank']}
        checkpoint = torch.load(resume_path, map_location=self.device)
        self.start_epoch = checkpoint['epoch'] + 1
        self.monitor_best = checkpoint['monitor_best']

        # load architecture params from checkpoint.
        if checkpoint['config']['model_arch'] != self.config['model_arch']:
            self.logger_warning(
                "Warning: Architecture configuration given in config file is different from that of "
                "checkpoint. This may yield an exception while state_dict is being loaded."
            )
        self.model.load_state_dict(checkpoint['state_dict'])

        # load optimizer state from checkpoint only when optimizer type is not changed.
        if checkpoint['config']['optimizer']['type'] != self.config[
                'optimizer']['type']:
            self.logger_warning(
                "Warning: Optimizer type given in config file is different from that of checkpoint. "
                "Optimizer parameters not being resumed.")
        else:
            self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.logger_info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))
Ejemplo n.º 2
0
class Trainer:
    """
    Implements training and validation logic
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1

        self.checkpoint_dir = config.save_dir

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)

        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()

        epoch_target = []
        epoch_scores = []
        epoch_word_pairs = []
        epoch_loss = []

        for batch_idx, batch_data in enumerate(self.data_loader):

            if not batch_data:
                continue

            for field in [
                    'input_ids', 'label', 'target', 'attention_mask',
                    'term1_mask', 'term2_mask'
            ]:
                batch_data[field] = batch_data[field].to(self.device)

            self.optimizer.zero_grad()
            output = self.model(batch_data)
            loss = self.criterion(
                output, batch_data['target'],
                self.data_loader.dataset.class_weights.to(self.device))
            loss.backward()
            self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)

            # accumulate epoch quantities
            epoch_target += [t.item() for t in batch_data['label']]
            epoch_scores += [output.cpu().detach().numpy()]
            epoch_word_pairs += batch_data['term_pair']
            epoch_loss += [loss.item()]

            # update metrics
            self.writer.add_scalar("loss", loss.item())
            for met in self.metric_ftns:
                self.writer.add_scalar(met.__name__,
                                       met(epoch_target, epoch_scores))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))
                #self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

            if batch_idx == self.len_epoch:
                break

        log = {
            m.__name__: m(epoch_target, epoch_scores)
            for m in self.metric_ftns
        }
        log["loss"] = np.sum(epoch_loss) / len(self.data_loader)

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        with torch.no_grad():

            epoch_target = []
            epoch_scores = []
            epoch_word_pairs = []
            epoch_loss = []

            for batch_idx, batch_data in enumerate(tqdm(
                    self.valid_data_loader)):

                for field in [
                        'input_ids', 'target', 'label', 'attention_mask',
                        'term1_mask', 'term2_mask'
                ]:
                    batch_data[field] = batch_data[field].to(self.device)

                output = self.model(batch_data)
                pred = torch.argmax(output, dim=-1)
                loss = self.criterion(
                    output, batch_data['target'].squeeze(-1),
                    self.data_loader.dataset.class_weights.to(self.device))

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')

                # accumulate epoch quantities
                epoch_target += [t.item() for t in batch_data['label']]
                epoch_scores += [output.cpu().detach().numpy()]
                epoch_word_pairs += batch_data['term_pair']
                epoch_loss += [loss.item()]

                # update metrics
                self.writer.add_scalar('loss', loss.item())
                for met in self.metric_ftns:
                    self.writer.add_scalar(met.__name__,
                                           met(epoch_target, epoch_scores))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')

        log = {
            m.__name__: m(epoch_target, epoch_scores)
            for m in self.metric_ftns
        }
        log["loss"] = np.sum(epoch_loss) / len(self.valid_data_loader)
        return log

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)

    def train(self):
        """
        Full training logic
        """
        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)

            # save logged informations into log dict
            log = {'epoch': epoch}
            log.update(result)

            # print logged informations to the screen
            for key, value in log.items():
                self.logger.info('    {:15s}: {}'.format(str(key), value))

            # evaluate model performance according to configured metric, save best checkpoint as model_best
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
                               (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. "
                        "Model performance monitoring is disabled.".format(
                            self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False

                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1

                if not_improved_count > self.early_stop:
                    self.logger.info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)

    def _prepare_device(self, n_gpu_use):
        """
        setup GPU device if available, move model into configured device
        """
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            self.logger.warning(
                "Warning: There\'s no GPU available on this machine,"
                "training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            self.logger.warning(
                "Warning: The number of GPU\'s configured to use is {}, but only {} are available "
                "on this machine.".format(n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
        list_ids = list(range(n_gpu_use))
        return device, list_ids

    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints

        :param epoch: current epoch number
        :param log: logging information of the epoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        arch = type(self.model).__name__
        state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.mnt_best,
            'config': self.config
        }
        if save_best:
            best_path = str(self.checkpoint_dir / 'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def _resume_checkpoint(self, resume_path):
        """
        Resume from saved checkpoints

        :param resume_path: Checkpoint path to be resumed
        """
        resume_path = str(resume_path)
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path)
        self.start_epoch = checkpoint['epoch'] + 1
        self.mnt_best = checkpoint['monitor_best']

        # load architecture params from checkpoint.
        if checkpoint['config']['arch'] != self.config['arch']:
            self.logger.warning(
                "Warning: Architecture configuration given in config file is different from that of "
                "checkpoint. This may yield an exception while state_dict is being loaded."
            )
        self.model.load_state_dict(checkpoint['state_dict'])

        # load optimizer state from checkpoint only when optimizer type is not changed.
        if checkpoint['config']['optimizer']['type'] != self.config[
                'optimizer']['type']:
            self.logger.warning(
                "Warning: Optimizer type given in config file is different from that of checkpoint. "
                "Optimizer parameters not being resumed.")
        else:
            self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))
Ejemplo n.º 3
0
class Trainer:
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 max_len_step=None):
        '''

        :param model:
        :param optimizer:
        :param config:
        :param data_loader:
        :param valid_data_loader:
        :param lr_scheduler:
        :param max_len_step: controls number of batches(steps) in each epoch.
        '''
        self.config = config
        self.distributed = config['distributed']
        if self.distributed:
            self.local_master = (config['local_rank'] == 0)
            self.global_master = (dist.get_rank() == 0)
        else:
            self.local_master = True
            self.global_master = True
        self.logger = config.get_logger(
            'trainer',
            config['trainer']['log_verbosity']) if self.local_master else None

        # setup GPU device if available, move model into configured device
        self.device, self.device_ids = self._prepare_device(
            config['local_rank'], config['local_world_size'])
        self.model = model.to(self.device)

        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        monitor_open = cfg_trainer['monitor_open']
        if monitor_open:
            self.monitor = cfg_trainer.get('monitor', 'off')
        else:
            self.monitor = 'off'

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.monitor_mode = 'off'
            self.monitor_best = 0
        else:
            self.monitor_mode, self.monitor_metric = self.monitor.split()
            assert self.monitor_mode in ['min', 'max']

            self.monitor_best = inf if self.monitor_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)
            self.early_stop = inf if self.early_stop == -1 else self.early_stop

        self.start_epoch = 1

        if self.local_master:
            self.checkpoint_dir = config.save_dir
            # setup visualization writer instance
            self.writer = TensorboardWriter(config.log_dir, self.logger,
                                            cfg_trainer['tensorboard'])

        # load checkpoint for resume training or finetune
        self.finetune = config['finetune']
        if config.resume is not None:
            self._resume_checkpoint(config.resume)
        else:
            if self.finetune:
                self.logger_warning(
                    "Finetune mode must set resume args to specific checkpoint path"
                )
                raise RuntimeError(
                    "Finetune mode must set resume args to specific checkpoint path"
                )
        # load checkpoint then load to multi-gpu, avoid 'module.' prefix
        if self.config['trainer']['sync_batch_norm'] and self.distributed:
            # sync_batch_norm only support one gpu per process mode
            self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                self.model)

        if self.distributed:  # move model to distributed gpu
            self.model = DDP(self.model,
                             device_ids=self.device_ids,
                             output_device=self.device_ids[0],
                             find_unused_parameters=True)

        # iteration-based training
        self.len_step = len(data_loader)
        self.data_loader = data_loader
        if max_len_step is not None:  # max length of iteration step of every epoch
            self.len_step = min(max_len_step, self.len_step)
        self.valid_data_loader = valid_data_loader

        do_validation = self.config['trainer']['do_validation']
        self.validation_start_epoch = self.config['trainer'][
            'validation_start_epoch']
        self.do_validation = (self.valid_data_loader is not None
                              and do_validation)
        self.lr_scheduler = lr_scheduler

        log_step = self.config['trainer']['log_step_interval']
        self.log_step = log_step if log_step != -1 and 0 < log_step < self.len_step else int(
            np.sqrt(data_loader.batch_size))

        # do validation interval
        val_step_interval = self.config['trainer']['val_step_interval']
        # self.val_step_interval = val_step_interval if val_step_interval!= -1 and 0 < val_step_interval < self.len_step\
        #                                             else int(np.sqrt(data_loader.batch_size))
        self.val_step_interval = val_step_interval

        # build metrics tracker and wrapper tensorboard writer.
        self.train_metrics = AverageMetricTracker(
            'loss', writer=self.writer if self.local_master else None)
        self.val_metrics = AverageMetricTracker(
            'loss',
            'word_acc',
            'word_acc_case_insensitive',
            'edit_distance_acc',
            writer=self.writer if self.local_master else None)

    def train(self):
        """
        Full training logic, including train and validation.
        """

        if self.distributed:
            dist.barrier()  # Syncing machines before training

        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):

            # ensure distribute worker sample different data,
            # set different random seed by passing epoch to sampler
            if self.distributed:
                self.data_loader.sampler.set_epoch(epoch)

            self.valid_data_loader.batch_sampler.set_epoch(
                epoch
            ) if self.valid_data_loader.batch_sampler is not None else None

            torch.cuda.empty_cache()
            result_dict = self._train_epoch(epoch)
            # import pdb;pdb.set_trace()

            # validate after training an epoch
            if self.do_validation and epoch >= self.validation_start_epoch:
                val_metric_res_dict = self._valid_epoch(epoch)
                # import pdb;pdb.set_trace()
                val_res = f"\nValidation result after {epoch} epoch: " \
                          f"Word_acc: {val_metric_res_dict['word_acc']:.6f} " \
                          f"Word_acc_case_ins: {val_metric_res_dict['word_acc_case_insensitive']:.6f} " \
                          f"Edit_distance_acc: {val_metric_res_dict['edit_distance_acc']:.6f}"
            else:
                val_res = ''

            # update lr after training an epoch, epoch-wise
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

            # every epoch log information
            self.logger_info(
                '[Epoch End] Epoch:[{}/{}] Loss: {:.6f} LR: {:.8f}'.format(
                    epoch, self.epochs, result_dict['loss'], self._get_lr()) +
                val_res)

            # evaluate model performance according to configured metric, check early stop, and
            # save best checkpoint as model_best
            best = False
            if self.monitor_mode != 'off' and self.do_validation and epoch >= self.validation_start_epoch:
                best, not_improved_count = self._is_best_monitor_metric(
                    best, not_improved_count, val_metric_res_dict)
                if not_improved_count > self.early_stop:  # epoch level count
                    self.logger_info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break
            # epoch-level save period
            if best or (epoch % self.save_period == 0
                        and epoch >= self.validation_start_epoch):
                self._save_checkpoint(epoch, save_best=best)

    def _is_best_monitor_metric(self,
                                best,
                                not_improved_count,
                                val_result_dict,
                                update_not_improved_count=True):
        '''
        monitor metric
        :param best: bool
        :param not_improved_count: int
        :param val_result_dict: dict
        :param update_monitor_best: bool,  true: update monitor_best when epoch-level validation
        :return:
        '''
        val_monitor_metric_res = val_result_dict[self.monitor_metric]
        try:
            # check whether model performance improved or not, according to specified metric(monitor_metric)
            improved = (self.monitor_mode == 'min' and val_monitor_metric_res <= self.monitor_best) or \
                       (self.monitor_mode == 'max' and val_monitor_metric_res >= self.monitor_best)
        except KeyError:
            self.logger_warning(
                "Warning: Metric '{}' is not found. "
                "Model performance monitoring is disabled.".format(
                    self.monitor_metric))
            self.monitor_mode = 'off'
            improved = False
        if improved:
            self.monitor_best = val_monitor_metric_res
            not_improved_count = 0
            best = True
        else:
            if update_not_improved_count:  # update when do epoch-level validation, step-level not changed count
                not_improved_count += 1
        return best, not_improved_count

    def _train_epoch(self, epoch):
        '''
        Training logic for an epoch
        :param epoch: Integer, current training epoch.
        :return: A log dict that contains average loss and metric in this epoch.
        '''
        self.model.train()

        self.train_metrics.reset()

        ## step iteration start ##
        for step_idx, input_data_item in enumerate(self.data_loader):

            batch_size = input_data_item['batch_size']
            if batch_size == 0: continue

            images = input_data_item['images']
            text_label = input_data_item['labels']

            # # step-wise lr scheduler, comment this, using epoch-wise lr_scheduler
            # if self.lr_scheduler is not None:
            #     self.lr_scheduler.step()

            # for step_idx in range(self.len_step):
            step_idx += 1
            # import pdb;pdb.set_trace()
            # prepare input data
            images = images.to(self.device)
            target = LabelTransformer.encode(text_label)
            target = target.to(self.device)
            target = target.permute(1, 0)

            if self.config['trainer']['anomaly_detection']:
                # This mode will increase the runtime and should only be enabled for debugging
                with torch.autograd.detect_anomaly():
                    # forward
                    outputs = self.model(
                        images,
                        target[:, :-1])  # need to remove <EOS> in target
                    loss = F.cross_entropy(
                        outputs.contiguous().view(-1, outputs.shape[-1]),
                        target[:, 1:].contiguous().view(
                            -1),  # need to remove <SOS> in target
                        ignore_index=LabelTransformer.PAD)

                    # backward and update parameters
                    self.optimizer.zero_grad()
                    loss.backward()
                    # self.average_gradients(self.model)
                    self.optimizer.step()
            else:
                # forward
                outputs = self.model(
                    images, target[:, :-1])  # need to remove <EOS> in target
                loss = F.cross_entropy(
                    outputs.contiguous().view(-1, outputs.shape[-1]),
                    target[:, 1:].contiguous().view(
                        -1),  # need to remove <SOS> in target
                    ignore_index=LabelTransformer.PAD)

                # backward and update parameters
                self.optimizer.zero_grad()
                loss.backward()
                # self.average_gradients(self.model)
                self.optimizer.step()

            ## Train batch done. Logging results

            # due to training mode (bn, dropout), we don't calculate acc

            batch_total = images.shape[0]
            reduced_loss = loss.item()  # mean results of ce

            if self.distributed:
                # obtain the sum of all train metrics at all processes by all_reduce operation
                # Must keep track of global batch size,
                # since not all machines are guaranteed equal batches at the end of an epoch
                reduced_metrics_tensor = torch.tensor(
                    [batch_total, reduced_loss]).float().to(self.device)
                # Use a barrier() to make sure that all process have finished above code
                dist.barrier()
                # averages metric tensor across the whole world
                # import pdb;pdb.set_trace()
                # reduced_metrics_tensor = self.mean_reduce_tensor(reduced_metrics_tensor)
                reduced_metrics_tensor = self.sum_tesnor(
                    reduced_metrics_tensor)
                batch_total, reduced_loss = reduced_metrics_tensor.cpu().numpy(
                )
                reduced_loss = reduced_loss / dist.get_world_size()
            # update metrics and write to tensorboard
            global_step = (epoch - 1) * self.len_step + step_idx - 1
            self.writer.set_step(global_step,
                                 mode='train') if self.local_master else None
            # write tag is loss/train (mode =train)
            self.train_metrics.update(
                'loss', reduced_loss, batch_total
            )  # here, loss is mean results over batch, accumulate values

            # log messages
            if step_idx % self.log_step == 0 or step_idx == 1:
                self.logger_info(
                    'Train Epoch:[{}/{}] Step:[{}/{}] Loss: {:.6f} Loss_avg: {:.6f} LR: {:.8f}'
                    .format(epoch, self.epochs, step_idx, self.len_step,
                            self.train_metrics.val('loss'),
                            self.train_metrics.avg('loss'), self._get_lr()))
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

            # do validation after val_step_interval iteration
            if self.do_validation and step_idx % self.val_step_interval == 0 and epoch >= self.validation_start_epoch:
                val_metric_res_dict = self._valid_epoch(
                    epoch)  # average metric
                self.logger_info(
                    '[Step Validation] Epoch:[{}/{}] Step:[{}/{}] Word_acc: {:.6f} Word_acc_case_ins {:.6f}'
                    'Edit_distance_acc: {:.6f}'.format(
                        epoch, self.epochs, step_idx, self.len_step,
                        val_metric_res_dict['word_acc'],
                        val_metric_res_dict['word_acc_case_insensitive'],
                        val_metric_res_dict['edit_distance_acc']))
                # check if best metric, if true, then save as model_best checkpoint.
                best, not_improved_count = self._is_best_monitor_metric(
                    False,
                    0,
                    val_metric_res_dict,
                    update_not_improved_count=False)
                if best:  # step-level valida then save model
                    self._save_checkpoint(epoch, best, step_idx)

            # decide whether continue iter
            if step_idx == self.len_step:
                break
        ## step iteration end ##

        log_dict = self.train_metrics.result()
        return log_dict

    def _valid_epoch(self, epoch):
        '''
         Validate after training an epoch or regular step, this is a time-consuming procedure if validation data is big.
        :param epoch: Integer, current training epoch.
        :return: A dict that contains information about validation
        '''

        self.model.eval()
        self.val_metrics.reset()

        for step_idx, input_data_item in enumerate(self.valid_data_loader):

            batch_size = input_data_item['batch_size']
            images = input_data_item['images']
            text_label = input_data_item['labels']

            if self.distributed:
                word_acc, word_acc_case_ins, edit_distance_acc, total_distance_ref, batch_total = \
                    self._distributed_predict(batch_size, images, text_label)
            else:  # one cpu or gpu non-distributed mode
                with torch.no_grad():
                    images = images.to(self.device)
                    # target = LabelTransformer.encode(text_label)
                    # target = target.to(self.device)
                    # target = target.permute(1, 0)

                    if hasattr(self.model, 'module'):
                        model = self.model.module
                    else:
                        model = self.model

                    # (bs, max_len)
                    outputs = decode_util.greedy_decode(
                        model,
                        images,
                        LabelTransformer.max_length,
                        LabelTransformer.SOS,
                        padding_symbol=LabelTransformer.PAD,
                        device=images.device,
                        padding=True)
                    correct = 0
                    correct_case_ins = 0
                    total_distance_ref = 0
                    total_edit_distance = 0
                    for index, (pred, text_gold) in enumerate(
                            zip(outputs[:, 1:], text_label)):
                        predict_text = ""
                        for i in range(len(pred)):  # decode one sample
                            if pred[i] == LabelTransformer.EOS: break
                            if pred[i] == LabelTransformer.UNK: continue

                            decoded_char = LabelTransformer.decode(pred[i])
                            predict_text += decoded_char

                        # calculate edit distance
                        ref = len(text_gold)
                        edit_distance = distance.levenshtein(
                            text_gold, predict_text)
                        total_distance_ref += ref
                        total_edit_distance += edit_distance

                        # calculate word accuracy related
                        # predict_text = predict_text.strip()
                        # text_gold = text_gold.strip()
                        if predict_text == text_gold:
                            correct += 1
                        if predict_text.lower() == text_gold.lower():
                            correct_case_ins += 1
                    batch_total = images.shape[
                        0]  # valid batch size of current steps
                    # calculate accuracy directly, due to non-distributed
                    word_acc = correct / batch_total
                    word_acc_case_ins = correct_case_ins / batch_total
                    edit_distance_acc = 1 - total_edit_distance / total_distance_ref

            # update valid metric and write to tensorboard,
            self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + step_idx, 'valid') \
                if self.local_master else None
            # self.val_metrics.update('loss', loss, batch_total)  # tag is loss/valid (mode =valid)
            self.val_metrics.update('word_acc', word_acc, batch_total)
            self.val_metrics.update('word_acc_case_insensitive',
                                    word_acc_case_ins, batch_total)
            self.val_metrics.update('edit_distance_acc', edit_distance_acc,
                                    total_distance_ref)

        val_metric_res_dict = self.val_metrics.result()

        # rollback to train mode
        self.model.train()

        return val_metric_res_dict

    def _distributed_predict(self, batch_size, images, text_label):
        # Allows distributed prediction on uneven batches.
        # Test set isn't always large enough for every GPU to get a batch

        # obtain the sum of all val metrics at all processes by all_reduce operation
        # dist.barrier()
        # batch_size = images.size(0)
        correct = correct_case_ins = valid_batches = total_edit_distance = total_distance_ref = 0
        if batch_size:  # not empty samples at current gpu validation process
            with torch.no_grad():
                images = images.to(self.device)
                # target = LabelTransformer.encode(text_label)
                # target = target.to(self.device)
                # target = target.permute(1, 0)

                if hasattr(self.model, 'module'):
                    model = self.model.module
                else:
                    model = self.model
                outputs = decode_util.greedy_decode(
                    model,
                    images,
                    LabelTransformer.max_length,
                    LabelTransformer.SOS,
                    padding_symbol=LabelTransformer.PAD,
                    device=images.device,
                    padding=True)

                for index, (pred, text_gold) in enumerate(
                        zip(outputs[:, 1:], text_label)):
                    predict_text = ""
                    for i in range(len(pred)):  # decode one sample
                        if pred[i] == LabelTransformer.EOS: break
                        if pred[i] == LabelTransformer.UNK: continue

                        decoded_char = LabelTransformer.decode(pred[i])
                        predict_text += decoded_char

                    # calculate edit distance
                    ref = len(text_gold)
                    edit_distance = distance.levenshtein(
                        text_gold, predict_text)
                    total_distance_ref += ref
                    total_edit_distance += edit_distance

                    # calculate word accuracy related
                    # predict_text = predict_text.strip()
                    # text_gold = text_gold.strip()
                    if predict_text == text_gold:
                        correct += 1
                    if predict_text.lower() == text_gold.lower():
                        correct_case_ins += 1

            valid_batches = 1  # can be regard as dist.world_size
        # sum metrics across all valid process
        sum_metrics_tensor = torch.tensor([
            batch_size, valid_batches, correct, correct_case_ins,
            total_edit_distance, total_distance_ref
        ]).float().to(self.device)
        # # Use a barrier() to make sure that all process have finished above code
        # dist.barrier()
        sum_metrics_tensor = self.sum_tesnor(sum_metrics_tensor)
        sum_metrics_tensor = sum_metrics_tensor.cpu().numpy()
        batch_total, valid_batches = sum_metrics_tensor[0:2]
        # averages metric across the valid process
        # loss= sum_metrics_tensor[2] / valid_batches
        correct, correct_case_ins, total_edit_distance, total_distance_ref = sum_metrics_tensor[
            2:]
        word_acc = correct / batch_total
        word_acc_case_ins = correct_case_ins / batch_total
        edit_distance_acc = 1 - total_edit_distance / total_distance_ref
        return word_acc, word_acc_case_ins, edit_distance_acc, total_distance_ref, batch_total

    def _get_lr(self):
        for group in self.optimizer.param_groups:
            return group['lr']

    def average_gradients(self, model):
        '''
        Gradient averaging
        :param model:
        :return:
        '''
        size = float(dist.get_world_size())
        for param in model.parameters():
            dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
            param.grad.data /= size

    def mean_reduce_tensor(self, tensor: torch.Tensor):
        ''' averages tensor across the whole world'''
        sum_tensor = self.sum_tesnor(tensor)
        return sum_tensor / dist.get_world_size()

    def sum_tesnor(self, tensor: torch.Tensor):
        '''obtain the sum of tensor at all processes'''
        rt = tensor.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        return rt

    def logger_info(self, msg):
        self.logger.info(msg) if self.local_master else None

    def logger_warning(self, msg):
        self.logger.warning(msg) if self.local_master else None

    def _prepare_device(self, local_rank, local_world_size):
        '''
         setup GPU device if available, move model into configured device
        :param local_rank:
        :param local_world_size:
        :return:
        '''
        if self.distributed:
            ngpu_per_process = torch.cuda.device_count() // local_world_size
            device_ids = list(
                range(local_rank * ngpu_per_process,
                      (local_rank + 1) * ngpu_per_process))

            if torch.cuda.is_available() and local_rank != -1:
                torch.cuda.set_device(
                    device_ids[0]
                )  # device_ids[0] =local_rank if local_world_size = n_gpu per node
                device = 'cuda'
                self.logger_info(
                    f"[Process {os.getpid()}] world_size = {dist.get_world_size()}, "
                    +
                    f"rank = {dist.get_rank()}, n_gpu/process = {ngpu_per_process}, device_ids = {device_ids}"
                )
            else:
                self.logger_warning('Training will be using CPU!')
                device = 'cpu'
            device = torch.device(device)
            return device, device_ids
        else:
            n_gpu = torch.cuda.device_count()
            n_gpu_use = local_world_size
            if n_gpu_use > 0 and n_gpu == 0:
                self.logger_warning(
                    "Warning: There\'s no GPU available on this machine,"
                    "training will be performed on CPU.")
                n_gpu_use = 0
            if n_gpu_use > n_gpu:
                self.logger_warning(
                    "Warning: The number of GPU\'s configured to use is {}, but only {} are available "
                    "on this machine.".format(n_gpu_use, n_gpu))
                n_gpu_use = n_gpu

            list_ids = list(range(n_gpu_use))
            if n_gpu_use > 0:
                torch.cuda.set_device(
                    list_ids[0])  # only use first available gpu as devices
                self.logger_warning(f'Training is using GPU {list_ids[0]}!')
                device = 'cuda'
            else:
                self.logger_warning('Training is using CPU!')
                device = 'cpu'
            device = torch.device(device)
            return device, list_ids

    def _save_checkpoint(self, epoch, save_best=False, step_idx=None):
        '''
        Saving checkpoints
        :param epoch:  current epoch number
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        :return:
        '''
        # only both local and global master process do save model
        if not (self.local_master and self.global_master):
            return

        if hasattr(self.model, 'module'):
            arch_name = type(self.model.module).__name__
            model_state_dict = self.model.module.state_dict()
        else:
            arch_name = type(self.model).__name__
            model_state_dict = self.model.state_dict()
        state = {
            'arch': arch_name,
            'epoch': epoch,
            'model_state_dict': model_state_dict,
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.monitor_best,
            'config': self.config
        }
        if step_idx is None:
            filename = str(self.checkpoint_dir /
                           'checkpoint-epoch{}.pth'.format(epoch))
        else:
            filename = str(
                self.checkpoint_dir /
                'checkpoint-epoch{}-step{}.pth'.format(epoch, step_idx))
        torch.save(state, filename)
        self.logger_info("Saving checkpoint: {} ...".format(filename))

        if save_best:
            best_path = str(self.checkpoint_dir / 'model_best.pth')
            shutil.copyfile(filename, best_path)
            self.logger_info(
                f"Saving current best (at {epoch} epoch): model_best.pth Best {self.monitor_metric}: {self.monitor_best:.6f}"
            )

        # if save_best:
        #     best_path = str(self.checkpoint_dir / 'model_best.pth')
        #     torch.save(state, best_path)
        #     self.logger_info(
        #         f"Saving current best: model_best.pth Best {self.monitor_metric}: {self.monitor_best:.6f}.")
        # else:
        #     filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
        #     torch.save(state, filename)
        #     self.logger_info("Saving checkpoint: {} ...".format(filename))

    def _resume_checkpoint(self, resume_path):
        '''
        Resume from saved checkpoints
        :param resume_path: Checkpoint path to be resumed
        :return:
        '''
        resume_path = str(resume_path)
        self.logger_info("Loading checkpoint: {} ...".format(resume_path))
        # map_location = {'cuda:%d' % 0: 'cuda:%d' % self.config['local_rank']}
        checkpoint = torch.load(resume_path, map_location=self.device)
        self.start_epoch = checkpoint['epoch'] + 1 if not self.finetune else 1
        self.monitor_best = checkpoint['monitor_best']

        # load architecture params from checkpoint.
        if checkpoint['config']['model_arch'] != self.config[
                'model_arch']:  # TODO verify adapt and adv arch
            self.logger_warning(
                "Warning: Architecture configuration given in config file is different from that of "
                "checkpoint. This may yield an exception while state_dict is being loaded."
            )
        # self.model.load_state_dict(checkpoint['state_dict'])
        self.model.load_state_dict(checkpoint['model_state_dict'])

        # load optimizer state from checkpoint only when optimizer type is not changed.
        if not self.finetune:  # resume mode will load optimizer state and continue train
            if checkpoint['config']['optimizer']['type'] != self.config[
                    'optimizer']['type']:
                self.logger_warning(
                    "Warning: Optimizer type given in config file is different from that of checkpoint. "
                    "Optimizer parameters not being resumed.")
            else:
                self.optimizer.load_state_dict(checkpoint['optimizer'])

        if self.finetune:
            self.logger_info(
                "Checkpoint loaded. Finetune training from epoch {}".format(
                    self.start_epoch))
        else:
            self.logger_info(
                "Checkpoint loaded. Resume training from epoch {}".format(
                    self.start_epoch))
Ejemplo n.º 4
0
class Trainer(BaseTrainer):
    """
    Trainer class
    Note:
        Inherited from BaseTrainer.
    """
    def __init__(self, model, metrics, optimizer, config, train_dataset, valid_datasets, lr_scheduler=None):
        super().__init__(model, metrics, optimizer, config, train_dataset)
        self.config = config
        self.config['data_loaders']['valid']['args']['batch_size'] = self.data_loader.batch_size
        self.valid_data_loaders = {}
        for corpus, valid_dataset in valid_datasets.items():
            self.valid_data_loaders[corpus] = config.init_obj('data_loaders.valid', module_loader, valid_dataset)
        self.lr_scheduler = lr_scheduler
        self.log_step = math.ceil(len(self.data_loader.dataset) / np.sqrt(self.data_loader.batch_size) / 200)
        self.writer = TensorboardWriter(config.log_dir)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch
        :param epoch: Current training epoch.
        :return: A log that contains all information you want to save.
        Note:
            If you have additional information to record, for example:
                > additional_log = {"x": x, "y": y}
            merge it with log before return. i.e.
                > log = {**log, **additional_log}
                > return log
            The metrics in log must have the key 'metrics'.
        """
        self.model.train()
        self.optimizer.zero_grad()

        total_loss = 0
        accum_count = 0
        for step, batch in enumerate(self.data_loader):
            # (input_ids, input_mask, segment_ids, ng_token_mask, target, deps, task)
            batch = {label: t.to(self.device, non_blocking=True) for label, t in batch.items()}
            current_step = (epoch - 1) * len(self.data_loader) + step

            loss, *_ = self.model(**batch, progress=current_step / self.total_step)

            if len(loss.size()) > 0:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            loss_value = loss.item()  # mean loss for this step
            total_loss += loss_value * next(iter(batch.values())).size(0)

            if step % self.log_step == 0:
                self.logger.info('Train Epoch: {} [{}/{} ({:.0f}%)] Time: {} Loss: {:.6f}'.format(
                    epoch,
                    step * self.data_loader.batch_size,
                    len(self.data_loader.dataset),
                    100.0 * step / len(self.data_loader),
                    datetime.datetime.now().strftime('%H:%M:%S'),
                    loss_value))

            gradient_accumulation_steps = self.gradient_accumulation_steps
            if step >= (len(self.data_loader) // self.gradient_accumulation_steps) * self.gradient_accumulation_steps:
                gradient_accumulation_steps = len(self.data_loader) % self.gradient_accumulation_steps  # fraction
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps  # scale loss
            loss.backward()
            accum_count += 1
            if accum_count == gradient_accumulation_steps:
                self.writer.set_step(self.writer.step + 1)
                self.writer.add_scalar('lr', self.lr_scheduler.get_last_lr()[0])
                self.writer.add_scalar('loss', loss_value)
                self.writer.add_scalar('progress', current_step / self.total_step)

                self.optimizer.step()
                self.optimizer.zero_grad()
                accum_count = 0
                if self.lr_scheduler is not None:
                    self.lr_scheduler.step()

        log = {
            'loss': total_loss / len(self.data_loader.dataset),
        }
        corpus2result = {}
        for corpus, valid_data_loader in self.valid_data_loaders.items():
            val_log = self._valid_epoch(valid_data_loader, corpus)
            log[f'val_{corpus}_loss'] = val_log['loss']
            corpus2result[corpus] = val_log['result']
            if 'all' not in corpus2result:
                corpus2result['all'] = val_log['result']
            else:
                corpus2result['all'] += val_log['result']
        for corpus, result in corpus2result.items():
            val_metrics = self._eval_metrics(result.to_dict(), corpus)
            log.update({f'val_{corpus}_{met.__name__}': v for met, v in zip(self.metrics, val_metrics)})

        return log

    def _valid_epoch(self, data_loader, corpus):
        """
        Validate after training an epoch
        :return: A log that contains information about validation
        Note:
            The validation metrics in log must have the key 'val_metrics'.
        """
        self.model.eval()
        total_loss = 0
        arguments_set: List[List[List[int]]] = []
        contingency_set: List[int] = []
        with torch.no_grad():
            for step, batch in enumerate(data_loader):
                batch = {label: t.to(self.device, non_blocking=True) for label, t in batch.items()}

                loss, *output = self.model(**batch)

                if len(loss.size()) > 0:
                    loss = loss.mean()
                pas_scores = output[0]  # (b, seq, case, seq)

                if corpus != 'commonsense':
                    arguments_set += torch.argmax(pas_scores, dim=3).tolist()  # (b, seq, case)

                total_loss += loss.item() * pas_scores.size(0)

                if step % self.log_step == 0:
                    self.logger.info('Validation [{}/{} ({:.0f}%)] Time: {}'.format(
                        step * data_loader.batch_size,
                        len(data_loader.dataset),
                        100.0 * step / len(data_loader),
                        datetime.datetime.now().strftime('%H:%M:%S')))

        log = {'loss': total_loss / len(data_loader.dataset)}
        self.writer.add_scalar(f'loss/{corpus}', log['loss'])

        if corpus != 'commonsense':
            dataset = data_loader.dataset
            prediction_writer = PredictionKNPWriter(dataset, self.logger)
            documents_pred = prediction_writer.write(arguments_set, None, add_pas_tag=False)
            targets2label = {tuple(): '', ('pred',): 'pred', ('noun',): 'noun', ('pred', 'noun'): 'all'}

            scorer = Scorer(documents_pred, dataset.gold_documents,
                            target_cases=dataset.target_cases,
                            target_exophors=dataset.target_exophors,
                            coreference=dataset.coreference,
                            bridging=dataset.bridging,
                            pas_target=targets2label[tuple(dataset.pas_targets)])
            result = scorer.run()
            log['result'] = result
        else:
            log['f1'] = self._eval_commonsense(contingency_set)

        return log

    def _eval_commonsense(self, contingency_set: List[int]) -> float:
        valid_data_loader = self.valid_data_loaders['commonsense']
        gold = [f.label for f in valid_data_loader.dataset.features]
        f1 = f1_score(gold, contingency_set)
        self.writer.add_scalar(f'commonsense_f1', f1)
        return f1

    def _eval_metrics(self, result: dict, corpus: str):
        f1_metrics = np.zeros(len(self.metrics))
        for i, metric in enumerate(self.metrics):
            f1_metrics[i] += metric(result)
            self.writer.add_scalar(f'{metric.__name__}/{corpus}', f1_metrics[i])
        return f1_metrics
class BaseTrainer:
    """
    Base class for all trainers
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 lr_scheduler,
                 config,
                 trainloader,
                 validloader=None,
                 len_epoch=None):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])
        self.trainloader = trainloader
        self.validloader = validloader

        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.trainloader)
        else:
            # iteration-based training
            self.trainloader = inf_loop(trainloader)
            self.len_epoch = len_epoch

        # setup GPU device if available, move model into configured device
        n_gpu_use = torch.cuda.device_count()
        self.device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
        self.model = model.to(self.device)
        self.model = torch.nn.DataParallel(model)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.log_step = cfg_trainer['log_step']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        self.start_epoch = 1
        self.checkpoint_dir = config.save_dir

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])
        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

        if config.resume is not None:
            self._resume_checkpoint(config.resume)

    @abstractmethod
    def _train_step(self, batch):
        """
        Training logic for a step

        :param batch: batch of current step
        :return: 
            loss: torch Variable with map for backwarding
            mets: metrics computed between output and target, dict
        """
        raise NotImplementedError

    @abstractmethod
    def _valid_step(self, batch):
        """
        Valid logic for a step

        :param batch: batch of current step
        :return:
            loss: torch Variable without map
            mets: metrics computed between output and target, dict
        """
        raise NotImplementedError

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()

        tic = time.time()
        datatime = batchtime = 0
        for batch_idx, batch in enumerate(self.trainloader):
            datatime += time.time() - tic
            # -------------------------------------------------------------------------
            loss, mets = self._train_step(batch)
            # -------------------------------------------------------------------------
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            batchtime += time.time() - tic
            tic = time.time()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for key, val in mets.items():
                self.train_metrics.update(key, val)

            if batch_idx % self.log_step == 0:
                processed_percent = batch_idx / self.len_epoch * 100
                self.logger.debug(
                    'Train Epoch:{} [{}/{}]({:.0f}%)\tTime:{:5.2f}/{:<5.2f}\tLoss:({:.4f}){:.4f}'
                    .format(epoch, batch_idx, self.len_epoch,
                            processed_percent, datatime, batchtime,
                            loss.item(), self.train_metrics.avg('loss')))
                datatime = batchtime = 0

            if batch_idx == self.len_epoch:
                break

        log = self.train_metrics.result()
        log = {'train_' + k: v for k, v in log.items()}

        if self.validloader is not None:
            val_log = self._valid_epoch(epoch)
            log.update(**{'valid_' + k: v for k, v in val_log.items()})
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        for batch_idx, batch in enumerate(self.validloader):
            # -------------------------------------------------------------------------
            loss, mets = self._valid_step(batch)
            # -------------------------------------------------------------------------
            self.writer.set_step(
                (epoch - 1) * len(self.validloader) + batch_idx, 'valid')
            self.valid_metrics.update('loss', loss.item())
            for key, val in mets.items():
                self.valid_metrics.update(key, val)

        return self.valid_metrics.result()

    def train(self):
        """
        Full training logic
        """
        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)

            # save logged informations into log dict
            lr = self.optimizer.param_groups[0]['lr']
            log = {'epoch': epoch, 'lr': lr}
            log.update(result)

            # print logged informations to the screen
            for key, value in log.items():
                self.logger.info('    {:20s}: {}'.format(str(key), value))

            # evaluate model performance according to configured metric, save best checkpoint as model_best
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
                               (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. "
                        "Model performance monitoring is disabled.".format(
                            self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False

                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1

                if not_improved_count > self.early_stop:
                    self.logger.info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if self.lr_scheduler is not None:
                if isinstance(self.lr_scheduler, ReduceLROnPlateau):
                    self.lr_scheduler.step(log[self.mnt_metric])
                else:
                    self.lr_scheduler.step()

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)

            # add histogram of model parameters to the tensorboard
            self.writer.set_step(epoch)
            for name, p in self.model.named_parameters():
                self.writer.add_histogram(name, p, bins='auto')

    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints

        :param epoch: current epoch number
        :param log: logging information of the epoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        state = {
            'epoch': epoch,
            'model': self.model.module.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'lr_scheduler': self.lr_scheduler.state_dict(),
            'monitor_best': self.mnt_best
        }
        filename = str(self.checkpoint_dir / 'chkpt_{:03d}.pth'.format(epoch))
        torch.save(state, filename)
        self.logger.info("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path = str(self.checkpoint_dir / 'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def _resume_checkpoint(self, resume_path):
        """
        Resume from saved checkpoints

        :param resume_path: Checkpoint path to be resumed
        """
        resume_path = str(resume_path)
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path)
        try:
            self.start_epoch = checkpoint['epoch'] + 1
            self.model.module.load_state_dict(checkpoint['model'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            self.mnt_best = checkpoint['monitor_best']
        except KeyError:
            self.model.module.load_state_dict(checkpoint)

        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))