Example #1
0
    def test_fine_tuning_end_to_end_fsdp(self):
        with in_temporary_directory() as pretrain_dir:
            # Run a pre-training to have some weights to being with
            pretrain_config = self._create_pretraining_config(
                with_fsdp=True, fsdp_flatten_parameters=True)
            run_integration_test(pretrain_config)
            sharded_checkpoint_path = os.path.join(pretrain_dir,
                                                   "checkpoint.torch")
            sliced_checkpoint_path = os.path.join(pretrain_dir, "sliced.torch")
            CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                input_checkpoint_path=sharded_checkpoint_path,
                output_checkpoint_path=sliced_checkpoint_path,
            )

            # Create a separate directly in which to run the fine-tuning
            with in_temporary_directory():
                finetune_config = self._create_finetuning_config(
                    sliced_checkpoint_path,
                    construct_single_param_group_only=False,
                    regularize_bias=False,
                    with_fsdp=True,
                    fsdp_flatten_parameters=False,
                )
                result = run_integration_test(finetune_config)
                accuracies = result.get_accuracies(from_metrics_file=True)
                self.assertEqual(4, len(accuracies))
Example #2
0
    def test_fsdp_integration_with_linear_eval(self):
        with in_temporary_directory() as pretrain_dir:

            # Start pre-training
            config = self._create_pretraining_config(
                with_fsdp=True,
                with_activation_checkpointing=True,
                with_mixed_precision=False,
                auto_wrap_threshold=0,
            )
            run_integration_test(config)

            # Consolidate the weights
            CheckpointFormatConverter.sharded_to_consolidated_checkpoint(
                "checkpoint.torch", "checkpoint_conso.torch")

            # Load the checkpoint and perform a linear evaluation on it
            losses = self.run_linear_eval(
                checkpoint_path=os.path.join(pretrain_dir,
                                             "checkpoint_conso.torch"),
                with_fsdp=True,
                with_mixed_precision=False,
                auto_wrap_threshold=0,
            )
            self.assertEqual(8, len(losses))
            print(losses)
def convert_checkpoint(input_path: str, output_path: str, output_type: str):
    assert g_pathmgr.exists(
        input_path), f"Checkpoint input path: {input_path} not found."

    # Make the output directory if it doesn't exist.
    makedir(os.path.split(output_path)[0])

    setup_logging(__name__)
    if output_type == CheckpointType.consolidated.name:
        CheckpointFormatConverter.sharded_to_consolidated_checkpoint(
            input_path, output_path)
    elif output_type == CheckpointType.sliced.name:
        CheckpointFormatConverter.to_sliced_checkpoint(input_path, output_path)
    shutdown_logging()
