Esempio n. 1
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    if getattr(args, "tpu", False):
        itr = tpu_data_loader(args, itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(
            args.tensorboard_logdir if distributed_utils.is_master(args) else None
        ),
        default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
    )

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(",")
    should_stop = False
    for i, samples in enumerate(progress):
        with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function("train_step-%d" % i):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
            progress.log(stats, tag="train_inner", step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop = validate_and_save(
            args, trainer, task, epoch_itr, valid_subsets, end_of_epoch
        )
        if should_stop:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
Esempio n. 2
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    for samples in progress:
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        if (not args.disable_validation and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0):
            valid_losses = validate(args, trainer, task, epoch_itr,
                                    valid_subsets)
            checkpoint_utils.save_checkpoint(args, trainer, epoch_itr,
                                             valid_losses[0])

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
Esempio n. 3
0
 def valid_step(self, batch_itr):
     if self.model.global_steps % self.fs_args.validate_interval_updates != 0:
         return
     with torch.no_grad():
         self.model.eval()
         for subset in batch_itr.valid_dataset():
             with metrics.aggregate(new_root=True) as agg:
                 for batch, is_dummy_batch in batch_itr.valid_batch():
                     _, sample_size, logging_output = self.task.valid_step(
                         batch, self.model.module.model, self.model.module.criterion
                     )
                     logging_outputs = [logging_output]
                     if is_dummy_batch:
                         if torch.is_tensor(sample_size):
                             sample_size.zero_()
                         else:
                             sample_size *= 0.0
                     logging_outputs, (sample_size,) = torch_reduce_sum(
                         self.model.device,
                         logging_outputs,
                         sample_size,
                         ignore=is_dummy_batch,
                     )
                     logging_output = self.reduce_log(logging_outputs, sample_size)
             log_dist(
                 "Valid on step: {}, dataset: {}. {}".format(
                     self.model.global_steps,
                     subset,
                     view_log(agg.get_smoothed_values()),
                 ),
                 ranks=[0],
             )
Esempio n. 4
0
    def _reduce_and_log_stats(self,
                              logging_outputs,
                              sample_size,
                              grad_norm=None):
        if grad_norm is not None:
            metrics.log_speed("ups", 1., priority=100, round=2)
            metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
            if self.args.clip_norm > 0:
                metrics.log_scalar(
                    "clip",
                    torch.where(
                        grad_norm > self.args.clip_norm,
                        grad_norm.new_tensor(100),
                        grad_norm.new_tensor(0),
                    ),
                    priority=500,
                    round=1,
                )

        with metrics.aggregate() as agg:
            if logging_outputs is not None:
                self.task.reduce_metrics(logging_outputs, self.get_criterion())

            # support legacy interface
            logging_output = agg.get_smoothed_values()
            logging_output["sample_size"] = sample_size
            for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
                if key_to_delete in logging_output:
                    del logging_output[key_to_delete]
            return logging_output
Esempio n. 5
0
def validate(
    cfg: DictConfig,
    trainer: Trainer,
    task: tasks.FairseqTask,
    epoch_itr,
    subsets: List[str],
) -> List[Optional[float]]:
    """Evaluate the model on the validation set(s) and return the losses."""

    if cfg.dataset.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(cfg.dataset.fixed_validation_seed)

    trainer.begin_valid_epoch(epoch_itr.epoch)
    valid_losses = []
    for subset in subsets:
        logger.info('begin validation on "{}" subset'.format(subset))

        # Initialize data iterator
        itr = trainer.get_valid_iterator(subset).next_epoch_itr(
            shuffle=False, set_dataset_epoch=False  # use a fixed valid set
        )
        if cfg.common.tpu:
            itr = utils.tpu_data_loader(itr)
        progress = progress_bar.progress_bar(
            itr,
            log_format=cfg.common.log_format,
            log_interval=cfg.common.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(
                cfg.common.tensorboard_logdir
                if distributed_utils.is_master(cfg.distributed_training)
                else None
            ),
            default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
            wandb_project=(
                cfg.common.wandb_project
                if distributed_utils.is_master(cfg.distributed_training)
                else None
            ),
            wandb_run_name=os.environ.get(
                "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
            ),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for i, sample in enumerate(progress):
                if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps:
                    break
                trainer.valid_step(sample)

        # log validation stats
        stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
    return valid_losses
Esempio n. 6
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    model = trainer.model

    val_conf = {
        "encoder": [{
            "self_attn": []
        } for i in range(args.encoder_layers)],
        "decoder": [{
            "self_attn": [],
            "enc_attn": []
        } for i in range(args.decoder_layers)]
    }

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
        if getattr(args, "tpu", False):
            itr = tpu_data_loader(args, itr)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args.tensorboard_logdir if
                                distributed_utils.is_master(args) else None),
            default_log_format=("tqdm"
                                if not args.no_progress_bar else "simple"),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_step(sample)

                # Get confidence for each head
                if args.head_confidence_method is not None:
                    val_conf = get_batch_confs(model, val_conf, args)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric])

    if args.head_confidence_method is not None:
        val_conf = convert_confs(val_conf, args)

        val_conf = calc_conf_per_epoch(val_conf, args)

    return valid_losses, val_conf
