def save_checkpoint_atomic(trainer, final_filename, extra_state): """Wrapper around trainer.save_checkpoint to make save atomic.""" temp_filename = os.path.join(final_filename + ".tmp") trainer.save_checkpoint(temp_filename, extra_state) # TODO(T56266125): Use mv() instead of copy() + rm() after it's added to # PathManager. assert PathManager.copy( temp_filename, final_filename, overwrite=True ), f"Failed to copy {temp_filename} to {final_filename}" PathManager.rm(temp_filename)
def save( self, args, trainer, extra_state: Dict[str, Any], new_averaged_params: OrderedDict, ) -> Dict[str, Any]: """Saves the model params contained in trainer. Takes ownership of new_averaged_params, so the caller should not modify them afterwards. Args: trainer: Trainer containing the model to be saved. extra_state: Dictionary containing any extra information about the model beyond the param weights. new_averaged_params: If specified, takes ownership of the params and sets them as current set of averaged params. If not specified, we will recalculate the averaged params using the model params in trainer. Returns: Updated extra_state dictionary. """ epoch = extra_state["epoch"] batch_offset = extra_state["batch_offset"] # batch_offset being None means that we're at the end of an epoch. if batch_offset is None: filename = os.path.join(args.save_dir, f"checkpoint{epoch}_end.pt") # Otherwise, we're in the middle of an epoch. else: filename = os.path.join( args.save_dir, f"checkpoint{epoch}_{batch_offset}.pt" ) checkpoint_to_remove = self._update_state( new_params_filename=filename, new_averaged_params=new_averaged_params ) extra_state["checkpoint_files"] = list(self._checkpoint_files) self.log_if_verbose( f"| Preparing to save checkpoints for epoch {epoch}, " f"offset {batch_offset}." ) # Saves two copies of the checkpoint - one under a specific name # corresponding to its epoch/offset, and another under the generic # "checkpoint_last.py" that we restore from in case training is # interrupted. save_checkpoint_atomic( trainer=trainer, final_filename=filename, extra_state=extra_state ) # We update checkpoint_last.pt only after the new averaged checkpoint # and epoch/offset-named copy have been written - so that in case either # write fails, we'd still be able to resume from the previous # checkpoint_last.pt last_checkpoint_path = os.path.join( args.save_dir, constants.LAST_CHECKPOINT_FILENAME ) assert PathManager.copy( filename, last_checkpoint_path, overwrite=True ), f"Failed to copy {filename} to {last_checkpoint_path}" self.log_if_verbose( f"| Finished saving checkpoints for epoch {epoch}, " f"offset {batch_offset}." ) # Wait until after checkpoint_last.py has been written to remove the # oldest checkpoint. This is so that in case we fail to write a new # checkpoint_last.py, we'd still have access to all the files listed in # the previous checkpoint_last.py self._remove_checkpoint(checkpoint_to_remove) return extra_state