def _save(self, name: str, content): save_checkpoint( checkpoint_folder=self.checkpoint_folder, state=content, checkpoint_file=name, ) logging.info(f"Saved checkpoint: {self.checkpoint_folder}/{name}")
def test_save_and_load_checkpoint(self): checkpoint_dict = {str(i): i * 2 for i in range(1000)} # save to the default checkpoint file save_checkpoint(self.base_dir, checkpoint_dict) # load the checkpoint by using the default file loaded_checkpoint = load_checkpoint(self.base_dir) self.assertDictEqual(checkpoint_dict, loaded_checkpoint) # load the checkpoint by passing the full path checkpoint_path = f"{self.base_dir}/{CHECKPOINT_FILE}" loaded_checkpoint = load_checkpoint(checkpoint_path) self.assertDictEqual(checkpoint_dict, loaded_checkpoint) # create a new checkpoint dict filename = "my_checkpoint.torch" checkpoint_dict = {str(i): i * 3 for i in range(1000)} # save the checkpoint to a different file save_checkpoint(self.base_dir, checkpoint_dict, checkpoint_file=filename) # load the checkpoint by passing the full path checkpoint_path = f"{self.base_dir}/{filename}" loaded_checkpoint = load_checkpoint(checkpoint_path) self.assertDictEqual(checkpoint_dict, loaded_checkpoint)
def _save_checkpoint(self, task, filename): if getattr(task, "test_only", False): return assert PathManager.exists( self.checkpoint_folder ), "Checkpoint folder '{}' deleted unexpectedly".format(self.checkpoint_folder) # save checkpoint: logging.info("Saving checkpoint to '{}'...".format(self.checkpoint_folder)) checkpoint_file = save_checkpoint( self.checkpoint_folder, get_checkpoint_dict(task, self.input_args) ) # make copy of checkpoint that won't be overwritten: PathManager.copy(checkpoint_file, f"{self.checkpoint_folder}/{filename}")
def _save_checkpoint(self, task, filename): if getattr(task, "test_only", False): return assert os.path.exists( self.checkpoint_folder ), "Checkpoint folder '{}' deleted unexpectedly".format(self.checkpoint_folder) # save checkpoint: logging.info("Saving checkpoint to '{}'...".format(self.checkpoint_folder)) checkpoint_file = save_checkpoint( self.checkpoint_folder, get_checkpoint_dict(task, self.input_args) ) # make copy of checkpoint that won't be overwritten: if checkpoint_file: tmp_dir = tempfile.mkdtemp() tmp_file = os.path.join(tmp_dir, filename) copy2(checkpoint_file, tmp_file) move(tmp_file, os.path.join(self.checkpoint_folder, filename))
def _save_checkpoint(self, task, filename): if getattr(task, "test_only", False): return assert PathManager.exists( self.checkpoint_folder ), "Checkpoint folder '{}' deleted unexpectedly".format( self.checkpoint_folder) for prefix in gfs_prefix_list: if self.checkpoint_folder.startswith(prefix): logging.warning( "GFS is deprecating... please save checkpoint to manifold!" ) break # save checkpoint: logging.info("Saving checkpoint to '{}'...".format( self.checkpoint_folder)) checkpoint_file = save_checkpoint( self.checkpoint_folder, get_checkpoint_dict(task, self.input_args)) # make copy of checkpoint that won't be overwritten: PathManager.copy(checkpoint_file, f"{self.checkpoint_folder}/{filename}")
def _checkpoint_model(self, task, train_phase_idx, mode_frequency, mode_num, mode="phase"): """ Checkpoint model. Can be called in 3 possible scenarios: 1. If training becomes NaN, then we checkpoint the model to facilitate debugging 2. After every N epochs (CHECKPOINT_FREQ), model state is checkpointed. 3. If user wants to checkpoint during the epoch (ie. after every few training iterations, the model state is checkpointed.) Args: task: Self-supervision task that hold information about training iteration, epoch number etc. train_phase_idx (int): current training phase number. Starts from 0 mode_frequency (int): mode can be "phase" or "iteration". Frequency of checkpointing for the given mode mode_num (int): for the checkpointing mode (phase or iteration), the number of phase or iteration at which checkpointing is being done """ phase_idx = task.phase_idx num_epochs = task.num_epochs # check if we need to checkpoint this phase is_checkpointing_phase = is_checkpoint_phase(mode_num, mode_frequency, train_phase_idx, num_epochs, mode) is_final_train_phase = ((train_phase_idx == (num_epochs - 1)) and task.train and mode == "phase") # handle checkpoint: if task.train and (is_final_train_phase or is_checkpointing_phase): # - if sharded state consolidate the state # /!\ All the ranks have to participate if hasattr(task.optimizer, "consolidate_state_dict") and mode != "phase": logging.info( f"[{mode}: {mode_num}] Consolidating sharded state on all replicas" ) task.optimizer.consolidate_state_dict() # Model's state dict may need to be obtained on all ranks if we are running # with FSDP since all_gather needs to happen here. model_state_dict = None if isinstance(task.base_model, FSDP): model_state_dict = task.get_classy_state() # - save the checkpoint on the primary rank if is_primary(): checkpoint_folder = task.checkpoint_folder logging.info( f"[{mode}: {mode_num}] Saving checkpoint to {checkpoint_folder}" ) if model_state_dict is None: model_state_dict = task.get_classy_state() # phase_idx is already incremented at the beginning of phase but if we # are checkpointing at an iteration in the middle of phase, we should not # save the incremented phase_idx as it will incorrectly assume that model # trained for that phase already. if mode == "iteration": phase_idx = phase_idx - 1 model_state_dict[ "phase_idx"] = model_state_dict["phase_idx"] - 1 if task.train: train_phase_idx = train_phase_idx - 1 model_state_dict["train_phase_idx"] = train_phase_idx checkpoint_task = { "phase_idx": phase_idx, "iteration": task.iteration, "loss": task.loss.state_dict(), "iteration_num": task.local_iteration_num, "train_phase_idx": train_phase_idx, # TODO (Min): change the key to model_state_dict but we need to be careful # about backward compatibilities. "classy_state_dict": model_state_dict, } ckpt_name = f"model_{mode}{mode_num}.torch" if is_final_train_phase: ckpt_name = f"model_final_checkpoint_{mode}{mode_num}.torch" backend = task.config["CHECKPOINT"]["BACKEND"] assert backend == "disk", "Only disk BACKEND supported" save_checkpoint(checkpoint_folder, checkpoint_task, checkpoint_file=ckpt_name) logging.info( f"Saved checkpoint: {checkpoint_folder}/{ckpt_name}") # we create the checkpoint symlink and use this symlink to load # checkpoints. This helps ensure that the checkpoint we load from # are valid. It's a particularly useful feature for resuming trainings. logging.info("Creating symlink...") symlink_dest_file = f"{checkpoint_folder}/checkpoint.torch" source_file = f"{checkpoint_folder}/{ckpt_name}" create_file_symlink(source_file, symlink_dest_file) logging.info(f"Created symlink: {symlink_dest_file}")