Beispiel #1
0
def save_and_eval(
    args,
    trainer,
    task,
    extra_state: Dict[str, Any],
    do_eval_tune_loss: bool,
    do_save: bool,
    do_eval_bleu: bool,
) -> Tuple[Dict[str, Any], bool, Optional[list]]:
    # Clear any remaining metrics from previous steps. This should already
    # have been done before, but just in case - to make sure we catch
    # any case where extra_case does not get populated correctly.
    extra_state = clear_per_step_extra_state(extra_state)

    # Under multiprocessing, each process will run eval over a different
    # shard of the tune data set and then aggregate the results across all
    # processes, so the eval stats from all processes' trainer should
    # remain synchronized.
    stop_due_to_tune_loss = False
    if do_eval_tune_loss:
        extra_state, stop_due_to_tune_loss = eval_tune_loss(
            args=args,
            trainer=trainer,
            task=task,
            subset=args.valid_subset,
            extra_state=extra_state,
        )

    # Only save checkpoints and eval tune BLEU on the master - all other
    # processes will just get the results from the master.
    master_extra_state = None
    master_stop_training = None
    translation_samples = None
    if distributed_utils.is_master(args):
        stop_due_to_tune_bleu = False
        if do_save:
            extra_state = save_checkpoint(
                trainer=trainer, args=args, extra_state=extra_state
            )
        if do_eval_bleu and not do_save:
            raise ValueError(
                "do_save should always be true when do_eval_bleu is true "
                "since a new BLEU eval can only be done when there's a new "
                "checkpoint."
            )
        if do_eval_bleu:
            extra_state, stop_due_to_tune_bleu, translation_samples = evaluate_bleu(
                args=args, task=task, extra_state=extra_state
            )
        master_extra_state = extra_state
        master_stop_training = stop_due_to_tune_loss or stop_due_to_tune_bleu

    # We don't all_gather the translation_samples since the sample sentences
    # could be pretty long, and only the master uses it anyway.
    extra_state, stop_training = pytorch_translate_utils.all_gather_from_master(
        args=args, data=[master_extra_state, master_stop_training]
    )

    # Basic sanity checks that extra_state is populated correctly.
    assert not (
        do_eval_tune_loss
        and (
            extra_state["tune_eval"]["loss"] is None
            or extra_state["tune_eval"]["perplexity"] is None
        )
    )
    assert not (do_eval_bleu and extra_state["tune_bleu"]["current"] is None)
    return extra_state, stop_training, translation_samples