Esempio n. 7
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""
    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    # reset dummy batch only for validation
    trainer._dummy_batch = "DUMMY"  # reset dummy batch

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args.tensorboard_logdir if
                                distributed_utils.is_master(args) else None),
            default_log_format=('tqdm'
                                if not args.no_progress_bar else 'simple'),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for step, sample in enumerate(progress):
                trainer.valid_step(sample)
                stats = get_training_stats(agg.get_smoothed_values())
                plog = progress.log
                if hasattr(progress, "wrapped_bar"):
                    plog = progress.wrapped_bar.log
                plog(stats, tag='valid', step=step)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric])

    # reset dummy batch again for continuing training
    trainer._dummy_batch = "DUMMY"
    return valid_losses
Esempio n. 8
0
def train(args, trainer, task, epoch_itr, max_update=math.inf):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )

    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if getattr(args, 'tpu', False):
        itr = tpu_data_loader(args, itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )
    progress.log_args(args, tag='train')

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(',')
    for samples in progress:
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        end_of_epoch = not itr.has_next()
        valid_losses = validate_and_save(args, trainer, task, epoch_itr,
                                         valid_subsets, end_of_epoch)
        if should_stop_early(args,
                             valid_losses[0]) or num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
    return valid_losses
Esempio n. 9
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(
                args.tensorboard_logdir if distributed_utils.is_master(args) else None
            ),
            default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        all_preds, all_labels = [], []
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                logging_outputs, preds, labels = trainer.valid_step(sample)
                if preds is not None:
                    all_preds.extend(preds)
                    all_labels.extend(labels)
                else:
                    all_preds, all_labels = None, None

        if all_preds is not None:
            all_preds = torch.cat(all_preds).cpu().numpy()
            all_labels = torch.cat(all_labels).cpu().numpy()

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values(), all_preds, all_labels)
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric])
    return valid_losses
Esempio n. 10
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    assert args.max_sentences_valid == 1, 'Val only supports batch size 1!'
    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    shuffle = False
    if args.validation_max_size > 0:
        logging.info(f'Validation set truncated to {args.validation_max_size}.')
        shuffle = True
        assert args.seed == 1234, args.seed
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=1,#args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=shuffle)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(
                args.tensorboard_logdir if distributed_utils.is_master(args) else None
            ),
            default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            i = 0
            for sample in progress:
                i += 1
                if args.validation_max_size > 0 and i > args.validation_max_size / args.distributed_world_size:
                    continue
                trainer.valid_step(sample, validation_topk=args.validation_topk, validation_D=args.validation_D, validation_rounds=args.validation_rounds)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric])
    return valid_losses
Esempio n. 11
0
 def reduce_log(self, logging_outputs, sample_size):
     with metrics.aggregate() as agg:
         if logging_outputs is not None:
             self.task.reduce_metrics(logging_outputs, self.criterion)
             del logging_outputs
     logging_output = agg.get_smoothed_values()
     logging_output["sample_size"] = sample_size
     return logging_output
Esempio n. 12
0
def validate(
    cfg: DictConfig,
    trainer: Trainer,
    task: tasks.FairseqTask,
    cur_step,
    epoch_itr,
    subsets: List[str],
) -> List[Optional[float]]:
    """Evaluate the model on the validation set(s) and return the losses."""

    if cfg.dataset.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(cfg.dataset.fixed_validation_seed)

    trainer.begin_valid_epoch(epoch_itr.epoch)
    valid_losses = []
    for subset in subsets:
        logger.info('begin validation on "{}" subset'.format(subset))

        # Initialize data iterator
        itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
        if cfg.common.tpu:
            itr = utils.tpu_data_loader(itr)
        progress = progress_bar.progress_bar(
            itr,
            log_format=cfg.common.log_format,
            log_interval=cfg.common.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(cfg.common.tensorboard_logdir
                                if distributed_utils.is_master(
                                    cfg.distributed_training) else None),
            default_log_format=("tqdm" if not cfg.common.no_progress_bar else
                                "simple"),
            wandb_project=(cfg.common.wandb_project
                           if distributed_utils.is_master(
                               cfg.distributed_training) else None),
            wandb_run_name=os.environ.get(
                "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        # 需要通过metrics.log_scalar("key", val)添加到metrics里面,才能在agg中显示出来log
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_step(sample, cur_step=cur_step)
        # import pdb
        # pdb.set_trace()
        # log validation stats
        # stats里面已经有了agg.get_smoothed_values()这个orderedDict作为基础,通过get_valid_stats函数获得一些其他的state值
        stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())
        valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
        # print("In fairseq/fairseq_cli/train.py line 400:\n{} not in stats".format(cfg.checkpoint.best_checkpoint_metric))

    return valid_losses
Esempio n. 13
0
    def _reduce_and_log_stats(self,
                              logging_outputs,
                              sample_size,
                              grad_norm=None):
        metrics.log_scalar("kl_loss",
                           round(logging_outputs[0]["kl_loss"].item(), 3))
        metrics.log_scalar("kld", round(logging_outputs[0]["kld"].item(), 3))
        metrics.log_scalar("bow_loss",
                           round(logging_outputs[0]["bow_loss"].item(), 3))

        if grad_norm is not None and (not torch.is_tensor(grad_norm)
                                      or torch.isfinite(grad_norm)):
            metrics.log_speed("ups", 1.0, priority=100, round=2)
            metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
            if self.cfg.optimization.clip_norm > 0:
                metrics.log_scalar(
                    "clip",
                    torch.where(
                        grad_norm > self.cfg.optimization.clip_norm,
                        grad_norm.new_tensor(100),
                        grad_norm.new_tensor(0),
                    ),
                    priority=500,
                    round=1,
                )

        with metrics.aggregate() as agg:

            if logging_outputs is not None:
                self.task.reduce_metrics(logging_outputs, self.get_criterion())
                del logging_outputs

            # extra warning for criterions that don't properly log a loss value
            if "loss" not in agg:
                if "loss" not in self._warn_once:
                    self._warn_once.add("loss")
                    logger.warning(
                        "Criterion.reduce_metrics did not log a 'loss' value, "
                        "which may break some functionality")
                metrics.log_scalar("loss", -1)
            # support legacy interface
            if self.tpu:
                logging_output = {}
            else:
                logging_output = agg.get_smoothed_values()
                logging_output["sample_size"] = sample_size
                for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
                    if key_to_delete in logging_output:
                        del logging_output[key_to_delete]
            return logging_output
Esempio n. 14
0
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        logger.info('begin validation on "{}" subset'.format(subset))

        # Initialize data iterator
        itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
        if getattr(args, "tpu", False):
            itr = utils.tpu_data_loader(itr)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args.tensorboard_logdir if
                                distributed_utils.is_master(args) else None),
            default_log_format=("tqdm"
                                if not args.no_progress_bar else "simple"),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_step(sample)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        if "fg_gloss0" in stats:
            criterion = trainer.get_criterion()
            ngroups = criterion.n_groups
            baselines = torch.zeros(ngroups, device='cuda')
            for ii in range(ngroups):
                key = "fg_gloss{}".format(ii)
                baselines[ii] = stats[key]
                stats.pop(key, None)
            criterion.set_valid_baselines(baselines)

        progress.print(stats, tag=subset, step=trainer.get_num_updates())
        valid_losses.append(stats[args.best_checkpoint_metric])
    return valid_losses
