Ejemplo n.º 1
0
    def _checkpoint_model(self,
                          task,
                          train_phase_idx,
                          mode_frequency,
                          mode_num,
                          mode="phase"):
        """
        Checkpoint model. Can be called in 3 possible scenarios:
        1. If training becomes NaN, then we checkpoint the model to facilitate debugging
        2. After every N epochs (CHECKPOINT_FREQ), model state is checkpointed.
        3. If user wants to checkpoint during the epoch (ie. after every few training
           iterations, the model state is checkpointed.)

        Args:
            task: Self-supervision task that hold information about training iteration,
                  epoch number etc.
            train_phase_idx (int): current training phase number. Starts from 0
            mode_frequency (int): mode can be "phase" or "iteration". Frequency
                                  of checkpointing for the given mode
            mode_num (int): for the checkpointing mode (phase or iteration), the number
                            of phase or iteration at which checkpointing is being done
        """
        phase_idx = task.phase_idx
        # num_train_phases = num_epochs * num_phases_per_epoch
        # For OSS use, num_train_phases will be equal to num_epochs
        num_train_phases = task.num_train_phases

        # check if we need to checkpoint this phase
        is_checkpointing_phase = is_checkpoint_phase(mode_num, mode_frequency,
                                                     train_phase_idx,
                                                     num_train_phases, mode)
        is_final_train_phase = ((train_phase_idx == (num_train_phases - 1))
                                and task.train and mode == "phase")

        # handle checkpoint:
        if task.train and (is_final_train_phase or is_checkpointing_phase):
            #  - if sharded state consolidate the state
            # /!\ All the ranks have to participate
            if hasattr(task.optimizer,
                       "consolidate_state_dict") and mode != "phase":
                logging.info(
                    f"[{mode}: {mode_num}] Consolidating sharded state on all replicas"
                )
                task.optimizer.consolidate_state_dict()

            # Depending on whether we are in FSDP mode or not
            # - save the checkpoint on the primary rank
            # - save the sharded checkpoint on all ranks
            if is_primary() or isinstance(task.base_model, FSDP):
                checkpoint_folder = task.checkpoint_folder
                logging.info(
                    f"[{mode}: {mode_num}] Saving checkpoint to {checkpoint_folder}"
                )
                model_state_dict = task.get_classy_state()

                # phase_idx is already incremented at the beginning of phase but if we
                # are checkpointing at an iteration in the middle of phase, we should not
                # save the incremented phase_idx as it will incorrectly assume that model
                # trained for that phase already.
                if mode == "iteration":
                    model_state_dict[
                        "phase_idx"] = model_state_dict["phase_idx"] - 1
                    if task.train:
                        train_phase_idx = train_phase_idx - 1
                        model_state_dict["train_phase_idx"] = train_phase_idx
                    restart_phase = phase_idx - 1
                    restart_iteration = task.iteration

                # When loading from a phase checkpoint:
                else:
                    restart_phase = phase_idx
                    restart_iteration = task.iteration

                checkpoint_content = {
                    "phase_idx": restart_phase,
                    "iteration": restart_iteration,
                    "loss": task.loss.state_dict(),
                    "iteration_num": task.local_iteration_num,
                    "train_phase_idx": train_phase_idx,
                    "classy_state_dict": model_state_dict,
                }

                checkpoint_writer = CheckpointWriter(
                    checkpoint_folder=checkpoint_folder,
                    is_final_train_phase=is_final_train_phase,
                    mode=mode,
                    mode_num=mode_num,
                    backend=task.config["CHECKPOINT"]["BACKEND"],
                )

                if isinstance(task.base_model, FSDP):
                    _, rank = get_machine_local_and_dist_rank()
                    checkpoint_writer.save_sharded_checkpoint(
                        content=checkpoint_content,
                        shard_rank=rank,
                        world_size=self.world_size,
                    )
                else:
                    checkpoint_writer.save_consolidated_checkpoint(
                        checkpoint_content)
