Exemple #1
0
    def _reduce_and_log_stats(self,
                              logging_outputs,
                              sample_size,
                              grad_norm=None):
        if grad_norm is not None:
            metrics.log_speed("ups", 1., priority=100, round=2)
            metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
            if self.args['optimization']['clip_norm'] > 0:
                metrics.log_scalar(
                    "clip",
                    torch.where(
                        grad_norm > self.args['optimization']['clip_norm'],
                        grad_norm.new_tensor(100),
                        grad_norm.new_tensor(0),
                    ),
                    priority=500,
                    round=1,
                )

        with metrics.aggregate() as agg:
            if logging_outputs is not None:
                self.task.reduce_metrics(logging_outputs, self.get_criterion())

            # support legacy interface
            logging_output = agg.get_smoothed_values()
            logging_output["sample_size"] = sample_size
            for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
                if key_to_delete in logging_output:
                    del logging_output[key_to_delete]
            return logging_output
Exemple #2
0
 def aggregate_logging_outputs(self, logging_outputs, criterion):
     """[deprecated] Aggregate logging outputs from data parallel training."""
     utils.deprecation_warning(
         "The aggregate_logging_outputs API is deprecated. "
         "Please use the reduce_metrics API instead."
     )
     with metrics.aggregate() as agg:
         self.reduce_metrics(logging_outputs, criterion)
         return agg.get_smoothed_values()
Exemple #3
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args['dataset']['fixed_validation_seed'] is not None:
        # set fixed seed for every validation
        set_seed.set_torch_seed(args['dataset']['fixed_validation_seed'])

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args['dataset']['max_tokens_valid'],
            max_sentences=args['dataset']['max_sentences_valid'],
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args['dataset']
            ['skip_invalid_size_inputs_valid_test'],
            required_batch_size_multiple=args['dataset']
            ['required_batch_size_multiple'],
            seed=args['common']['seed'],
            num_shards=args['distributed_training']['distributed_world_size'],
            shard_id=args['distributed_training']['distributed_rank'],
            num_workers=args['dataset']['num_workers'],
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args['common']['log_format'],
            log_interval=args['common']['log_interval'],
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args['common']['tensorboard_logdir'] if
                                distributed_utils.is_master(args) else None),
            default_log_format=('tqdm' if not args['common']['no_progress_bar']
                                else 'simple'),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_step(sample)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(
            stats[args['checkpoint']['best_checkpoint_metric']])

    return valid_losses
Exemple #4
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args['distributed_training']
        ['fix_batches_to_gpus'],
        # shuffle=(epoch_itr.next_epoch_idx > args['dataset']['curriculum']),
        shuffle=False,
    )
    update_freq = (args['optimization']['update_freq'][epoch_itr.epoch - 1] if
                   epoch_itr.epoch <= len(args['optimization']['update_freq'])
                   else args['optimization']['update_freq'][-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args['common']['log_format'],
        log_interval=args['common']['log_interval'],
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args['common']['tensorboard_logdir']
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args['common']['no_progress_bar']
                            else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args['dataset']['valid_subset'].split(',')
    max_update = args['optimization']['max_update'] or math.inf
    num_updates = 0  # init as 0, for zero-shot learning
    for samples in progress:
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args['common']['log_interval'] == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset epoch-level meters
            metrics.reset_meters('train_inner')

        if (not args['dataset']['disable_validation']
                and args['checkpoint']['save_interval_updates'] > 0 and
                num_updates % args['checkpoint']['save_interval_updates'] == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
Exemple #5
0
def validate(args, trainer, task, epoch_itr, valid_subsets, dev_subsets,
             dev_refs):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args['dataset']['fixed_validation_seed'] is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args['dataset']['fixed_validation_seed'])

    for subset in valid_subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args['dataset']['max_tokens_valid'],
            max_sentences=args['dataset']['max_sentences_valid'],
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args['dataset']
            ['skip_invalid_size_inputs_valid_test'],
            required_batch_size_multiple=args['dataset']
            ['required_batch_size_multiple'],
            seed=args['common']['seed'],
            num_shards=args['distributed_training']['distributed_world_size'],
            shard_id=args['distributed_training']['distributed_rank'],
            num_workers=args['dataset']['num_workers'],
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args['common']['log_format'],
            log_interval=args['common']['log_interval'],
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args['common']['tensorboard_logdir'] if
                                distributed_utils.is_master(args) else None),
            default_log_format=('tqdm' if not args['common']['no_progress_bar']
                                else 'simple'),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_step(sample)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        # calculate accuracy
        match = stats.pop('match')
        total = stats.pop('total')
        valid_acc = match / total
        progress.print(
            {
                'accuracy': f'{round(100. * valid_acc, 2)}%',
                'bleu': stats['bleu'],
                'loss': stats['loss'],
            },
            tag=subset,
            step=trainer.get_num_updates())

    # for subset in dev_subsets:
    #     hypotheses, references = {}, dev_refs
    #
    #     # Initialize data iterator
    #     itr = task.get_batch_iterator(
    #         dataset=task.dataset(subset),
    #         max_tokens=args['dataset']['max_tokens_valid'],
    #         max_sentences=args['dataset']['max_sentences_valid'],
    #         max_positions=utils.resolve_max_positions(
    #             task.max_positions(),
    #             trainer.get_model().max_positions(),
    #         ),
    #         ignore_invalid_inputs=args['dataset']['skip_invalid_size_inputs_valid_test'],
    #         required_batch_size_multiple=args['dataset']['required_batch_size_multiple'],
    #         seed=args['common']['seed'],
    #         num_shards=args['distributed_training']['distributed_world_size'],
    #         shard_id=args['distributed_training']['distributed_rank'],
    #         num_workers=args['dataset']['num_workers'],
    #     ).next_epoch_itr(shuffle=False)
    #     progress = progress_bar.progress_bar(
    #         itr,
    #         log_format=args['common']['log_format'],
    #         log_interval=args['common']['log_interval'],
    #         epoch=epoch_itr.epoch,
    #         prefix=f"valid on '{subset}' subset",
    #         tensorboard_logdir=(
    #             args['common']['tensorboard_logdir'] if distributed_utils.is_master(args) else None
    #         ),
    #         default_log_format=('tqdm' if not args['common']['no_progress_bar'] else 'simple'),
    #     )
    #
    #     # create a new root metrics aggregator so validation metrics
    #     # don't pollute other aggregators (e.g., train meters)
    #     with metrics.aggregate(new_root=True) as agg:
    #         for sample in progress:
    #             with torch.no_grad():
    #                 trainer.model.eval()
    #                 trainer.criterion.eval()
    #                 sample = trainer._prepare_sample(sample)
    #                 hyps, _, _, ids = trainer.task.step_out(sample, trainer.model)
    #                 for idx, hypo in zip(ids, hyps):
    #                     hypotheses[idx] = hypo
    #
    #     from third_party.pycocoevalcap.bleu.google_bleu import compute_bleu
    #     assert set(hypotheses.keys()) == set(references.keys())
    #     bleus = [
    #         compute_bleu([references[idx]], [hypotheses[idx]], smooth=Trainer)[0]
    #         for idx in hypotheses.keys()
    #     ]
    #     dev_bleu = round(100. * sum(bleus) / len(bleus), 2)
    #     # log validation stats
    #     stats = agg.get_smoothed_values()
    #     stats['bleu'] = dev_bleu
    #     stats = get_dev_stats(args, trainer, stats)
    #     progress.print(stats, tag=subset, step=trainer.get_num_updates())
    # return valid_acc, dev_bleu
    return valid_acc, None