예제 #1
0
def save_checkpoint_atomic(trainer, final_filename, extra_state):
    """Wrapper around trainer.save_checkpoint to make save atomic."""
    temp_filename = os.path.join(final_filename + ".tmp")
    trainer.save_checkpoint(temp_filename, extra_state)
    # TODO(T56266125): Use mv() instead of copy() + rm() after it's added to
    # PathManager.
    assert PathManager.copy(
        temp_filename, final_filename, overwrite=True
    ), f"Failed to copy {temp_filename} to {final_filename}"
    PathManager.rm(temp_filename)
예제 #2
0
파일: utils.py 프로젝트: vednig/translate
def load_diverse_ensemble_for_inference(filenames: List[str],
                                        task: Optional[
                                            tasks.FairseqTask] = None):
    """Load an ensemble of diverse models for inference.

    This method is similar to fairseq.utils.load_ensemble_for_inference
    but allows to load diverse models with non-uniform args.

    Args:
        filenames: List of file names to checkpoints
        task: Optional[FairseqTask]. If this isn't provided, we setup the task
            using the first checkpoint's model args loaded from the saved state.

    Return:
        models, args: Tuple of lists. models contains the loaded models, args
            the corresponding configurations.
        task: Either the input task or the task created within this function
            using args
    """

    # load model architectures and weights
    checkpoints_data = []
    for filename in filenames:
        if not PathManager.exists(filename):
            raise IOError("Model file not found: {}".format(filename))
        with PathManager.open(filename, "rb") as f:
            checkpoints_data.append(
                torch.load(
                    f,
                    map_location=lambda s, l: torch.serialization.
                    default_restore_location(s, "cpu"),
                ))

    def get_cfg(cp, key):
        if "cfg" in cp:
            return cp["cfg"][key]
        else:
            return cp["args"]

    # build ensemble
    ensemble = []
    if task is None:
        cfg = get_cfg(checkpoints_data[0], "task")
        if hasattr(cfg, "mode"):
            cfg.mode = "eval"
        task = tasks.setup_task(cfg)
    for checkpoint_data in checkpoints_data:
        cfg = get_cfg(checkpoint_data, "model")
        model = task.build_model(cfg)
        model.load_state_dict(checkpoint_data["model"])
        ensemble.append(model)
    args_list = [get_cfg(s, "model") for s in checkpoints_data]
    return ensemble, args_list, task
예제 #3
0
 def _remove_checkpoint(self, checkpoint_to_remove: Optional[str]):
     if checkpoint_to_remove:
         self.log_if_verbose(
             f"| Preparing to remove old checkpoint {checkpoint_to_remove}."
         )
         try:
             PathManager.rm(checkpoint_to_remove)
             self.log_if_verbose(
                 f"| Finished removing old checkpoint {checkpoint_to_remove}."
             )
         except FileNotFoundError:
             print(
                 f"| Unable to find old checkpoint {checkpoint_to_remove} for removal",
                 flush=True,
             )
예제 #4
0
 def _remove_checkpoint(self, checkpoint_to_remove: Optional[str]):
     if checkpoint_to_remove:
         self.log_if_verbose(
             f"| Preparing to remove old checkpoint {checkpoint_to_remove}."
         )
         try:
             PathManager.rm(checkpoint_to_remove)
             self.log_if_verbose(
                 f"| Finished removing old checkpoint {checkpoint_to_remove}."
             )
         except OSError as e:
             print(
                 f"| Failed to remove old checkpoint {checkpoint_to_remove} "
                 f"- exception: {e}",
                 flush=True,
             )
예제 #5
0
def load_to_gpu(path: str) -> Dict[str, Any]:
    """
    Similar to load_to_cpu, but load model to cuda
    """
    with PathManager.open(path, "rb") as f:
        state = torch.load(
            f,
            map_location=(lambda s, _: torch.serialization.
                          default_restore_location(s, "cuda")),
        )
    return state