Example #4
0
    def test_benchmarking_from_a_consolidated_checkpoint_2(self):
        with in_temporary_directory() as checkpoint_folder:
            # Run a pre-training in DDP mode and convert to a consolidated checkpoint
            config = self._create_pretraining_config(with_fsdp=True)
            run_integration_test(config)
            sharded_checkpoint_path = os.path.join(checkpoint_folder,
                                                   "checkpoint.torch")
            checkpoint_path = os.path.join(checkpoint_folder,
                                           "checkpoint.torch")
            CheckpointFormatConverter.sharded_to_consolidated_checkpoint(
                sharded_checkpoint_path, checkpoint_path)

            # Now, run both DDP and FSDP linear evaluation and compare the traces
            ddp_losses, ddp_accuracies = self.run_benchmarking(checkpoint_path,
                                                               with_fsdp=False)
            fsdp_losses, fsdp_accuracies = self.run_benchmarking(
                checkpoint_path, with_fsdp=True)
            self.assertEqual(ddp_losses, fsdp_losses)
            self.assertEqual(ddp_accuracies, fsdp_accuracies)
    def test_fsdp_integration_with_linear_eval(self):
        with in_temporary_directory() as pretrain_dir:

            # Start pre-training
            config = self._create_pretraining_config(
                with_fsdp=True,
                with_activation_checkpointing=True,
                with_mixed_precision=False,
                auto_wrap_threshold=0,
            )
            run_integration_test(config)

            # Consolidate the weights (3 different ways)
            CheckpointFormatConverter.sharded_to_consolidated_checkpoint(
                "checkpoint.torch", "checkpoint_conso.torch")
            CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                "checkpoint.torch", "checkpoint_sliced.torch")
            CheckpointFormatConverter.consolidated_to_sliced_checkpoint(
                "checkpoint_conso.torch", "checkpoint_sliced_2.torch")

            # Load the sharded checkpoint and perform a inear evaluation on it
            ref_losses = self.run_linear_eval(
                checkpoint_path=os.path.join(pretrain_dir, "checkpoint.torch"),
                with_fsdp=True,
                with_mixed_precision=False,
                auto_wrap_threshold=0,
            )
            self.assertEqual(8, len(ref_losses))

            # Then check that the results are the same for the other kind of
            # checkpoints after consolidation has taken place
            for checkpoint_name in [
                    "checkpoint_conso.torch",
                    "checkpoint_sliced.torch",
                    "checkpoint_sliced_2.torch",
            ]:
                losses = self.run_linear_eval(
                    checkpoint_path=os.path.join(pretrain_dir,
                                                 checkpoint_name),
                    with_fsdp=True,
                    with_mixed_precision=False,
                    auto_wrap_threshold=0,
                )
                self.assertEqual(8, len(losses))
                self.assertAlmostEqual(
                    losses[0],
                    ref_losses[0],
                    places=4,
                    msg=f"Failed for {checkpoint_name}",
                )
Example #6
0
    def test_benchmarking_with_checkpoint_resharding(self):
        with in_temporary_directory() as checkpoint_folder:
            # Run a pre-training in FSDP mode and save a sharded checkpoint
            config = self._create_pretraining_config(with_fsdp=True)
            run_integration_test(config)
            checkpoint_path = os.path.join(checkpoint_folder,
                                           "checkpoint.torch")

            # List the files inside the current working directory
            # to later test what files have been created
            files_before_conversion = set(os.listdir(checkpoint_folder))

            # Transform the sharded checkpoint to a consolidated checkpoint
            eval_checkpoint_path_1 = os.path.join(checkpoint_folder,
                                                  "checkpoint_eval_1.torch")
            CheckpointFormatConverter.sharded_to_consolidated_checkpoint(
                input_checkpoint_path=checkpoint_path,
                output_checkpoint_path=eval_checkpoint_path_1,
            )

            # Transform the sharded checkpoint to a sliced checkpoint
            eval_checkpoint_path_2 = os.path.join(checkpoint_folder,
                                                  "checkpoint_eval_2.torch")
            CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                input_checkpoint_path=checkpoint_path,
                output_checkpoint_path=eval_checkpoint_path_2,
            )

            # Verify the content of the directory after checkpoint conversion
            files_after_conversion = set(os.listdir(checkpoint_folder))
            new_files = files_after_conversion - files_before_conversion
            expected_new_files = {
                "checkpoint_eval_1.torch",
                "checkpoint_eval_2.torch",
                "checkpoint_eval_2_layers",
            }
            self.assertEqual(
                new_files,
                expected_new_files,
                "checkpoint 2 slices should be packaged in a directory",
            )

            # Run a benchmark in FSDP mode and record the losses and accuracies
            eval_losses, eval_accuracies = self.run_benchmarking(
                checkpoint_path, with_fsdp=True)
            self.assertGreater(len(eval_losses), 0)
            self.assertEqual(4, len(eval_accuracies))

            # Check that these losses and accuracies are the same with the
            # consolidated and sliced checkpoints
            for eval_checkpoint in [
                    eval_checkpoint_path_1, eval_checkpoint_path_2
            ]:
                fsdp_losses, fsdp_accuracies = self.run_benchmarking(
                    eval_checkpoint, with_fsdp=True)
                self.assertEqual(fsdp_losses, eval_losses)
                self.assertEqual(fsdp_accuracies, eval_accuracies)

            # Check that the consolidated and sliced checkpoints, contrary to
            # the sharded checkpoint, can be used with a different number of GPUs
            for eval_checkpoint in [
                    eval_checkpoint_path_1, eval_checkpoint_path_2
            ]:
                fsdp_losses, fsdp_accuracies = self.run_benchmarking(
                    eval_checkpoint, with_fsdp=True, num_gpu=1)
                self.assertGreater(len(fsdp_losses), 0)
                self.assertEqual(len(fsdp_accuracies), 4)