def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        logger.info('begin validation on "{}" subset'.format(subset))

        # Initialize data iterator
        itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
        if getattr(args, "tpu", False):
            itr = utils.tpu_data_loader(itr)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args.tensorboard_logdir if
                                distributed_utils.is_master(args) else None),
            default_log_format=("tqdm"
                                if not args.no_progress_bar else "simple"),
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)

        count = 0
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_step(sample)
                count += 1
                if count % 50 == 0:
                    logger.info("Processed {} batches!".format(count))

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric])
    return valid_losses
Esempio n. 16
0
    def _reduce_and_log_stats(self, logging_outputs, sample_size):
        if logging_outputs is None or len(logging_outputs) == 0:
            return {"sample_size": sample_size}
        with metrics.aggregate() as agg:
            # convert logging_outputs to CPU to avoid unnecessary
            # device-to-host transfers in reduce_metrics
            logging_outputs = utils.apply_to_sample(
                lambda t: t.to(
                    device='cpu', non_blocking=True, dtype=torch.double),
                logging_outputs)

            self.task.reduce_metrics(logging_outputs, self.get_criterion())

            # support legacy interface
            logging_output = agg.get_smoothed_values()
            logging_output["sample_size"] = sample_size
            for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
                if key_to_delete in logging_output:
                    del logging_output[key_to_delete]
            return logging_output