Beispiel #2
0
def save_and_eval(
    args,
    trainer,
    task,
    extra_state: Dict[str, Any],
    checkpoint_manager: Optional[checkpoint.CheckpointManager],
    end_of_epoch=False,
) -> Tuple[Dict[str, Any], bool, Optional[List]]:
    # Checks for time limit stopping criterion even when we're not doing
    # eval/saving checkpoints.
    max_update = args.max_update or math.inf
    stop_due_to_max_update = trainer.get_num_updates() > max_update
    stop_due_to_time_limit = is_training_over_time_limit(
        extra_state, args.stop_time_hr)
    if not end_of_epoch and (
            args.save_interval_updates <= 0 or
        (extra_state["num_iterations"] % args.save_interval_updates != 0)):
        return extra_state, stop_due_to_time_limit, None

    # Update training time before saving the checkpoint.
    time_now: float = time.time()
    extra_state[
        "previous_training_time"] += time_now - extra_state["start_time"]
    extra_state["start_time"] = time_now

    # Under multiprocessing, each process will run eval over a different
    # shard of the tune data set and then aggregate the results across all
    # processes, so the eval stats from all processes' trainer should
    # remain synchronized.

    # Tune loss
    extra_state, stop_due_to_tune_loss = eval_tune_loss(
        args=args,
        trainer=trainer,
        task=task,
        subset=args.valid_subset,
        extra_state=extra_state,
    )

    is_master: bool = distributed_utils.is_master(args)
    if is_master:
        assert checkpoint_manager is not None, (
            f"Master worker (rank {args.distributed_rank}) should "
            f"have a checkpoint_manager defined.")
    else:
        assert checkpoint_manager is None, (
            f"Non-master worker (rank {args.distributed_rank}) should not "
            f"have a checkpoint_manager defined.")

    # trick to prepare the task for evaluation, e.g. in latent variable model we need to set eval_key in RoundRobinZipDataset
    if hasattr(task, "prepare_for_eval") and callable(task.prepare_for_eval):
        task.prepare_for_eval()
    # Only save checkpoints and eval tune BLEU on the master - all other
    # processes will just get the results from the master.
    translation_samples: Optional[List] = None
    if is_master:
        averaged_params: OrderedDict = checkpoint_manager.get_averaged_params(
            new_params=trainer.get_model().state_dict())

        # TODO: fix after masked lm work completes
        if "save_only" not in args or not args.save_only:
            (
                extra_state,
                stop_due_to_tune_bleu,
                new_best_averaged_checkpoint,
                translation_samples,
            ) = evaluate_bleu(
                args=args,
                task=task,
                extra_state=extra_state,
                trainer=trainer,
                averaged_params=averaged_params,
            )
        else:
            new_best_averaged_checkpoint = True
            stop_due_to_tune_bleu = False
        # checkpoint_manager takes ownership of averaged_params.
        extra_state = checkpoint_manager.save(
            args=args,
            trainer=trainer,
            extra_state=extra_state,
            new_averaged_params=averaged_params,
        )
        if new_best_averaged_checkpoint:
            checkpoint_manager.save_best_averaged_checkpoint(
                args=args, trainer=trainer, extra_state=extra_state)
    if hasattr(task, "prepare_for_train") and callable(task.prepare_for_train):
        task.prepare_for_train()

    # extra_state["tune_bleu"] needs to be sync'ed between master and workers
    # since we only do BLEU eval on master, but then need that info for
    # determining when to do lr_shrink on all workers.
    master_tune_bleu = None
    master_stop_training = None
    if is_master:
        master_tune_bleu = extra_state["tune_bleu"]
        master_stop_training = (stop_due_to_time_limit or stop_due_to_tune_loss
                                or stop_due_to_tune_bleu
                                or stop_due_to_max_update)
    tune_bleu, stop_training = pytorch_translate_utils.all_gather_from_master(
        args=args, data=[master_tune_bleu, master_stop_training])
    extra_state["tune_bleu"] = tune_bleu

    # TODO: fix after masked lm work completes
    if "save_only" not in args or not args.save_only:
        # Basic sanity checks that extra_state is populated correctly.
        assert (extra_state["tune_eval"]["loss"] is not None
                and extra_state["tune_eval"]["perplexity"] is not None
                and extra_state["tune_bleu"]["current"] is not None)
    return extra_state, stop_training, translation_samples
Beispiel #3
0
def save_and_eval(
    args,
    trainer,
    task,
    extra_state: Dict[str, Any],
    checkpoint_manager: Optional[checkpoint.CheckpointManager],
    end_of_epoch=False,
) -> Tuple[Dict[str, Any], bool, Optional[List]]:
    # Checks for time limit stopping criterion even when we're not doing
    # eval/saving checkpoints.
    max_update = args.max_update or math.inf
    stop_due_to_max_update = trainer.get_num_updates() > max_update
    stop_due_to_time_limit = is_training_over_time_limit(
        extra_state, args.stop_time_hr)
    if not end_of_epoch and (
            args.save_interval_updates <= 0 or
        (extra_state["num_iterations"] % args.save_interval_updates != 0)):
        return extra_state, stop_due_to_time_limit

    # Update training time before saving the checkpoint.
    time_now: float = time.time()
    extra_state[
        "previous_training_time"] += time_now - extra_state["start_time"]
    extra_state["start_time"] = time_now

    # Under multiprocessing, each process will run eval over a different
    # shard of the tune data set and then aggregate the results across all
    # processes, so the eval stats from all processes' trainer should
    # remain synchronized.

    # Tune loss
    extra_state, stop_due_to_tune_loss = eval_tune_loss(
        args=args,
        trainer=trainer,
        task=task,
        subset=args.valid_subset,
        extra_state=extra_state,
    )

    is_master: bool = distributed_utils.is_master(args)
    if is_master:
        assert checkpoint_manager is not None, (
            f"Master worker (rank {args.distributed_rank}) should "
            f"have a checkpoint_manager defined.")
    else:
        assert checkpoint_manager is None, (
            f"Non-master worker (rank {args.distributed_rank}) should not "
            f"have a checkpoint_manager defined.")

    if is_master:
        averaged_params: OrderedDict = checkpoint_manager.get_averaged_params(
            new_params=trainer.get_model().state_dict())
        new_best_averaged_checkpoint = extra_state["tune_eval"][
            "num_since_best"] == 0
        # checkpoint_manager takes ownership of averaged_params.
        extra_state = checkpoint_manager.save(
            args=args,
            trainer=trainer,
            extra_state=extra_state,
            new_averaged_params=averaged_params,
        )
        if new_best_averaged_checkpoint:
            checkpoint_manager.save_best_averaged_checkpoint(
                args=args, trainer=trainer, extra_state=extra_state)

    master_stop_training = None
    if is_master:
        master_stop_training = (stop_due_to_time_limit or stop_due_to_tune_loss
                                or stop_due_to_max_update)
    stop_training = pytorch_translate_utils.all_gather_from_master(
        args=args, data=[master_stop_training])[0]

    # TODO: fix after masked lm work completes
    if "save_only" not in args or not args.save_only:
        # Basic sanity checks that extra_state is populated correctly.
        assert (extra_state["tune_eval"]["loss"] is not None
                and extra_state["tune_eval"]["perplexity"] is not None)
    return extra_state, stop_training
