Esempio n. 1
0
def save_checkpoint_atomic(trainer, final_filename, extra_state):
    """Wrapper around trainer.save_checkpoint to make save atomic."""
    temp_filename = os.path.join(final_filename + ".tmp")
    trainer.save_checkpoint(temp_filename, extra_state)
    # TODO(T56266125): Use mv() instead of copy() + rm() after it's added to
    # PathManager.
    assert PathManager.copy(
        temp_filename, final_filename, overwrite=True
    ), f"Failed to copy {temp_filename} to {final_filename}"
    PathManager.rm(temp_filename)
Esempio n. 2
0
    def save(
        self,
        args,
        trainer,
        extra_state: Dict[str, Any],
        new_averaged_params: OrderedDict,
    ) -> Dict[str, Any]:
        """Saves the model params contained in trainer.

        Takes ownership of new_averaged_params, so the caller should not modify
        them afterwards.

        Args:
          trainer: Trainer containing the model to be saved.
          extra_state: Dictionary containing any extra information about the
              model beyond the param weights.
          new_averaged_params: If specified, takes ownership of the params and
              sets them as current set of averaged params. If not specified,
              we will recalculate the averaged params using the model params
              in trainer.

        Returns:
          Updated extra_state dictionary.
        """
        epoch = extra_state["epoch"]
        batch_offset = extra_state["batch_offset"]

        # batch_offset being None means that we're at the end of an epoch.
        if batch_offset is None:
            filename = os.path.join(args.save_dir, f"checkpoint{epoch}_end.pt")
        # Otherwise, we're in the middle of an epoch.
        else:
            filename = os.path.join(
                args.save_dir, f"checkpoint{epoch}_{batch_offset}.pt"
            )

        checkpoint_to_remove = self._update_state(
            new_params_filename=filename, new_averaged_params=new_averaged_params
        )
        extra_state["checkpoint_files"] = list(self._checkpoint_files)

        self.log_if_verbose(
            f"| Preparing to save checkpoints for epoch {epoch}, "
            f"offset {batch_offset}."
        )
        # Saves two copies of the checkpoint - one under a specific name
        # corresponding to its epoch/offset, and another under the generic
        # "checkpoint_last.py" that we restore from in case training is
        # interrupted.
        save_checkpoint_atomic(
            trainer=trainer, final_filename=filename, extra_state=extra_state
        )
        # We update checkpoint_last.pt only after the new averaged checkpoint
        # and epoch/offset-named copy have been written - so that in case either
        # write fails, we'd still be able to resume from the previous
        # checkpoint_last.pt
        last_checkpoint_path = os.path.join(
            args.save_dir, constants.LAST_CHECKPOINT_FILENAME
        )
        assert PathManager.copy(
            filename, last_checkpoint_path, overwrite=True
        ), f"Failed to copy {filename} to {last_checkpoint_path}"
        self.log_if_verbose(
            f"| Finished saving checkpoints for epoch {epoch}, "
            f"offset {batch_offset}."
        )

        # Wait until after checkpoint_last.py has been written to remove the
        # oldest checkpoint. This is so that in case we fail to write a new
        # checkpoint_last.py, we'd still have access to all the files listed in
        # the previous checkpoint_last.py
        self._remove_checkpoint(checkpoint_to_remove)
        return extra_state