Esempio n. 17
0
def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask,
          epoch_itr) -> Tuple[List[Optional[float]], bool]:
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
    )
    update_freq = (cfg.optimization.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(cfg.optimization.update_freq) else
                   cfg.optimization.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if cfg.common.tpu:
        itr = utils.tpu_data_loader(itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=cfg.common.log_format,
        log_interval=cfg.common.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(cfg.common.tensorboard_logdir
                            if distributed_utils.is_master(
                                cfg.distributed_training) else None),
        default_log_format=("tqdm"
                            if not cfg.common.no_progress_bar else "simple"),
        wandb_project=(cfg.common.wandb_project if distributed_utils.is_master(
            cfg.distributed_training) else None),
        wandb_run_name=os.environ.get(
            "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)),
        azureml_logging=(cfg.common.azureml_logging
                         if distributed_utils.is_master(
                             cfg.distributed_training) else False),
    )
    progress.update_config(_flatten_config(cfg))

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = cfg.dataset.valid_subset.split(",")
    should_stop = False
    num_updates = trainer.get_num_updates()
    logger.info("Start iterating over samples")
    for i, samples in enumerate(progress):
        with metrics.aggregate(
                "train_inner"), torch.autograd.profiler.record_function(
                    "train_step-%d" % i):
            log_output = trainer.train_step(samples)

        if log_output is not None:  # not OOM, overflow, ...
            # log mid-epoch stats
            num_updates = trainer.get_num_updates()
            if num_updates % cfg.common.log_interval == 0:
                stats = get_training_stats(
                    metrics.get_smoothed_values("train_inner"))
                progress.log(stats, tag="train_inner", step=num_updates)

                # reset mid-epoch stats after each log interval
                # the end-of-epoch stats will still be preserved
                metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop = validate_and_save(cfg, trainer, task,
                                                      epoch_itr, valid_subsets,
                                                      end_of_epoch)

        if should_stop:
            break

    # log end-of-epoch stats
    logger.info("end of epoch {} (average epoch stats below)".format(
        epoch_itr.epoch))
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
Esempio n. 18
0
def main(cfg: DictConfig, override_args=None):
    if isinstance(cfg, Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    assert (
        cfg.dataset.max_tokens is not None
        or cfg.dataset.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"

    use_fp16 = cfg.common.fp16
    use_cuda = torch.cuda.is_available() and not cfg.common.cpu

    if use_cuda:
        torch.cuda.set_device(cfg.distributed_training.device_id)

    if cfg.distributed_training.distributed_world_size > 1:
        data_parallel_world_size = distributed_utils.get_data_parallel_world_size(
        )
        data_parallel_rank = distributed_utils.get_data_parallel_rank()
    else:
        data_parallel_world_size = 1
        data_parallel_rank = 0

    if override_args is not None:
        overrides = vars(override_args)
        overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
    else:
        overrides = None

    # Load ensemble
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
        [cfg.common_eval.path],
        arg_overrides=overrides,
        suffix=cfg.checkpoint.checkpoint_suffix,
    )
    model = models[0]

    # Move models to GPU
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Print args
    logger.info(saved_cfg)

    # Build criterion
    criterion = task.build_criterion(saved_cfg.criterion)
    criterion.eval()

    for subset in cfg.dataset.valid_subset.split(","):
        try:
            task.load_dataset(subset,
                              combine=False,
                              epoch=1,
                              task_cfg=saved_cfg.task)
            dataset = task.dataset(subset)
        except KeyError:
            raise Exception("Cannot find dataset: " + subset)

        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=cfg.dataset.max_tokens,
            max_sentences=cfg.dataset.batch_size,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[m.max_positions() for m in models],
            ),
            ignore_invalid_inputs=cfg.dataset.
            skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=cfg.dataset.
            required_batch_size_multiple,
            seed=cfg.common.seed,
            num_shards=data_parallel_world_size,
            shard_id=data_parallel_rank,
            num_workers=cfg.dataset.num_workers,
            data_buffer_size=cfg.dataset.data_buffer_size,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=cfg.common.log_format,
            log_interval=cfg.common.log_interval,
            prefix=f"valid on '{subset}' subset",
            default_log_format=("tqdm" if not cfg.common.no_progress_bar else
                                "simple"),
        )

        log_outputs = []
        for i, sample in enumerate(progress):
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            _loss, _sample_size, log_output = task.valid_step(
                sample, model, criterion)
            progress.log(log_output, step=i)
            log_outputs.append(log_output)

        if data_parallel_world_size > 1:
            log_outputs = distributed_utils.all_gather_list(
                log_outputs,
                max_size=cfg.common.all_gather_list_size,
                group=distributed_utils.get_data_parallel_group(),
            )
            log_outputs = list(chain.from_iterable(log_outputs))

        with metrics.aggregate() as agg:
            task.reduce_metrics(log_outputs, criterion)
            log_output = agg.get_smoothed_values()

        progress.print(log_output, tag=subset, step=i)
Esempio n. 19
0
def main(args, override_args=None):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    use_fp16 = args.fp16
    use_cuda = torch.cuda.is_available() and not args.cpu

    if use_cuda:
        torch.cuda.set_device(args.device_id)

    if override_args is not None:
        overrides = vars(override_args)
        overrides.update(eval(getattr(override_args, 'model_overrides', '{}')))
    else:
        overrides = None

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
        [args.path],
        arg_overrides=overrides,
        suffix=getattr(args, "checkpoint_suffix", ""),
    )
    model = models[0]

    # Move models to GPU
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Print args
    logger.info(model_args)

    # Build criterion
    criterion = task.build_criterion(model_args)
    criterion.eval()

    for subset in args.valid_subset.split(','):
        try:
            task.load_dataset(subset, combine=False, epoch=1)
            dataset = task.dataset(subset)
        except KeyError:
            raise Exception('Cannot find dataset: ' + subset)

        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[m.max_positions() for m in models],
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
            data_buffer_size=args.data_buffer_size,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            prefix=f"valid on '{subset}' subset",
            default_log_format=('tqdm'
                                if not args.no_progress_bar else 'simple'),
        )

        log_outputs = []
        for i, sample in enumerate(progress):
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            _loss, _sample_size, log_output = task.valid_step(
                sample, model, criterion)
            progress.log(log_output, step=i)
            log_outputs.append(log_output)

        if args.distributed_world_size > 1:
            log_outputs = distributed_utils.all_gather_list(
                log_outputs,
                max_size=getattr(args, 'all_gather_list_size', 16384),
            )
            log_outputs = list(chain.from_iterable(log_outputs))

        with metrics.aggregate() as agg:
            task.reduce_metrics(log_outputs, criterion)
            log_output = agg.get_smoothed_values()

        progress.print(log_output, tag=subset, step=i)
