Пример #1
0
 def _save(self, name: str, content):
     save_checkpoint(
         checkpoint_folder=self.checkpoint_folder,
         state=content,
         checkpoint_file=name,
     )
     logging.info(f"Saved checkpoint: {self.checkpoint_folder}/{name}")
    def test_save_and_load_checkpoint(self):
        checkpoint_dict = {str(i): i * 2 for i in range(1000)}

        # save to the default checkpoint file
        save_checkpoint(self.base_dir, checkpoint_dict)

        # load the checkpoint by using the default file
        loaded_checkpoint = load_checkpoint(self.base_dir)
        self.assertDictEqual(checkpoint_dict, loaded_checkpoint)

        # load the checkpoint by passing the full path
        checkpoint_path = f"{self.base_dir}/{CHECKPOINT_FILE}"
        loaded_checkpoint = load_checkpoint(checkpoint_path)
        self.assertDictEqual(checkpoint_dict, loaded_checkpoint)

        # create a new checkpoint dict
        filename = "my_checkpoint.torch"
        checkpoint_dict = {str(i): i * 3 for i in range(1000)}

        # save the checkpoint to a different file
        save_checkpoint(self.base_dir,
                        checkpoint_dict,
                        checkpoint_file=filename)

        # load the checkpoint by passing the full path
        checkpoint_path = f"{self.base_dir}/{filename}"
        loaded_checkpoint = load_checkpoint(checkpoint_path)
        self.assertDictEqual(checkpoint_dict, loaded_checkpoint)
Пример #3
0
    def _save_checkpoint(self, task, filename):
        if getattr(task, "test_only", False):
            return
        assert PathManager.exists(
            self.checkpoint_folder
        ), "Checkpoint folder '{}' deleted unexpectedly".format(self.checkpoint_folder)

        # save checkpoint:
        logging.info("Saving checkpoint to '{}'...".format(self.checkpoint_folder))
        checkpoint_file = save_checkpoint(
            self.checkpoint_folder, get_checkpoint_dict(task, self.input_args)
        )

        # make copy of checkpoint that won't be overwritten:
        PathManager.copy(checkpoint_file, f"{self.checkpoint_folder}/{filename}")
Пример #4
0
    def _save_checkpoint(self, task, filename):
        if getattr(task, "test_only", False):
            return
        assert os.path.exists(
            self.checkpoint_folder
        ), "Checkpoint folder '{}' deleted unexpectedly".format(self.checkpoint_folder)

        # save checkpoint:
        logging.info("Saving checkpoint to '{}'...".format(self.checkpoint_folder))
        checkpoint_file = save_checkpoint(
            self.checkpoint_folder, get_checkpoint_dict(task, self.input_args)
        )

        # make copy of checkpoint that won't be overwritten:
        if checkpoint_file:
            tmp_dir = tempfile.mkdtemp()
            tmp_file = os.path.join(tmp_dir, filename)
            copy2(checkpoint_file, tmp_file)
            move(tmp_file, os.path.join(self.checkpoint_folder, filename))
Пример #5
0
    def _save_checkpoint(self, task, filename):
        if getattr(task, "test_only", False):
            return
        assert PathManager.exists(
            self.checkpoint_folder
        ), "Checkpoint folder '{}' deleted unexpectedly".format(
            self.checkpoint_folder)

        for prefix in gfs_prefix_list:
            if self.checkpoint_folder.startswith(prefix):
                logging.warning(
                    "GFS is deprecating... please save checkpoint to manifold!"
                )
                break

        # save checkpoint:
        logging.info("Saving checkpoint to '{}'...".format(
            self.checkpoint_folder))
        checkpoint_file = save_checkpoint(
            self.checkpoint_folder, get_checkpoint_dict(task, self.input_args))

        # make copy of checkpoint that won't be overwritten:
        PathManager.copy(checkpoint_file,
                         f"{self.checkpoint_folder}/{filename}")
