def save_checkpoint_atomic(trainer, final_filename, extra_state): """Wrapper around trainer.save_checkpoint to make save atomic.""" temp_filename = os.path.join(final_filename + ".tmp") trainer.save_checkpoint(temp_filename, extra_state) # TODO(T56266125): Use mv() instead of copy() + rm() after it's added to # PathManager. assert PathManager.copy( temp_filename, final_filename, overwrite=True ), f"Failed to copy {temp_filename} to {final_filename}" PathManager.rm(temp_filename)
def load_diverse_ensemble_for_inference(filenames: List[str], task: Optional[ tasks.FairseqTask] = None): """Load an ensemble of diverse models for inference. This method is similar to fairseq.utils.load_ensemble_for_inference but allows to load diverse models with non-uniform args. Args: filenames: List of file names to checkpoints task: Optional[FairseqTask]. If this isn't provided, we setup the task using the first checkpoint's model args loaded from the saved state. Return: models, args: Tuple of lists. models contains the loaded models, args the corresponding configurations. task: Either the input task or the task created within this function using args """ # load model architectures and weights checkpoints_data = [] for filename in filenames: if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) with PathManager.open(filename, "rb") as f: checkpoints_data.append( torch.load( f, map_location=lambda s, l: torch.serialization. default_restore_location(s, "cpu"), )) def get_cfg(cp, key): if "cfg" in cp: return cp["cfg"][key] else: return cp["args"] # build ensemble ensemble = [] if task is None: cfg = get_cfg(checkpoints_data[0], "task") if hasattr(cfg, "mode"): cfg.mode = "eval" task = tasks.setup_task(cfg) for checkpoint_data in checkpoints_data: cfg = get_cfg(checkpoint_data, "model") model = task.build_model(cfg) model.load_state_dict(checkpoint_data["model"]) ensemble.append(model) args_list = [get_cfg(s, "model") for s in checkpoints_data] return ensemble, args_list, task
def _remove_checkpoint(self, checkpoint_to_remove: Optional[str]): if checkpoint_to_remove: self.log_if_verbose( f"| Preparing to remove old checkpoint {checkpoint_to_remove}." ) try: PathManager.rm(checkpoint_to_remove) self.log_if_verbose( f"| Finished removing old checkpoint {checkpoint_to_remove}." ) except FileNotFoundError: print( f"| Unable to find old checkpoint {checkpoint_to_remove} for removal", flush=True, )
def _remove_checkpoint(self, checkpoint_to_remove: Optional[str]): if checkpoint_to_remove: self.log_if_verbose( f"| Preparing to remove old checkpoint {checkpoint_to_remove}." ) try: PathManager.rm(checkpoint_to_remove) self.log_if_verbose( f"| Finished removing old checkpoint {checkpoint_to_remove}." ) except OSError as e: print( f"| Failed to remove old checkpoint {checkpoint_to_remove} " f"- exception: {e}", flush=True, )
def load_to_gpu(path: str) -> Dict[str, Any]: """ Similar to load_to_cpu, but load model to cuda """ with PathManager.open(path, "rb") as f: state = torch.load( f, map_location=(lambda s, _: torch.serialization. default_restore_location(s, "cuda")), ) return state
def load_to_cpu(path: str) -> Dict[str, Any]: """ This is just fairseq's utils.load_checkpoint_to_cpu(), except we don't try to upgrade the state dict for backward compatibility - to make cases where we only care about loading the model params easier to unit test. """ with PathManager.open(path, "rb") as f: state = torch.load( f, map_location=(lambda s, _: torch.serialization. default_restore_location(s, "cpu")), ) return state
def load_existing_checkpoint( checkpoint_path, trainer, restore_state=True) -> Tuple[bool, Optional[Dict]]: loaded = False extra_state = None if not PathManager.isfile(checkpoint_path): print(f"| No existing checkpoint at {checkpoint_path}. " f"Starting training from scratch.") return loaded, extra_state if restore_state: extra_state = trainer.load_checkpoint(checkpoint_path) if extra_state is None: loaded = False print( f"| Failed to load checkpoint and state from {checkpoint_path}." ) else: loaded = True print(f"| Loaded checkpoint {checkpoint_path} " f"(epoch {extra_state['epoch']}) with restored extra state.") # batch_offset being None denotes this was a checkpoint saved at # the end of an epoch (after the last batch). if extra_state["batch_offset"] is None: trainer.lr_step(extra_state["epoch"]) extra_state["epoch"] += 1 extra_state["batch_offset"] = 0 else: dummy_state = trainer.load_checkpoint(checkpoint_path, reset_optimizer=True) if dummy_state is None: loaded = False print( f"| Failed to load checkpoint weights from {checkpoint_path}.") else: loaded = True print(f"| Loaded checkpoint weights from {checkpoint_path}.") return loaded, extra_state
def setup_training_state(args, trainer, task, epoch_itr): """Set up the directory for saving checkpoints. Load pretrained model if specified.""" PathManager.mkdirs(args.save_dir) # If --restore-file is already present under --save-dir, use that one # instead of --pretrained-checkpoint-file. The idea is that # --pretrained-checkpoint-file allows the user to specify restoring from a # different run's checkpoint (possibly with different training params), # while not polluting the previous run's checkpoint directory # with new checkpoints. However, if training gets interrupted # and the user restarts training, we want to resume from # the checkpoints under --save-dir, instead of # restarting again from the old run's checkpoint at # --pretrained-checkpoint-file. # # Note that if args.restore_file is an absolute path, os.path.join() will # ignore previous directory args and just use the absolute path as is. checkpoint_path = os.path.join(args.save_dir, args.restore_file) restore_state = True if PathManager.isfile(checkpoint_path): print( f"| Using --save-dir={args.save_dir}, --restore-file={args.restore_file}." ) elif args.pretrained_checkpoint_file and PathManager.isfile( args.pretrained_checkpoint_file): checkpoint_path = args.pretrained_checkpoint_file restore_state = args.load_pretrained_checkpoint_state print( f"| Using --pretrained-checkpoint-file={args.pretrained_checkpoint_file}, " f"--load-pretrained-checkpoint-state={args.load_pretrained_checkpoint_state}." ) extra_state = default_extra_state(args) if not PathManager.isfile( checkpoint_path) and args.multi_model_restore_files: print( f"| Restoring individual models from {args.multi_model_restore_files}" ) multi_model.import_individual_models(args.multi_model_restore_files, trainer) else: loaded, loaded_extra_state = checkpoint.load_existing_checkpoint( checkpoint_path=checkpoint_path, trainer=trainer, restore_state=restore_state, ) if loaded_extra_state: extra_state.update(loaded_extra_state) # Reset the start time for the current training run. extra_state["start_time"] = time.time() # Skips printing all training progress to prevent log spam. training_progress = extra_state["training_progress"] extra_state["training_progress"] = ([ "...truncated...", training_progress[-1] ] if len(training_progress) > 0 else []) print(f"| extra_state: {extra_state}") extra_state["training_progress"] = training_progress epoch = extra_state["epoch"] if extra_state["batch_offset"] == 0: epoch -= 1 # this will be incremented when we call epoch_itr.next_epoch_itr() epoch_itr.load_state_dict({ "epoch": epoch, "iterations_in_epoch": extra_state["batch_offset"] }) checkpoint_manager = None if distributed_utils.is_master(args): checkpoint_manager = checkpoint.CheckpointManager( num_avg_checkpoints=args.num_avg_checkpoints, auto_clear_checkpoints=args.auto_clear_checkpoints, log_verbose=args.log_verbose, checkpoint_files=extra_state["checkpoint_files"], ) return extra_state, epoch_itr, checkpoint_manager
def save( self, args, trainer, extra_state: Dict[str, Any], new_averaged_params: OrderedDict, ) -> Dict[str, Any]: """Saves the model params contained in trainer. Takes ownership of new_averaged_params, so the caller should not modify them afterwards. Args: trainer: Trainer containing the model to be saved. extra_state: Dictionary containing any extra information about the model beyond the param weights. new_averaged_params: If specified, takes ownership of the params and sets them as current set of averaged params. If not specified, we will recalculate the averaged params using the model params in trainer. Returns: Updated extra_state dictionary. """ epoch = extra_state["epoch"] batch_offset = extra_state["batch_offset"] # batch_offset being None means that we're at the end of an epoch. if batch_offset is None: filename = os.path.join(args.save_dir, f"checkpoint{epoch}_end.pt") # Otherwise, we're in the middle of an epoch. else: filename = os.path.join( args.save_dir, f"checkpoint{epoch}_{batch_offset}.pt" ) checkpoint_to_remove = self._update_state( new_params_filename=filename, new_averaged_params=new_averaged_params ) extra_state["checkpoint_files"] = list(self._checkpoint_files) self.log_if_verbose( f"| Preparing to save checkpoints for epoch {epoch}, " f"offset {batch_offset}." ) # Saves two copies of the checkpoint - one under a specific name # corresponding to its epoch/offset, and another under the generic # "checkpoint_last.py" that we restore from in case training is # interrupted. save_checkpoint_atomic( trainer=trainer, final_filename=filename, extra_state=extra_state ) # We update checkpoint_last.pt only after the new averaged checkpoint # and epoch/offset-named copy have been written - so that in case either # write fails, we'd still be able to resume from the previous # checkpoint_last.pt last_checkpoint_path = os.path.join( args.save_dir, constants.LAST_CHECKPOINT_FILENAME ) assert PathManager.copy( filename, last_checkpoint_path, overwrite=True ), f"Failed to copy {filename} to {last_checkpoint_path}" self.log_if_verbose( f"| Finished saving checkpoints for epoch {epoch}, " f"offset {batch_offset}." ) # Wait until after checkpoint_last.py has been written to remove the # oldest checkpoint. This is so that in case we fail to write a new # checkpoint_last.py, we'd still have access to all the files listed in # the previous checkpoint_last.py self._remove_checkpoint(checkpoint_to_remove) return extra_state