Beispiel #1
0
    def _checkpoint_model(self,
                          task,
                          train_phase_idx,
                          mode_frequency,
                          mode_num,
                          mode="phase"):
        """
        Checkpoint model. Can be called in 3 possible scenarios:
        1. If training becomes NaN, then we checkpoint the model to facilitate debugging
        2. After every N epochs (CHECKPOINT_FREQ), model state is checkpointed.
        3. If user wants to checkpoint during the epoch (ie. after every few training
           iterations, the model state is checkpointed.)

        Args:
            task: Self-supervision task that hold information about training iteration,
                  epoch number etc.
            train_phase_idx (int): current training phase number. Starts from 0
            mode_frequency (int): mode can be "phase" or "iteration". Frequency
                                  of checkpointing for the given mode
            mode_num (int): for the checkpointing mode (phase or iteration), the number
                            of phase or iteration at which checkpointing is being done
        """
        phase_idx = task.phase_idx
        # num_train_phases = num_epochs * num_phases_per_epoch
        # For OSS use, num_train_phases will be equal to num_epochs
        num_train_phases = task.num_train_phases

        # check if we need to checkpoint this phase
        is_checkpointing_phase = is_checkpoint_phase(mode_num, mode_frequency,
                                                     train_phase_idx,
                                                     num_train_phases, mode)
        is_final_train_phase = ((train_phase_idx == (num_train_phases - 1))
                                and task.train and mode == "phase")

        # handle checkpoint:
        if task.train and (is_final_train_phase or is_checkpointing_phase):
            #  - if sharded state consolidate the state
            # /!\ All the ranks have to participate
            if hasattr(task.optimizer,
                       "consolidate_state_dict") and mode != "phase":
                logging.info(
                    f"[{mode}: {mode_num}] Consolidating sharded state on all replicas"
                )
                task.optimizer.consolidate_state_dict()

            # Depending on whether we are in FSDP mode or not
            # - save the checkpoint on the primary rank
            # - save the sharded checkpoint on all ranks
            if is_primary() or isinstance(task.base_model, FSDP):
                checkpoint_folder = task.checkpoint_folder
                logging.info(
                    f"[{mode}: {mode_num}] Saving checkpoint to {checkpoint_folder}"
                )
                model_state_dict = task.get_classy_state()

                # phase_idx is already incremented at the beginning of phase but if we
                # are checkpointing at an iteration in the middle of phase, we should not
                # save the incremented phase_idx as it will incorrectly assume that model
                # trained for that phase already.
                if mode == "iteration":
                    model_state_dict[
                        "phase_idx"] = model_state_dict["phase_idx"] - 1
                    if task.train:
                        train_phase_idx = train_phase_idx - 1
                        model_state_dict["train_phase_idx"] = train_phase_idx
                    restart_phase = phase_idx - 1
                    restart_iteration = task.iteration

                # When loading from a phase checkpoint:
                else:
                    restart_phase = phase_idx
                    restart_iteration = task.iteration

                checkpoint_content = {
                    "phase_idx": restart_phase,
                    "iteration": restart_iteration,
                    "loss": task.loss.state_dict(),
                    "iteration_num": task.local_iteration_num,
                    "train_phase_idx": train_phase_idx,
                    "classy_state_dict": model_state_dict,
                }

                checkpoint_writer = CheckpointWriter(
                    checkpoint_folder=checkpoint_folder,
                    is_final_train_phase=is_final_train_phase,
                    mode=mode,
                    mode_num=mode_num,
                    backend=task.config["CHECKPOINT"]["BACKEND"],
                )

                if isinstance(task.base_model, FSDP):
                    _, rank = get_machine_local_and_dist_rank()
                    checkpoint_writer.save_sharded_checkpoint(
                        content=checkpoint_content,
                        shard_rank=rank,
                        world_size=self.world_size,
                    )
                else:
                    checkpoint_writer.save_consolidated_checkpoint(
                        checkpoint_content)
    def _worker(gpu_id: int, sync_file: str, world_size: int):
        torch.manual_seed(0)
        os.environ["RANK"] = str(gpu_id)
        init_distributed_on_file(world_size=world_size,
                                 gpu_id=gpu_id,
                                 sync_file=sync_file)
        torch.backends.cudnn.deterministic = True

        config = TestCheckpointConversion._create_fsdp_model_config(
            with_fsdp=True)
        model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        model = fsdp_wrapper(model, **config.MODEL.FSDP_CONFIG)
        optimizer = optim.SGD(model.parameters(), lr=1e-4)

        # Fake inputs
        num_iterations = 5
        batch_size = 3
        torch.manual_seed(gpu_id)
        fake_inputs = torch.randn(size=(num_iterations, batch_size, 3, 96, 96))
        fake_targets = torch.randn(size=(num_iterations, batch_size))

        # Fake training loop
        criterion = nn.MSELoss()
        for iteration in range(num_iterations):
            fake_input = fake_inputs[iteration].cuda(gpu_id)
            fake_target = fake_targets[iteration].cuda(gpu_id)
            output1, output2 = model(fake_input)[0]
            loss = criterion(output1.sum(axis=-1), fake_target) + criterion(
                output2.sum(axis=-1), fake_target)
            if gpu_id == 0:
                print(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Save a bunch of checkpoint, one by shard
        checkpoint_writer = CheckpointWriter(
            checkpoint_folder=".",
            is_final_train_phase=True,
            mode="iteration",
            mode_num=0,
            backend="disk",
        )
        content = {
            "classy_state_dict": {
                "base_model": {
                    "model": {
                        "trunk": model.trunk.local_state_dict()
                    },
                    "meta": {
                        "trunk": model.trunk.local_metadata_dict()
                    },
                }
            }
        }
        checkpoint_writer.save_sharded_checkpoint(content,
                                                  shard_rank=gpu_id,
                                                  world_size=world_size)
        dist.barrier()
        print(os.listdir("."))

        # Convert the checkpoint to consolidated and sliced checkpoints
        if gpu_id == 0:
            CheckpointFormatConverter.sharded_to_consolidated_checkpoint(
                "checkpoint.torch", "checkpoint_conso.torch")
            CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                "checkpoint.torch", "checkpoint_sliced.torch")
        dist.barrier()
        print(os.listdir("."))

        # Now create models initialized from the previous checkpoint and compare them
        fake_test_input = torch.randn(size=(1, 3, 96, 96)).cuda(gpu_id)

        shard_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint.torch", device=torch.device("cpu"))
        shard_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        shard_model = fsdp_wrapper(shard_model, **config.MODEL.FSDP_CONFIG)
        shard_model.init_model_from_weights_params_file(config, shard_cp)

        conso_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint_conso.torch", device=torch.device("cpu"))
        conso_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        conso_model = fsdp_wrapper(conso_model, **config.MODEL.FSDP_CONFIG)
        conso_model.init_model_from_weights_params_file(config, conso_cp)

        slice_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint_sliced.torch", device=torch.device("cpu"))
        slice_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        slice_model = fsdp_wrapper(slice_model, **config.MODEL.FSDP_CONFIG)
        slice_model.init_model_from_weights_params_file(config, slice_cp)

        # Verifying that the models are equivalent
        if gpu_id == 0:
            slice_state_dict = slice_model.local_state_dict()
            conso_state_dict = conso_model.local_state_dict()
            assert set(slice_state_dict.keys()) == set(conso_state_dict.keys())
            for k in slice_state_dict.keys():
                slice_val = slice_state_dict[k]
                conso_val = conso_state_dict[k]
                assert torch.allclose(
                    slice_val, conso_val
                ), f"Difference for key {k}: {slice_val} VS {conso_val}"
        dist.barrier()

        with torch.no_grad():
            ref_out = model.trunk(fake_test_input)[0]
            shard_out = shard_model.trunk(fake_test_input)[0]
            conso_out = conso_model.trunk(fake_test_input)[0]
            slice_out = slice_model.trunk(fake_test_input)[0]
            assert torch.allclose(
                ref_out, shard_out), f"{ref_out.sum()} vs {shard_out.sum()}"
            assert torch.allclose(
                ref_out, conso_out), f"{ref_out.sum()} vs {conso_out.sum()}"
            assert torch.allclose(
                ref_out, slice_out), f"{ref_out.sum()} vs {slice_out.sum()}"