Example #7
0
    def test_knn_fsdp(self):
        with in_temporary_directory() as pretrain_dir:

            # Run a pre-training to have some weights to being with
            pretrain_config = self._create_pretraining_config(with_fsdp=True)
            results = run_integration_test(pretrain_config)
            losses = results.get_losses()
            print(losses)

            # Convert checkpoint to sliced checkpoint for easy loading
            CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                "checkpoint.torch", "checkpoint_sliced.torch"
            )
            checkpoint_path = os.path.join(pretrain_dir, "checkpoint_sliced.torch")

            # Create a directory to contain the extracted features
            with in_temporary_directory() as extract_dir:

                # Extract head features
                extract_config_head = self._create_extract_features_config_head(
                    checkpoint_path=checkpoint_path, with_fsdp=True
                )
                extract_config_head.EXTRACT_FEATURES.OUTPUT_DIR = extract_dir
                run_integration_test(
                    extract_config_head, engine_name="extract_features"
                )

                # Extract trunk features
                extract_config_trunk = self._create_extract_features_config_trunk(
                    checkpoint_path=checkpoint_path, with_fsdp=True
                )
                extract_config_trunk.EXTRACT_FEATURES.OUTPUT_DIR = extract_dir
                run_integration_test(
                    extract_config_trunk, engine_name="extract_features"
                )

                # Verify that we can merge the heads features back
                train_feat = ExtractedFeaturesLoader.load_features(
                    extract_dir, "train", "heads", flatten_features=True
                )
                self.assertEqual(train_feat["features"].shape, torch.Size([200, 128]))
                self.assertEqual(train_feat["targets"].shape, torch.Size([200, 1]))
                self.assertEqual(train_feat["inds"].shape, torch.Size([200]))

                # Verify that we can merge the trunk features back
                train_feat = ExtractedFeaturesLoader.load_features(
                    extract_dir, "train", "res5", flatten_features=True
                )
                self.assertEqual(
                    train_feat["features"].shape, torch.Size([200, 3024 * 2 * 2])
                )
                self.assertEqual(train_feat["targets"].shape, torch.Size([200, 1]))
                self.assertEqual(train_feat["inds"].shape, torch.Size([200]))

                # Run KNN on the res5 layer
                extract_config_trunk.NEAREST_NEIGHBOR.FEATURES.PATH = extract_dir
                top_1_ref, top_5_ref, total_ref = run_knn_at_layer(
                    extract_config_trunk, layer_name="res5"
                )
                top_1_opt, top_5_opt, total_opt = run_knn_at_layer_low_memory(
                    extract_config_trunk, layer_name="res5"
                )
                self.assertEqual(total_ref, total_opt)
                # TODO - investigate: both KNN implementation have a bit of randomness
                #  in their accuracies, so the asserts are inequalities.
                self.assertLessEqual(top_1_ref, 30.0)
                self.assertLessEqual(top_1_opt, 30.0)
                self.assertGreaterEqual(top_1_ref, 29.0)
                self.assertGreaterEqual(top_1_opt, 29.0)
                # self.assertEqual(top_1_ref, top_1_opt)
                # self.assertEqual(top_5_ref, top_5_opt)

                # Run KNN on the head layer
                extract_config_head.NEAREST_NEIGHBOR.FEATURES.PATH = extract_dir
                top_1_ref, top_5_ref, total_ref = run_knn_at_layer(
                    extract_config_head, layer_name="heads"
                )
                top_1_opt, top_5_opt, total_opt = run_knn_at_layer_low_memory(
                    extract_config_head, layer_name="heads"
                )
                self.assertEqual(total_ref, total_opt)
                # TODO - investigate: both KNN implementation have a bit of randomness
                #  in their accuracies, so the asserts are inequalities.
                self.assertLessEqual(top_1_ref, 35.0)
                self.assertLessEqual(top_1_opt, 35.0)
                self.assertGreaterEqual(top_1_ref, 33.0)
                self.assertGreaterEqual(top_1_opt, 33.0)
    def _worker(gpu_id: int, sync_file: str, world_size: int):
        torch.manual_seed(0)
        os.environ["RANK"] = str(gpu_id)
        init_distributed_on_file(world_size=world_size,
                                 gpu_id=gpu_id,
                                 sync_file=sync_file)
        torch.backends.cudnn.deterministic = True

        config = TestCheckpointConversion._create_fsdp_model_config(
            with_fsdp=True)
        model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        model = fsdp_wrapper(model, **config.MODEL.FSDP_CONFIG)
        optimizer = optim.SGD(model.parameters(), lr=1e-4)

        # Fake inputs
        num_iterations = 5
        batch_size = 3
        torch.manual_seed(gpu_id)
        fake_inputs = torch.randn(size=(num_iterations, batch_size, 3, 96, 96))
        fake_targets = torch.randn(size=(num_iterations, batch_size))

        # Fake training loop
        criterion = nn.MSELoss()
        for iteration in range(num_iterations):
            fake_input = fake_inputs[iteration].cuda(gpu_id)
            fake_target = fake_targets[iteration].cuda(gpu_id)
            output1, output2 = model(fake_input)[0]
            loss = criterion(output1.sum(axis=-1), fake_target) + criterion(
                output2.sum(axis=-1), fake_target)
            if gpu_id == 0:
                print(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Save a bunch of checkpoint, one by shard
        checkpoint_writer = CheckpointWriter(
            checkpoint_folder=".",
            is_final_train_phase=True,
            mode="iteration",
            mode_num=0,
            backend="disk",
        )
        content = {
            "classy_state_dict": {
                "base_model": {
                    "model": {
                        "trunk": model.trunk.local_state_dict()
                    },
                    "meta": {
                        "trunk": model.trunk.local_metadata_dict()
                    },
                }
            }
        }
        checkpoint_writer.save_sharded_checkpoint(content,
                                                  shard_rank=gpu_id,
                                                  world_size=world_size)
        dist.barrier()
        print(os.listdir("."))

        # Convert the checkpoint to consolidated and sliced checkpoints
        if gpu_id == 0:
            CheckpointFormatConverter.sharded_to_consolidated_checkpoint(
                "checkpoint.torch", "checkpoint_conso.torch")
            CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                "checkpoint.torch", "checkpoint_sliced.torch")
        dist.barrier()
        print(os.listdir("."))

        # Now create models initialized from the previous checkpoint and compare them
        fake_test_input = torch.randn(size=(1, 3, 96, 96)).cuda(gpu_id)

        shard_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint.torch", device=torch.device("cpu"))
        shard_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        shard_model = fsdp_wrapper(shard_model, **config.MODEL.FSDP_CONFIG)
        shard_model.init_model_from_weights_params_file(config, shard_cp)

        conso_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint_conso.torch", device=torch.device("cpu"))
        conso_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        conso_model = fsdp_wrapper(conso_model, **config.MODEL.FSDP_CONFIG)
        conso_model.init_model_from_weights_params_file(config, conso_cp)

        slice_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint_sliced.torch", device=torch.device("cpu"))
        slice_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        slice_model = fsdp_wrapper(slice_model, **config.MODEL.FSDP_CONFIG)
        slice_model.init_model_from_weights_params_file(config, slice_cp)

        # Verifying that the models are equivalent
        if gpu_id == 0:
            slice_state_dict = slice_model.local_state_dict()
            conso_state_dict = conso_model.local_state_dict()
            assert set(slice_state_dict.keys()) == set(conso_state_dict.keys())
            for k in slice_state_dict.keys():
                slice_val = slice_state_dict[k]
                conso_val = conso_state_dict[k]
                assert torch.allclose(
                    slice_val, conso_val
                ), f"Difference for key {k}: {slice_val} VS {conso_val}"
        dist.barrier()

        with torch.no_grad():
            ref_out = model.trunk(fake_test_input)[0]
            shard_out = shard_model.trunk(fake_test_input)[0]
            conso_out = conso_model.trunk(fake_test_input)[0]
            slice_out = slice_model.trunk(fake_test_input)[0]
            assert torch.allclose(
                ref_out, shard_out), f"{ref_out.sum()} vs {shard_out.sum()}"
            assert torch.allclose(
                ref_out, conso_out), f"{ref_out.sum()} vs {conso_out.sum()}"
            assert torch.allclose(
                ref_out, slice_out), f"{ref_out.sum()} vs {slice_out.sum()}"