Esempio n. 20
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch and return validation losses."""
    if isinstance(epoch_itr, list):
        itrs = []
        for itr in epoch_itr:
            # Initialize data iterators
            itrs.append(
                itr.next_epoch_itr(
                    fix_batches_to_gpus=args.fix_batches_to_gpus,
                    shuffle=(itr.next_epoch_idx > args.curriculum),
                ))

        update_freq = (args.update_freq[epoch_itr[0].epoch - 1]
                       if epoch_itr[0].epoch <= len(args.update_freq) else
                       args.update_freq[-1])

        grouped_itrs = []
        for itr in itrs:
            grouped_itrs.append(iterators.GroupedIterator(itr, update_freq))

        # not supported
        # if getattr(args, "tpu", False):
        #     itr = utils.tpu_data_loader(itr)

        progress = progress_bar.progress_bar(
            grouped_itrs,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=epoch_itr[0].epoch,
            tensorboard_logdir=(args.tensorboard_logdir if
                                distributed_utils.is_master(args) else None),
            default_log_format=("simplecluster"),
        )

        trainer.begin_epoch(epoch_itr[0].epoch)

    else:
        # Initialize data iterators
        itr = epoch_itr.next_epoch_itr(
            fix_batches_to_gpus=args.fix_batches_to_gpus,
            shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
        )

        update_freq = (args.update_freq[epoch_itr.epoch - 1]
                       if epoch_itr.epoch <= len(args.update_freq) else
                       args.update_freq[-1])

        itr = iterators.GroupedIterator(itr, update_freq)

        # not supported
        # if getattr(args, "tpu", False):
        #     itr = utils.tpu_data_loader(itr)

        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=epoch_itr.epoch,
            tensorboard_logdir=(args.tensorboard_logdir if
                                distributed_utils.is_master(args) else None),
            default_log_format=("tqdm"
                                if not args.no_progress_bar else "simple"),
        )

        trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(",")
    should_stop = False
    num_updates = trainer.get_num_updates()
    for i, samples in enumerate(progress):

        if 'cluster_ids' not in samples[0]['net_input']:
            samples[0]['net_input']['cluster_ids'] = numpy.full(
                (1), 0, dtype=numpy.single)

        with metrics.aggregate(
                "train_inner"), torch.autograd.profiler.record_function(
                    "train_step-%d" % i):
            log_output = trainer.train_step(samples)

        if log_output is not None:  # not OOM, overflow, ...
            # log mid-epoch stats
            num_updates = trainer.get_num_updates()
            if num_updates % args.log_interval == 0:
                stats = get_training_stats(
                    metrics.get_smoothed_values("train_inner"))
                progress.log(stats, tag="train_inner", step=num_updates)

                # reset mid-epoch stats after each log interval
                # the end-of-epoch stats will still be preserved
                metrics.reset_meters("train_inner")

        if isinstance(itr, list):
            end_of_epoch = not itr[0].has_next()
            valid_losses, should_stop = validate_and_save(
                args, trainer, task, epoch_itr[0], valid_subsets, end_of_epoch)

            if should_stop:
                break
        else:
            end_of_epoch = not itr.has_next()
            valid_losses, should_stop = validate_and_save(
                args, trainer, task, epoch_itr, valid_subsets, end_of_epoch)

            if should_stop:
                break

    # log end-of-epoch stats
    logger.info("end of epoch {} (average epoch stats below)".format(
        epoch_itr[0].epoch))
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
def train(args, trainer, task, epoch_itr, m_mle=None):
    global model_old
    global model_mle
    model_old = copy.deepcopy(trainer.model)
    if m_mle is None:
        model_mle = model_old
    else:
        model_mle = m_mle

    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(
            args.tensorboard_logdir if distributed_utils.is_master(args) else None
        ),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    should_end_training = False
    for samples in progress:
        if True: # warning
            valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets)

        with metrics.aggregate('train_inner'):
            # Debug: training goes here
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # Log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # Reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        if num_updates > 2 and num_updates % (args.policy_update_per_k_epoch) == 0:  # warning
            del model_old
            torch.cuda.empty_cache()
            model_old = copy.deepcopy(trainer.model)

        valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets)
        if should_stop_early(args, valid_losses[0]) or num_updates >= max_update:
            should_end_training = True
            break

    # Log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # Reset epoch-level meters
    metrics.reset_meters('train')
    return should_end_training
Esempio n. 22
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch and return validation losses."""
    logger.info("begin training epoch {}".format(epoch_itr.epoch))
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if getattr(args, "tpu", False):
        itr = tpu_data_loader(args, itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
    )

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(",")
    should_stop = False
    time_cost = 0
    for i, samples in enumerate(progress):

        ##### statistic program
        if args.validate_training_performance:
            performance_end_its = args.performance_begin_its + args.performance_its_count - 1
        if args.validate_training_performance and i == args.performance_begin_its:
            processed_tokens = 0

        with metrics.aggregate(
                "train_inner"), torch.autograd.profiler.record_function(
                    "train_step-%d" % i):
            time_begin = time.time()
            log_output = trainer.train_step(samples)
            time_end = time.time()
            if args.validate_training_performance and i >= args.performance_begin_its and i <= performance_end_its:
                time_cost = time_cost + (time_end - time_begin)
            if log_output is None:  # OOM, overflow, ...
                continue
        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values("train_inner"))
            progress.log(stats, tag="train_inner", step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop = validate_and_save(args, trainer, task,
                                                      epoch_itr, valid_subsets,
                                                      end_of_epoch)
        if args.validate_training_performance and i >= args.performance_begin_its:
            for sample in samples:
                net_input = sample['net_input']
                bs, src_lens = net_input['src_tokens'].shape
                processed_tokens += bs * src_lens
        if args.validate_training_performance and i == performance_end_its:
            logger.info("Performance info:")
            logger.info("Begin iteration:{}".format(
                args.performance_begin_its))
            logger.info("End iteration: {}".format(performance_end_its))
            logger.info("Processed_tokens: {}".format(processed_tokens))
            logger.info("Time cost: {} s".format(time_cost))
            logger.info("Throughput:{} tokens/s".format(processed_tokens /
                                                        (time_cost)))
            should_stop = True
        if should_stop:
            break
    # log end-of-epoch stats
    logger.info("end of epoch {} (average epoch stats below)".format(
        epoch_itr.epoch))
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop
Esempio n. 23
0
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    should_end_training = False
    for samples in progress:
        with metrics.aggregate('train_inner'):
            try:
                log_output = trainer.train_step(samples)

            except ResetTrainerException:
                trainer._wrapped_criterion = None
                trainer._wrapped_model = None
                trainer._optimizer = None

                logger.info("reset the trainer at {}".format(
                    trainer.get_num_updates()))
                log_output = trainer.train_step(samples)

            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')

        valid_losses = validate_and_save(args, trainer, task, epoch_itr,
                                         valid_subsets)
        if should_stop_early(args,
                             valid_losses[0]) or num_updates >= max_update:
            should_end_training = True
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
    return should_end_training
Esempio n. 24
0
    def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None, gvar=None, adam_mom2=None,
                              gvar_diff=None, xstd=None, ams_mom=None, acc_ratio=None, real_var=None, real_var_diff=None,
                              ad_beta=None, lr_min=None, lr_max=None, lr_median=None,
                              update_min=None, update_max=None, update_median=None, valid_ratio=None, var_adapt=None):
        if grad_norm is not None:
            metrics.log_speed("ups", 1., priority=100, round=2)
            metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
            if self.args.clip_norm > 0:
                metrics.log_scalar(
                    "clip",
                    torch.where(
                        grad_norm > self.args.clip_norm,
                        grad_norm.new_tensor(100),
                        grad_norm.new_tensor(0),
                    ),
                    priority=500,
                    round=1,
                )
        if gvar is not None:
            metrics.log_scalar("gvar", gvar, priority=100)
        if adam_mom2 is not None:
            metrics.log_scalar("adam_mom2", adam_mom2, priority=100)
        if gvar_diff is not None:
            metrics.log_scalar("gvar_diff", gvar_diff, priority=100)
        if xstd is not None:
            metrics.log_scalar("xstd", xstd, priority=100)
        if ams_mom is not None:
            metrics.log_scalar("ams_mom", ams_mom, priority=100)
        if acc_ratio is not None:
            metrics.log_scalar("acc_ratio", acc_ratio, priority=50)
        if real_var is not None:
            metrics.log_scalar("real_var", real_var, priority=50)
        if real_var_diff is not None:
            metrics.log_scalar("real_var_diff", real_var_diff, priority=50)
        if ad_beta is not None:
            metrics.log_scalar("ad_beta", ad_beta, priority=50)

        if lr_min is not None:
            metrics.log_scalar("lr_min", lr_min, priority=50)

        if lr_max is not None:
            metrics.log_scalar("lr_max", lr_max, priority=50)

        if lr_median is not None:
            metrics.log_scalar("lr_median", lr_median, priority=50)
        if update_min is not None:
            metrics.log_scalar("update_min", update_min, priority=50)

        if update_median is not None:
            metrics.log_scalar("update_median", update_median, priority=50)

        if update_max is not None:
            metrics.log_scalar("update_max", update_max, priority=50)

        if valid_ratio is not None:
            metrics.log_scalar("valid_ratio", valid_ratio, priority=49)
        if var_adapt is not None:
            metrics.log_scalar("var_adapt", var_adapt, priority=1)

        with metrics.aggregate() as agg:
            if logging_outputs is not None:
                self.task.reduce_metrics(logging_outputs, self.get_criterion())

            preds, labels = [], []
            for log_output in logging_outputs:
                if 'preds' in log_output:
                    preds.append(log_output['preds'])
                    labels.append(log_output['labels'])
                else:
                    preds = None
                    labels = None

            # support legacy interface
            logging_output = agg.get_smoothed_values()
            logging_output["sample_size"] = sample_size
            for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
                if key_to_delete in logging_output:
                    del logging_output[key_to_delete]
            return logging_output, preds, labels
