Пример #1
0
def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask,
          epoch_itr) -> Tuple[List[Optional[float]], bool]:
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
    )
    update_freq = (cfg.optimization.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(cfg.optimization.update_freq) else
                   cfg.optimization.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if cfg.common.tpu:
        itr = utils.tpu_data_loader(itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(cfg.common.tensorboard_logdir
                            if distributed_utils.is_master(
                                cfg.distributed_training) else None),
        default_log_format=("tqdm"
                            if not cfg.common.no_progress_bar else "simple"),
        wandb_project=(cfg.common.wandb_project if distributed_utils.is_master(
            cfg.distributed_training) else None),
        wandb_run_name=os.environ.get(
            "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
        azureml_logging=(cfg.common.azureml_logging
                         if distributed_utils.is_master(
                             cfg.distributed_training) else False),
    )
    progress.update_config(_flatten_config(cfg))

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = cfg.dataset.valid_subset.split(",")
    should_stop = False
    num_updates = trainer.get_num_updates()
    logger.info("Start iterating over samples")
    for i, samples in enumerate(progress):
        with metrics.aggregate(
                "train_inner"), torch.autograd.profiler.record_function(
                    "train_step-%d" % i):
            log_output = trainer.train_step(samples)

        if log_output is not None:  # not OOM, overflow, ...
            # log mid-epoch stats
            num_updates = trainer.get_num_updates()
            if num_updates % cfg.common.log_interval == 0:
                stats = get_training_stats(
                    metrics.get_smoothed_values("train_inner"))
                progress.log(stats, tag="train_inner", step=num_updates)

                # reset mid-epoch stats after each log interval
                # the end-of-epoch stats will still be preserved
                metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop = validate_and_save(cfg, trainer, task,
                                                      epoch_itr, valid_subsets,
                                                      end_of_epoch)

        if should_stop:
            break

    # log end-of-epoch stats
    logger.info("end of epoch {} (average epoch stats below)".format(
        epoch_itr.epoch))
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
Пример #2
0
def compute_head_importance(
    cfg: DictConfig,
    trainer: Trainer,
    task: tasks.FairseqTask,
    epoch_itr,
    head_mask=None,
) -> Tuple[List[Optional[float]], bool]:
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
    )
    if cfg.common.tpu:
        itr = utils.tpu_data_loader(itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(cfg.common.tensorboard_logdir
                            if distributed_utils.is_master(
                                cfg.distributed_training) else None),
        default_log_format=("tqdm"
                            if not cfg.common.no_progress_bar else "simple"),
        wandb_project=(cfg.common.wandb_project if distributed_utils.is_master(
            cfg.distributed_training) else None),
        wandb_run_name=os.environ.get(
            "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
        azureml_logging=(cfg.common.azureml_logging
                         if distributed_utils.is_master(
                             cfg.distributed_training) else False),
    )

    # Initialize head importance scores
    encoder_layers = trainer.cfg.model.encoder_layers
    decoder_layers = trainer.cfg.model.decoder_layers
    encoder_heads = trainer.cfg.model.encoder_attention_heads
    decoder_heads = trainer.cfg.model.decoder_attention_heads
    device = next(trainer.model.parameters()).device
    assert encoder_heads == decoder_heads
    head_importance = torch.zeros(
        [encoder_layers + 2 * decoder_layers, decoder_heads]).to(device)

    # Initialize head masks
    if head_mask is None:
        head_mask = torch.ones(
            [encoder_layers + 2 * decoder_layers, decoder_heads]).to(device)
    head_mask.requires_grad_(requires_grad=True)

    trainer.begin_epoch(epoch_itr.epoch)

    for i, samples in enumerate(tqdm(progress)):
        if head_importance is not None:
            head_importance += trainer.prune_step(samples, head_mask)

    # Normalize by layer
    exponent = 2
    norm_by_layer = torch.pow(
        torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
    head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20

    head_importance = (head_importance - head_importance.min()) / (
        head_importance.max() - head_importance.min())

    return head_importance