Example #9
0
    def test_fine_tuning_end_to_end_fsdp(self):
        with in_temporary_directory() as pretrain_dir:
            # Run a pre-training to have some weights to being with
            pretrain_config = self._create_pretraining_config(
                with_fsdp=True, fsdp_flatten_parameters=True)
            run_integration_test(pretrain_config)
            sharded_checkpoint_path = os.path.join(pretrain_dir,
                                                   "checkpoint.torch")

            # Consolidate the checkpoint of the FSDP model
            conso_checkpoint_path = os.path.join(pretrain_dir,
                                                 "consolidated.torch")
            CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                input_checkpoint_path=sharded_checkpoint_path,
                output_checkpoint_path=conso_checkpoint_path,
            )

            # Consolidate the checkpoint of the FSDP model (sliced version)
            sliced_checkpoint_path = os.path.join(pretrain_dir, "sliced.torch")
            CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                input_checkpoint_path=sharded_checkpoint_path,
                output_checkpoint_path=sliced_checkpoint_path,
            )

            # Create a separate directory in which to run the fine-tuning
            with in_temporary_directory():
                finetune_config = self._create_finetuning_config(
                    sliced_checkpoint_path,
                    construct_single_param_group_only=False,
                    regularize_bias=False,
                    with_fsdp=True,
                    fsdp_flatten_parameters=False,
                    with_partial_head=False,
                )
                result = run_integration_test(finetune_config)
                accuracies = result.get_accuracies(from_metrics_file=True)
                self.assertEqual(4, len(accuracies))

            # Create a separate directory in which we run the fine-tuning
            # with a partial head loading (sliced checkpoint)
            with in_temporary_directory():
                finetune_config = self._create_finetuning_config(
                    sliced_checkpoint_path,
                    construct_single_param_group_only=False,
                    regularize_bias=False,
                    with_fsdp=True,
                    fsdp_flatten_parameters=False,
                    with_partial_head=True,
                )
                result = run_integration_test(finetune_config)
                losses = result.get_losses()
                first_loss_sliced = losses[0]
                accuracies = result.get_accuracies(from_metrics_file=True)
                self.assertEqual(4, len(accuracies))

            # Create a separate directory in which we run the fine-tuning
            # with a partial head loading (consolidated checkpoint)
            with in_temporary_directory():
                finetune_config = self._create_finetuning_config(
                    conso_checkpoint_path,
                    construct_single_param_group_only=False,
                    regularize_bias=False,
                    with_fsdp=True,
                    fsdp_flatten_parameters=False,
                    with_partial_head=True,
                )
                result = run_integration_test(finetune_config)
                losses = result.get_losses()
                self.assertAlmostEqual(first_loss_sliced, losses[0], places=4)
                accuracies = result.get_accuracies(from_metrics_file=True)
                self.assertEqual(4, len(accuracies))
    def test_fsdp_extract_label_predictions(self):
        with in_temporary_directory() as pretrain_dir:

            # Start pre-training and consolidate the checkpoint
            config = self._create_pretraining_config(
                with_fsdp=True,
                with_activation_checkpointing=True,
                with_mixed_precision=False,
                auto_wrap_threshold=0,
            )
            run_integration_test(config)
            CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                "checkpoint.torch", "checkpoint_sliced.torch")

            with in_temporary_directory() as fine_tune_dir:

                # Run fine-tuning on the consolidated checkpoint
                finetune_config = self._create_finetuning_config(
                    checkpoint_path=os.path.join(pretrain_dir,
                                                 "checkpoint_sliced.torch"),
                    auto_wrap_threshold=True,
                    with_fsdp=True,
                    with_partial_head=True,
                    with_mixed_precision=True,
                    with_activation_checkpointing=True,
                )
                finetune_config.OPTIMIZER.num_epochs = 1
                results = run_integration_test(finetune_config)
                fine_tune_accuracy = results.get_final_accuracy(layer_name="0")

                # Consolidate the fine-tuned evaluation checkpoint
                CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                    "checkpoint.torch", "checkpoint_sliced.torch")

                # Run label extract on top of fine-tuned checkpoint
                # Run label extraction on both sharded and consolidated checkpoint
                for checkpoint_name in [
                        "checkpoint.torch", "checkpoint_sliced.torch"
                ]:
                    with in_temporary_directory() as extract_dir:
                        extract_config = (
                            self.
                            _create_extract_label_prediction_finetuned_config(
                                with_fsdp=True,
                                with_mixed_precision=False,
                                auto_wrap_threshold=True,
                            ))
                        extract_config.MODEL.WEIGHTS_INIT.PARAMS_FILE = os.path.join(
                            fine_tune_dir, checkpoint_name)
                        run_integration_test(
                            extract_config,
                            engine_name="extract_label_predictions")

                        accuracy, preds, targets = self._read_accuracy(
                            extract_dir, layer="heads")
                        print(targets)
                        print(preds)
                        print("Accuracy:", accuracy)
                        self.assertAlmostEqual(accuracy * 100.0,
                                               fine_tune_accuracy,
                                               places=3)

            with in_temporary_directory() as lin_eval_dir:

                # Run linear evaluation on the consolidated checkpoint
                config = self._create_linear_evaluation_config(
                    with_fsdp=True,
                    with_mixed_precision=True,
                    auto_wrap_threshold=True)
                config.OPTIMIZER.num_epochs = 1
                config.OPTIMIZER.param_schedulers.lr.values = [0.01]
                config.OPTIMIZER.param_schedulers.lr.milestones = []
                config.MODEL.WEIGHTS_INIT.PARAMS_FILE = os.path.join(
                    pretrain_dir, "checkpoint_sliced.torch")
                results = run_integration_test(config)
                linear_accuracy = results.get_final_accuracy("res5")

                # Consolidate the linear evaluation checkpoint
                CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                    "checkpoint.torch", "checkpoint_sliced.torch")

                # Run label extraction on both sharded and consolidated checkpoint
                for checkpoint_name in [
                        "checkpoint.torch", "checkpoint_sliced.torch"
                ]:
                    with in_temporary_directory() as extract_dir:
                        extract_config = self._create_extract_label_prediction_config(
                            with_fsdp=True,
                            with_mixed_precision=False,
                            auto_wrap_threshold=True,
                        )
                        extract_config.MODEL.WEIGHTS_INIT.PARAMS_FILE = os.path.join(
                            lin_eval_dir, checkpoint_name)
                        run_integration_test(
                            extract_config,
                            engine_name="extract_label_predictions")

                        # Read the predictions and verify that they match
                        # the linear evaluation results
                        accuracy, preds, targets = self._read_accuracy(
                            extract_dir, layer="res5")
                        print(targets)
                        print(preds)
                        print("Accuracy:", accuracy)
                        self.assertAlmostEqual(accuracy * 100.0,
                                               linear_accuracy,
                                               places=3)