예제 #6
0
def load_to_cpu(path: str) -> Dict[str, Any]:
    """
    This is just fairseq's utils.load_checkpoint_to_cpu(), except we don't try
    to upgrade the state dict for backward compatibility - to make cases
    where we only care about loading the model params easier to unit test.
    """
    with PathManager.open(path, "rb") as f:
        state = torch.load(
            f,
            map_location=(lambda s, _: torch.serialization.
                          default_restore_location(s, "cpu")),
        )
    return state
예제 #7
0
def load_existing_checkpoint(
        checkpoint_path,
        trainer,
        restore_state=True) -> Tuple[bool, Optional[Dict]]:
    loaded = False
    extra_state = None

    if not PathManager.isfile(checkpoint_path):
        print(f"| No existing checkpoint at {checkpoint_path}. "
              f"Starting training from scratch.")
        return loaded, extra_state

    if restore_state:
        extra_state = trainer.load_checkpoint(checkpoint_path)
        if extra_state is None:
            loaded = False
            print(
                f"| Failed to load checkpoint and state from {checkpoint_path}."
            )
        else:
            loaded = True
            print(f"| Loaded checkpoint {checkpoint_path} "
                  f"(epoch {extra_state['epoch']}) with restored extra state.")
            # batch_offset being None denotes this was a checkpoint saved at
            # the end of an epoch (after the last batch).
            if extra_state["batch_offset"] is None:
                trainer.lr_step(extra_state["epoch"])
                extra_state["epoch"] += 1
                extra_state["batch_offset"] = 0

    else:
        dummy_state = trainer.load_checkpoint(checkpoint_path,
                                              reset_optimizer=True)
        if dummy_state is None:
            loaded = False
            print(
                f"| Failed to load checkpoint weights from {checkpoint_path}.")
        else:
            loaded = True
            print(f"| Loaded checkpoint weights from {checkpoint_path}.")

    return loaded, extra_state
예제 #8
0
def setup_training_state(args, trainer, task, epoch_itr):
    """Set up the directory for saving checkpoints.
    Load pretrained model if specified."""
    PathManager.mkdirs(args.save_dir)

    # If --restore-file is already present under --save-dir, use that one
    # instead of --pretrained-checkpoint-file. The idea is that
    # --pretrained-checkpoint-file allows the user to specify restoring from a
    # different run's checkpoint (possibly with different training params),
    # while not polluting the previous run's checkpoint directory
    # with new checkpoints. However, if training gets interrupted
    # and the user restarts training, we want to resume from
    # the checkpoints under --save-dir, instead of
    # restarting again from the old run's checkpoint at
    # --pretrained-checkpoint-file.
    #
    # Note that if args.restore_file is an absolute path, os.path.join() will
    # ignore previous directory args and just use the absolute path as is.
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    restore_state = True
    if PathManager.isfile(checkpoint_path):
        print(
            f"| Using --save-dir={args.save_dir}, --restore-file={args.restore_file}."
        )
    elif args.pretrained_checkpoint_file and PathManager.isfile(
            args.pretrained_checkpoint_file):
        checkpoint_path = args.pretrained_checkpoint_file
        restore_state = args.load_pretrained_checkpoint_state
        print(
            f"| Using --pretrained-checkpoint-file={args.pretrained_checkpoint_file}, "
            f"--load-pretrained-checkpoint-state={args.load_pretrained_checkpoint_state}."
        )

    extra_state = default_extra_state(args)
    if not PathManager.isfile(
            checkpoint_path) and args.multi_model_restore_files:
        print(
            f"| Restoring individual models from {args.multi_model_restore_files}"
        )
        multi_model.import_individual_models(args.multi_model_restore_files,
                                             trainer)
    else:
        loaded, loaded_extra_state = checkpoint.load_existing_checkpoint(
            checkpoint_path=checkpoint_path,
            trainer=trainer,
            restore_state=restore_state,
        )
        if loaded_extra_state:
            extra_state.update(loaded_extra_state)

    # Reset the start time for the current training run.
    extra_state["start_time"] = time.time()

    # Skips printing all training progress to prevent log spam.
    training_progress = extra_state["training_progress"]
    extra_state["training_progress"] = ([
        "...truncated...", training_progress[-1]
    ] if len(training_progress) > 0 else [])
    print(f"| extra_state: {extra_state}")
    extra_state["training_progress"] = training_progress

    epoch = extra_state["epoch"]
    if extra_state["batch_offset"] == 0:
        epoch -= 1  # this will be incremented when we call epoch_itr.next_epoch_itr()
    epoch_itr.load_state_dict({
        "epoch":
        epoch,
        "iterations_in_epoch":
        extra_state["batch_offset"]
    })

    checkpoint_manager = None
    if distributed_utils.is_master(args):
        checkpoint_manager = checkpoint.CheckpointManager(
            num_avg_checkpoints=args.num_avg_checkpoints,
            auto_clear_checkpoints=args.auto_clear_checkpoints,
            log_verbose=args.log_verbose,
            checkpoint_files=extra_state["checkpoint_files"],
        )

    return extra_state, epoch_itr, checkpoint_manager
