Ejemplo n.º 1
0
    def _restore_model_weights(self, model):
        """
        If using a weights file to initialize the model, we load the weights
        and initialize the model. Since the weights file specified
        by user might not be VISSL trained weights, we expose several config
        options like APPEND_PREFIX, etc to allow successful loading of the weights.
        See MODEL.WEIGHTS_INIT description in vissl/config/defaults.yaml for details.
        """
        params_from_file = self.config["MODEL"]["WEIGHTS_INIT"]
        init_weights_path = params_from_file["PARAMS_FILE"]
        assert init_weights_path, "Shouldn't call this when init_weight_path is empty"
        logging.info(f"Initializing model from: {init_weights_path}")

        if PathManager.exists(init_weights_path):
            checkpoint = CheckpointLoader.load_and_broadcast_init_weights(
                checkpoint_path=init_weights_path, device=torch.device("cpu"))
            model.init_model_from_weights_params_file(self.config, checkpoint)
        return model
    def _worker(gpu_id: int, sync_file: str, world_size: int):
        torch.manual_seed(0)
        os.environ["RANK"] = str(gpu_id)
        init_distributed_on_file(world_size=world_size,
                                 gpu_id=gpu_id,
                                 sync_file=sync_file)
        torch.backends.cudnn.deterministic = True

        config = TestCheckpointConversion._create_fsdp_model_config(
            with_fsdp=True)
        model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        model = fsdp_wrapper(model, **config.MODEL.FSDP_CONFIG)
        optimizer = optim.SGD(model.parameters(), lr=1e-4)

        # Fake inputs
        num_iterations = 5
        batch_size = 3
        torch.manual_seed(gpu_id)
        fake_inputs = torch.randn(size=(num_iterations, batch_size, 3, 96, 96))
        fake_targets = torch.randn(size=(num_iterations, batch_size))

        # Fake training loop
        criterion = nn.MSELoss()
        for iteration in range(num_iterations):
            fake_input = fake_inputs[iteration].cuda(gpu_id)
            fake_target = fake_targets[iteration].cuda(gpu_id)
            output1, output2 = model(fake_input)[0]
            loss = criterion(output1.sum(axis=-1), fake_target) + criterion(
                output2.sum(axis=-1), fake_target)
            if gpu_id == 0:
                print(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Save a bunch of checkpoint, one by shard
        checkpoint_writer = CheckpointWriter(
            checkpoint_folder=".",
            is_final_train_phase=True,
            mode="iteration",
            mode_num=0,
            backend="disk",
        )
        content = {
            "classy_state_dict": {
                "base_model": {
                    "model": {
                        "trunk": model.trunk.local_state_dict()
                    },
                    "meta": {
                        "trunk": model.trunk.local_metadata_dict()
                    },
                }
            }
        }
        checkpoint_writer.save_sharded_checkpoint(content,
                                                  shard_rank=gpu_id,
                                                  world_size=world_size)
        dist.barrier()
        print(os.listdir("."))

        # Convert the checkpoint to consolidated and sliced checkpoints
        if gpu_id == 0:
            CheckpointFormatConverter.sharded_to_consolidated_checkpoint(
                "checkpoint.torch", "checkpoint_conso.torch")
            CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                "checkpoint.torch", "checkpoint_sliced.torch")
        dist.barrier()
        print(os.listdir("."))

        # Now create models initialized from the previous checkpoint and compare them
        fake_test_input = torch.randn(size=(1, 3, 96, 96)).cuda(gpu_id)

        shard_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint.torch", device=torch.device("cpu"))
        shard_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        shard_model = fsdp_wrapper(shard_model, **config.MODEL.FSDP_CONFIG)
        shard_model.init_model_from_weights_params_file(config, shard_cp)

        conso_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint_conso.torch", device=torch.device("cpu"))
        conso_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        conso_model = fsdp_wrapper(conso_model, **config.MODEL.FSDP_CONFIG)
        conso_model.init_model_from_weights_params_file(config, conso_cp)

        slice_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint_sliced.torch", device=torch.device("cpu"))
        slice_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        slice_model = fsdp_wrapper(slice_model, **config.MODEL.FSDP_CONFIG)
        slice_model.init_model_from_weights_params_file(config, slice_cp)

        # Verifying that the models are equivalent
        if gpu_id == 0:
            slice_state_dict = slice_model.local_state_dict()
            conso_state_dict = conso_model.local_state_dict()
            assert set(slice_state_dict.keys()) == set(conso_state_dict.keys())
            for k in slice_state_dict.keys():
                slice_val = slice_state_dict[k]
                conso_val = conso_state_dict[k]
                assert torch.allclose(
                    slice_val, conso_val
                ), f"Difference for key {k}: {slice_val} VS {conso_val}"
        dist.barrier()

        with torch.no_grad():
            ref_out = model.trunk(fake_test_input)[0]
            shard_out = shard_model.trunk(fake_test_input)[0]
            conso_out = conso_model.trunk(fake_test_input)[0]
            slice_out = slice_model.trunk(fake_test_input)[0]
            assert torch.allclose(
                ref_out, shard_out), f"{ref_out.sum()} vs {shard_out.sum()}"
            assert torch.allclose(
                ref_out, conso_out), f"{ref_out.sum()} vs {conso_out.sum()}"
            assert torch.allclose(
                ref_out, slice_out), f"{ref_out.sum()} vs {slice_out.sum()}"