Пример #6
0
    def _checkpoint_model(self,
                          task,
                          train_phase_idx,
                          mode_frequency,
                          mode_num,
                          mode="phase"):
        """
        Checkpoint model. Can be called in 3 possible scenarios:
        1. If training becomes NaN, then we checkpoint the model to facilitate debugging
        2. After every N epochs (CHECKPOINT_FREQ), model state is checkpointed.
        3. If user wants to checkpoint during the epoch (ie. after every few training
           iterations, the model state is checkpointed.)

        Args:
            task: Self-supervision task that hold information about training iteration,
                  epoch number etc.
            train_phase_idx (int): current training phase number. Starts from 0
            mode_frequency (int): mode can be "phase" or "iteration". Frequency
                                  of checkpointing for the given mode
            mode_num (int): for the checkpointing mode (phase or iteration), the number
                            of phase or iteration at which checkpointing is being done
        """
        phase_idx = task.phase_idx
        num_epochs = task.num_epochs
        # check if we need to checkpoint this phase
        is_checkpointing_phase = is_checkpoint_phase(mode_num, mode_frequency,
                                                     train_phase_idx,
                                                     num_epochs, mode)
        is_final_train_phase = ((train_phase_idx == (num_epochs - 1))
                                and task.train and mode == "phase")

        # handle checkpoint:
        if task.train and (is_final_train_phase or is_checkpointing_phase):
            #  - if sharded state consolidate the state
            # /!\ All the ranks have to participate
            if hasattr(task.optimizer,
                       "consolidate_state_dict") and mode != "phase":
                logging.info(
                    f"[{mode}: {mode_num}] Consolidating sharded state on all replicas"
                )
                task.optimizer.consolidate_state_dict()

            # Model's state dict may need to be obtained on all ranks if we are running
            # with FSDP since all_gather needs to happen here.
            model_state_dict = None
            if isinstance(task.base_model, FSDP):
                model_state_dict = task.get_classy_state()

            # - save the checkpoint on the primary rank
            if is_primary():
                checkpoint_folder = task.checkpoint_folder
                logging.info(
                    f"[{mode}: {mode_num}] Saving checkpoint to {checkpoint_folder}"
                )
                if model_state_dict is None:
                    model_state_dict = task.get_classy_state()
                # phase_idx is already incremented at the beginning of phase but if we
                # are checkpointing at an iteration in the middle of phase, we should not
                # save the incremented phase_idx as it will incorrectly assume that model
                # trained for that phase already.
                if mode == "iteration":
                    phase_idx = phase_idx - 1
                    model_state_dict[
                        "phase_idx"] = model_state_dict["phase_idx"] - 1
                    if task.train:
                        train_phase_idx = train_phase_idx - 1
                        model_state_dict["train_phase_idx"] = train_phase_idx
                checkpoint_task = {
                    "phase_idx": phase_idx,
                    "iteration": task.iteration,
                    "loss": task.loss.state_dict(),
                    "iteration_num": task.local_iteration_num,
                    "train_phase_idx": train_phase_idx,
                    # TODO (Min): change the key to model_state_dict but we need to be careful
                    #             about backward compatibilities.
                    "classy_state_dict": model_state_dict,
                }
                ckpt_name = f"model_{mode}{mode_num}.torch"
                if is_final_train_phase:
                    ckpt_name = f"model_final_checkpoint_{mode}{mode_num}.torch"
                backend = task.config["CHECKPOINT"]["BACKEND"]
                assert backend == "disk", "Only disk BACKEND supported"
                save_checkpoint(checkpoint_folder,
                                checkpoint_task,
                                checkpoint_file=ckpt_name)
                logging.info(
                    f"Saved checkpoint: {checkpoint_folder}/{ckpt_name}")
                # we create the checkpoint symlink and use this symlink to load
                # checkpoints. This helps ensure that the checkpoint we load from
                # are valid. It's a particularly useful feature for resuming trainings.
                logging.info("Creating symlink...")
                symlink_dest_file = f"{checkpoint_folder}/checkpoint.torch"
                source_file = f"{checkpoint_folder}/{ckpt_name}"
                create_file_symlink(source_file, symlink_dest_file)
                logging.info(f"Created symlink: {symlink_dest_file}")