예제 #9
0
    def save(
        self,
        args,
        trainer,
        extra_state: Dict[str, Any],
        new_averaged_params: OrderedDict,
    ) -> Dict[str, Any]:
        """Saves the model params contained in trainer.

        Takes ownership of new_averaged_params, so the caller should not modify
        them afterwards.

        Args:
          trainer: Trainer containing the model to be saved.
          extra_state: Dictionary containing any extra information about the
              model beyond the param weights.
          new_averaged_params: If specified, takes ownership of the params and
              sets them as current set of averaged params. If not specified,
              we will recalculate the averaged params using the model params
              in trainer.

        Returns:
          Updated extra_state dictionary.
        """
        epoch = extra_state["epoch"]
        batch_offset = extra_state["batch_offset"]

        # batch_offset being None means that we're at the end of an epoch.
        if batch_offset is None:
            filename = os.path.join(args.save_dir, f"checkpoint{epoch}_end.pt")
        # Otherwise, we're in the middle of an epoch.
        else:
            filename = os.path.join(
                args.save_dir, f"checkpoint{epoch}_{batch_offset}.pt"
            )

        checkpoint_to_remove = self._update_state(
            new_params_filename=filename, new_averaged_params=new_averaged_params
        )
        extra_state["checkpoint_files"] = list(self._checkpoint_files)

        self.log_if_verbose(
            f"| Preparing to save checkpoints for epoch {epoch}, "
            f"offset {batch_offset}."
        )
        # Saves two copies of the checkpoint - one under a specific name
        # corresponding to its epoch/offset, and another under the generic
        # "checkpoint_last.py" that we restore from in case training is
        # interrupted.
        save_checkpoint_atomic(
            trainer=trainer, final_filename=filename, extra_state=extra_state
        )
        # We update checkpoint_last.pt only after the new averaged checkpoint
        # and epoch/offset-named copy have been written - so that in case either
        # write fails, we'd still be able to resume from the previous
        # checkpoint_last.pt
        last_checkpoint_path = os.path.join(
            args.save_dir, constants.LAST_CHECKPOINT_FILENAME
        )
        assert PathManager.copy(
            filename, last_checkpoint_path, overwrite=True
        ), f"Failed to copy {filename} to {last_checkpoint_path}"
        self.log_if_verbose(
            f"| Finished saving checkpoints for epoch {epoch}, "
            f"offset {batch_offset}."
        )

        # Wait until after checkpoint_last.py has been written to remove the
        # oldest checkpoint. This is so that in case we fail to write a new
        # checkpoint_last.py, we'd still have access to all the files listed in
        # the previous checkpoint_last.py
        self._remove_checkpoint(checkpoint_to_remove)
        return extra_state