Beispiel #4
0
def save_and_eval(
    args,
    trainer,
    task,
    extra_state: Dict[str, Any],
    checkpoint_manager: Optional[checkpoint.CheckpointManager],
    end_of_epoch=False,
) -> Tuple[Dict[str, Any], bool, Optional[List]]:
    # Checks for time limit stopping criterion even when we're not doing
    # eval/saving checkpoints.
    stop_due_to_time_limit = is_training_over_time_limit(
        extra_state, args.stop_time_hr)
    if not end_of_epoch and (
            args.save_interval_updates <= 0 or
        (extra_state["num_iterations"] % args.save_interval_updates != 0)):
        return extra_state, stop_due_to_time_limit, None

    # Update training time before saving the checkpoint.
    time_now: float = time.time()
    extra_state[
        "previous_training_time"] += time_now - extra_state["start_time"]
    extra_state["start_time"] = time_now

    # Under multiprocessing, each process will run eval over a different
    # shard of the tune data set and then aggregate the results across all
    # processes, so the eval stats from all processes' trainer should
    # remain synchronized.

    # Tune loss
    extra_state, stop_due_to_tune_loss = eval_tune_loss(
        args=args,
        trainer=trainer,
        task=task,
        subset=args.valid_subset,
        extra_state=extra_state,
    )

    is_master: bool = distributed_utils.is_master(args)
    if is_master:
        assert checkpoint_manager is not None, (
            f"Master worker (rank {args.distributed_rank}) should "
            f"have a checkpoint_manager defined.")
    else:
        assert checkpoint_manager is None, (
            f"Non-master worker (rank {args.distributed_rank}) should not "
            f"have a checkpoint_manager defined.")

    # Only save checkpoints and eval tune BLEU on the master - all other
    # processes will just get the results from the master.
    translation_samples: Optional[List] = None
    if is_master:
        averaged_params: OrderedDict = checkpoint_manager.get_averaged_params(
            new_params=trainer.get_model().state_dict())
        extra_state, stop_due_to_tune_bleu, new_best_averaged_checkpoint, translation_samples = evaluate_bleu(
            args=args,
            task=task,
            extra_state=extra_state,
            trainer=trainer,
            averaged_params=averaged_params,
        )
        # checkpoint_manager takes ownership of averaged_params.
        extra_state = checkpoint_manager.save(
            args=args,
            trainer=trainer,
            extra_state=extra_state,
            new_averaged_params=averaged_params,
        )
        if new_best_averaged_checkpoint:
            checkpoint_manager.save_best_averaged_checkpoint(
                args=args, trainer=trainer, extra_state=extra_state)

    # We don't all_gather the translation_samples since the sample sentences
    # could be pretty long, and only the master uses it anyway.
    master_extra_state = None
    master_stop_training = None
    if is_master:
        master_extra_state = extra_state
        master_stop_training = (stop_due_to_time_limit or stop_due_to_tune_loss
                                or stop_due_to_tune_bleu)
    extra_state, stop_training = pytorch_translate_utils.all_gather_from_master(
        args=args, data=[master_extra_state, master_stop_training])

    # Basic sanity checks that extra_state is populated correctly.
    assert (extra_state["tune_eval"]["loss"] is not None
            and extra_state["tune_eval"]["perplexity"] is not None
            and extra_state["tune_bleu"]["current"] is not None)
    return extra_state, stop_training, translation_samples