Ejemplo n.º 2
0
    def _checkpoint_model(self,
                          task,
                          train_phase_idx,
                          mode_frequency,
                          mode_num,
                          mode="phase"):
        """
        Checkpoint model. Can be called in 3 possible scenarios:
        1. If training becomes NaN, then we checkpoint the model to facilitate debugging
        2. After every N epochs (CHECKPOINT_FREQ), model state is checkpointed.
        3. If user wants to checkpoint during the epoch (ie. after every few training
           iterations, the model state is checkpointed.)

        Args:
            task: Self-supervision task that hold information about training iteration,
                  epoch number etc.
            train_phase_idx (int): current training phase number. Starts from 0
            mode_frequency (int): mode can be "phase" or "iteration". Frequency
                                  of checkpointing for the given mode
            mode_num (int): for the checkpointing mode (phase or iteration), the number
                            of phase or iteration at which checkpointing is being done
        """
        phase_idx = task.phase_idx
        num_epochs = task.num_epochs
        # check if we need to checkpoint this phase
        is_checkpointing_phase = is_checkpoint_phase(mode_num, mode_frequency,
                                                     train_phase_idx,
                                                     num_epochs, mode)
        is_final_train_phase = ((train_phase_idx == (num_epochs - 1))
                                and task.train and mode == "phase")

        # handle checkpoint:
        if task.train and (is_final_train_phase or is_checkpointing_phase):
            #  - if sharded state consolidate the state
            # /!\ All the ranks have to participate
            if hasattr(task.optimizer,
                       "consolidate_state_dict") and mode != "phase":
                logging.info(
                    f"[{mode}: {mode_num}] Consolidating sharded state on all replicas"
                )
                task.optimizer.consolidate_state_dict()

            # Model's state dict may need to be obtained on all ranks if we are running
            # with FSDP since all_gather needs to happen here.
            model_state_dict = None
            if isinstance(task.base_model, FSDP):
                model_state_dict = task.get_classy_state()

            # - save the checkpoint on the primary rank
            if is_primary():
                checkpoint_folder = task.checkpoint_folder
                logging.info(
                    f"[{mode}: {mode_num}] Saving checkpoint to {checkpoint_folder}"
                )
                if model_state_dict is None:
                    model_state_dict = task.get_classy_state()
                # phase_idx is already incremented at the beginning of phase but if we
                # are checkpointing at an iteration in the middle of phase, we should not
                # save the incremented phase_idx as it will incorrectly assume that model
                # trained for that phase already.
                if mode == "iteration":
                    phase_idx = phase_idx - 1
                    model_state_dict[
                        "phase_idx"] = model_state_dict["phase_idx"] - 1
                    if task.train:
                        train_phase_idx = train_phase_idx - 1
                        model_state_dict["train_phase_idx"] = train_phase_idx
                checkpoint_task = {
                    "phase_idx": phase_idx,
                    "iteration": task.iteration,
                    "loss": task.loss.state_dict(),
                    "iteration_num": task.local_iteration_num,
                    "train_phase_idx": train_phase_idx,
                    # TODO (Min): change the key to model_state_dict but we need to be careful
                    #             about backward compatibilities.
                    "classy_state_dict": model_state_dict,
                }
                ckpt_name = f"model_{mode}{mode_num}.torch"
                if is_final_train_phase:
                    ckpt_name = f"model_final_checkpoint_{mode}{mode_num}.torch"
                backend = task.config["CHECKPOINT"]["BACKEND"]
                assert backend == "disk", "Only disk BACKEND supported"
                save_checkpoint(checkpoint_folder,
                                checkpoint_task,
                                checkpoint_file=ckpt_name)
                logging.info(
                    f"Saved checkpoint: {checkpoint_folder}/{ckpt_name}")
                # we create the checkpoint symlink and use this symlink to load
                # checkpoints. This helps ensure that the checkpoint we load from
                # are valid. It's a particularly useful feature for resuming trainings.
                logging.info("Creating symlink...")
                symlink_dest_file = f"{checkpoint_folder}/checkpoint.torch"
                source_file = f"{checkpoint_folder}/{ckpt_name}"
                create_file_symlink(source_file, symlink_dest_file)
                logging.info(f"Created symlink: {symlink_dest_file}")