Esempio n. 25
0
def train(args,
          trainer,
          task,
          epoch_itr,
          model,
          experiment_path,
          total_samples=None,
          last_epoch_num=0,
          restore=None):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (args.update_freq[epoch_itr.epoch - 1]
                   if epoch_itr.epoch <= len(args.update_freq) else
                   args.update_freq[-1])
    itr = iterators.GroupedIterator(itr, update_freq)
    if getattr(args, "tpu", False):
        itr = tpu_data_loader(args, itr)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(args.tensorboard_logdir
                            if distributed_utils.is_master(args) else None),
        default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
    )

    num_heads = args.decoder_attention_heads
    head_dim = args.decoder_embed_dim // num_heads
    if experiment_path is not None:
        with open(experiment_path, 'r') as f:
            swaps = json.load(f)
        mhr(model, swaps, head_dim, num_heads, epoch_itr.epoch)

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = args.valid_subset.split(",")
    should_stop = False

    conf = {
        "encoder": [{
            "self_attn": []
        } for i in range(args.encoder_layers)],
        "decoder": [{
            "self_attn": [],
            "enc_attn": []
        } for i in range(args.decoder_layers)]
    }
    attentions = {
        "decoder": [{
            "self_attn": []
        } for i in range(args.decoder_layers)]
    }

    batch_regression = 1.0 - (total_samples / (160239 * 50))
    for i, samples in enumerate(progress):
        with metrics.aggregate(
                "train_inner"), torch.autograd.profiler.record_function(
                    "train_step-%d" % i):
            log_output = trainer.train_step(samples,
                                            batch_num=batch_regression)

            if log_output is None:  # OOM, overflow, ...
                continue
        total_samples += model.decoder.layers[0].self_attn.bsz
        batch_regression = 1.0 - (
            total_samples / (160239 * 40)
        )  # need to find more generic way to find total samples and epoch num.

        # Get Confidence for each Head.
        if args.head_confidence_method is not None:
            conf = get_batch_confs(model, conf, args)

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(
                metrics.get_smoothed_values("train_inner"))
            progress.log(stats, tag="train_inner", step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters("train_inner")

        end_of_epoch = not itr.has_next()
        valid_losses, should_stop, val_conf = validate_and_save(
            args, trainer, task, epoch_itr, valid_subsets, end_of_epoch)

        if should_stop:
            break

    if args.head_confidence_method is not None:

        conf = convert_confs(conf, args)

        path = args.save_dir.replace("checkpoints",
                                     "confs") + "-method={0}".format(
                                         args.head_confidence_method)
        try:
            os.mkdir(path, 0o775)
        except:
            pass
        with open(
                args.save_dir.replace("checkpoints", "confs") +
                "-method={0}".format(args.head_confidence_method) +
                "/epoch-{0}.pkl".format(epoch_itr.epoch), 'wb') as fd:
            pickle.dump(conf, fd, protocol=3)

    if args.dynamic_type is not None and args.head_confidence_method is not None:
        conf = val_conf

        restore['enc_self_attn'], last_epoch_num[
            'enc_self_attn'] = dynamic_mhr(model,
                                           int(args.start_dynamic_mhr[0]),
                                           "encoder",
                                           "self_attn",
                                           restore['enc_self_attn'],
                                           int(args.dynamic_swap_frequency[0]),
                                           last_epoch_num['enc_self_attn'],
                                           epoch_itr.epoch + 1,
                                           int(args.dynamic_max_switches[0]),
                                           conf[0],
                                           num_heads,
                                           head_dim,
                                           args.encoder_layers,
                                           local_only=False,
                                           d_type=args.dynamic_type[0],
                                           rest=int(args.dynamic_rest[0]),
                                           end_epoch=int(
                                               args.dynamic_end_epoch[0]))

        restore['dec_self_attn'], last_epoch_num[
            'dec_self_attn'] = dynamic_mhr(model,
                                           int(args.start_dynamic_mhr[1]),
                                           "decoder",
                                           "self_attn",
                                           restore['dec_self_attn'],
                                           int(args.dynamic_swap_frequency[1]),
                                           last_epoch_num['dec_self_attn'],
                                           epoch_itr.epoch + 1,
                                           int(args.dynamic_max_switches[1]),
                                           conf[1],
                                           num_heads,
                                           head_dim,
                                           args.encoder_layers,
                                           local_only=False,
                                           d_type=args.dynamic_type[1],
                                           rest=int(args.dynamic_rest[1]),
                                           end_epoch=int(
                                               args.dynamic_end_epoch[1]))
        restore['dec_enc_attn'], last_epoch_num['dec_enc_attn'] = dynamic_mhr(
            model,
            int(args.start_dynamic_mhr[2]),
            "decoder",
            "encoder_attn",
            restore['dec_enc_attn'],
            int(args.dynamic_swap_frequency[2]),
            last_epoch_num['dec_enc_attn'],
            epoch_itr.epoch + 1,
            int(args.dynamic_max_switches[2]),
            conf[2],
            num_heads,
            head_dim,
            args.encoder_layers,
            local_only=False,
            d_type=args.dynamic_type[2],
            rest=int(args.dynamic_rest[2]),
            end_epoch=int(args.dynamic_end_epoch[2]))

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag="train", step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop, total_samples, restore, last_epoch_num
Esempio n. 26
0
def train(args, trainer, task, epoch_itr, max_update=math.inf, model=None):
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
    )
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format=args.log_format,
        log_interval=args.log_interval,
        epoch=epoch_itr.epoch,
        tensorboard_logdir=(
            args.tensorboard_logdir if distributed_utils.is_master(args) else None
        ),
        default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    for i, samples in enumerate(progress):
        with metrics.aggregate('train_inner'):
            log_output = trainer.train_step(samples)
            if log_output is None:  # OOM, overflow, ...
                continue

        # log mid-epoch stats
        num_updates = trainer.get_num_updates()
        if num_updates % args.log_interval == 0:
            stats = get_training_stats(metrics.get_smoothed_values('train_inner'))
            progress.log(stats, tag='train_inner', step=num_updates)

            # reset mid-epoch stats after each log interval
            # the end-of-epoch stats will still be preserved
            metrics.reset_meters('train_inner')
        if(i==0):
            print('epoch: ', epoch_itr.epoch)
            endeattn_norm=[]
            selfattn_norm=[]
            for m in model.modules():
                if(hasattr(m, 'selfattn_norm')):
                    if(m.selfattn_norm != None):
                        selfattn_norm.append(m.selfattn_norm)
                if(hasattr(m, 'endeattn_norm')):
                    if(m.endeattn_norm != None):
                        endeattn_norm.append(m.endeattn_norm)
            print('self attention norms: ', selfattn_norm)
            print('en/decoder attn norms:', endeattn_norm)
        valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets)
        if should_stop_early(args, valid_losses[0]) or num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')
    return valid_losses
Esempio n. 27
0
def validate_iw(args, trainer, task, epoch_itr, subsets, prune=-1, mode='iw'):
    """Evaluate the model on the validation set(s) and return the losses."""

    if mode == 'none' or mode == 'time' or args.criterior == 'lm_baseline':
        return [0]

    # top k instead of sampling to approximate sum of prototypes for evaluation
    for subset in subsets:
        task.dataset(subset).set_sampling(False)

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=1,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args.tensorboard_logdir if
                                distributed_utils.is_master(args) else None),
            default_log_format=('tqdm'
                                if not args.no_progress_bar else 'simple'),
        )

        if prune > 0:
            index_map = trainer.get_model().set_prune_index(prune)
            task.set_index_map(index_map)

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_iw_step(sample, mode=mode)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag='valid_iw', step=trainer.get_num_updates())

        # valid_losses.append(stats[args.best_checkpoint_metric])

        if prune > 0:
            trainer.get_model().reset_prune_index()
            task.reset_index_map()

    return valid_losses
Esempio n. 28
0
def validate(args, trainer, task, epoch_itr, subsets, prune=-1):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            epoch=epoch_itr.epoch,
            prefix=f"valid on '{subset}' subset",
            tensorboard_logdir=(args.tensorboard_logdir if
                                distributed_utils.is_master(args) else None),
            default_log_format=('tqdm'
                                if not args.no_progress_bar else 'simple'),
        )

        # added by Junxian
        if prune > 0:
            index_map = trainer.get_model().set_prune_index(prune)
            task.set_index_map(index_map)

        # not write templates for time profiling
        write_template_flag = False if args.eval_mode == 'time' else True

        # only one worker deals with the template file in DDP
        if args.distributed_rank == 0 and write_template_flag:
            print('write template files')

            if args.eval_mode == 'none':
                fout = open(
                    os.path.join(
                        args.save_dir, 'templates_{}_{}.txt'.format(
                            epoch_itr.epoch, trainer.get_num_updates())), 'w')
            else:
                fout = open(
                    os.path.join(args.save_dir,
                                 'templates_eval_{}.txt'.format(subset)), 'w')

            if prune <= 0:
                task.write_lambda(fout, trainer.get_model())
        else:
            fout = None

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_step(sample, split=subset)

                # added by Junxian
                if args.distributed_rank == 0:
                    task.write_template(sample, trainer.get_model(), fout)

        if fout is not None:
            fout.close()

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric])
    return valid_losses
Esempio n. 29
0
def main(args, override_args=None):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    use_fp16 = args.fp16
    use_cuda = torch.cuda.is_available() and not args.cpu

    if override_args is not None:
        try:
            override_args = override_args['override_args']
        except TypeError:
            override_args = override_args
        overrides = vars(override_args)
        overrides.update(eval(getattr(override_args, 'model_overrides', '{}')))
    else:
        overrides = None

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
        [args.path],
        arg_overrides=overrides,
        suffix=getattr(args, "checkpoint_suffix", ""),
    )
    model = models[0]

    # Move models to GPU
    for model in models:
        if use_fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Print args
    logger.info(model_args)

    # Build criterion
    criterion = task.build_criterion(model_args)
    if use_fp16:
        criterion.half()
    if use_cuda:
        criterion.cuda()
    criterion.eval()

    for subset in args.valid_subset.split(','):
        try:
            task.load_dataset(subset, combine=False, epoch=1)
            dataset = task.dataset(subset)
        except KeyError:
            raise Exception('Cannot find dataset: ' + subset)

        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=dataset,
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                *[m.max_positions() for m in models],
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_workers=args.num_workers,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank).next_epoch_itr(shuffle=False)

        progress = progress_bar.progress_bar(
            itr,
            log_format=args.log_format,
            log_interval=args.log_interval,
            prefix=f"valid on '{subset}' subset",
            default_log_format=('tqdm'
                                if not args.no_progress_bar else 'simple'),
        )

        log_outputs = []
        for i, sample in enumerate(progress):
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            sample = utils.apply_to_sample(
                lambda t: t.half() if t.dtype is torch.float32 else t,
                sample) if use_fp16 else sample
            try:
                with torch.no_grad():  # do not save backward passes
                    max_num_rays = 900 * 900
                    if sample['uv'].shape[3] > max_num_rays:
                        sample['ray_split'] = sample['uv'].shape[
                            3] // max_num_rays
                    _loss, _sample_size, log_output = task.valid_step(
                        sample, model, criterion)

                progress.log(log_output, step=i)
                log_outputs.append(log_output)

            except TypeError:
                break

        with metrics.aggregate() as agg:
            task.reduce_metrics(log_outputs, criterion)
            log_output = agg.get_smoothed_values()

        # summarize all the gpus
        if args.distributed_world_size > 1:
            all_log_output = list(
                zip(*distributed_utils.all_gather_list([log_output])))[0]
            log_output = {
                key: np.mean([log[key] for log in all_log_output])
                for key in all_log_output[0]
            }

        progress.print(log_output, tag=subset, step=i)