def test_mean_teacher_model() -> None:
    """
    Test training and weight updates of the mean teacher model computation.
    """
    def _get_parameters_of_model(
            model: Union[torch.nn.Module, DataParallelModel]) -> Any:
        """
        Returns the iterator of model parameters
        """
        if isinstance(model, DataParallelModel):
            return model.module.parameters()
        else:
            return model.parameters()

    config = DummyClassification()
    config.num_epochs = 1
    # Set train batch size to be arbitrary big to ensure we have only one training step
    # i.e. one mean teacher update.
    config.train_batch_size = 100
    # Train without mean teacher
    model_train(config)

    # Retrieve the weight after one epoch
    model = create_model_with_temperature_scaling(config)
    print(config.get_path_to_checkpoint(1))
    _ = model_util.load_checkpoint(model, config.get_path_to_checkpoint(1))
    model_weight = next(_get_parameters_of_model(model))

    # Get the starting weight of the mean teacher model
    ml_util.set_random_seed(config.get_effective_random_seed())
    _ = create_model_with_temperature_scaling(config)
    mean_teach_model = create_model_with_temperature_scaling(config)
    initial_weight_mean_teacher_model = next(
        _get_parameters_of_model(mean_teach_model))

    # Now train with mean teacher and check the update of the weight
    alpha = 0.999
    config.mean_teacher_alpha = alpha
    model_train(config)

    # Retrieve weight of mean teacher model saved in the checkpoint
    mean_teacher_model = create_model_with_temperature_scaling(config)
    _ = model_util.load_checkpoint(
        mean_teacher_model,
        config.get_path_to_checkpoint(1, for_mean_teacher_model=True))
    result_weight = next(_get_parameters_of_model(mean_teacher_model))
    # Retrieve the associated student weight
    _ = model_util.load_checkpoint(model, config.get_path_to_checkpoint(1))
    student_model_weight = next(_get_parameters_of_model(model))

    # Assert that the student weight corresponds to the weight of a simple training without mean teacher
    # computation
    assert student_model_weight.allclose(model_weight)

    # Check the update of the parameters
    assert torch.all(alpha * initial_weight_mean_teacher_model +
                     (1 - alpha) * student_model_weight == result_weight)
Esempio n. 2
0
def test_recover_training_mean_teacher_model(test_output_dirs: OutputFolderForTests) -> None:
    """
    Tests that training can be recovered from a previous checkpoint.
    """
    config = DummyClassification()
    config.mean_teacher_alpha = 0.999

    # First round of training
    config.num_epochs = 2
    checkpoint_handler = get_default_checkpoint_handler(model_config=config,
                                                        project_root=test_output_dirs.root_dir)
    model_train(config, checkpoint_handler=checkpoint_handler)
    assert len(list(config.checkpoint_folder.rglob("*.*"))) == 1

    # Restart training from previous run
    config.start_epoch = 2
    config.num_epochs = 3
    # make if seem like run recovery objects have been downloaded
    checkpoint_root = config.checkpoint_folder / "recovered"
    shutil.copytree(config.checkpoint_folder, checkpoint_root)
    checkpoint_handler.run_recovery = RunRecovery([checkpoint_root])

    model_train(config, checkpoint_handler=checkpoint_handler)
    # remove recovery checkpoints
    shutil.rmtree(checkpoint_root)
    assert len(list(config.checkpoint_folder.rglob("*.*"))) == 2
Esempio n. 3
0
def test_recover_training_mean_teacher_model(
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Tests that training can be recovered from a previous checkpoint.
    """
    config = DummyClassification()
    config.mean_teacher_alpha = 0.999
    config.autosave_every_n_val_epochs = 1
    config.set_output_to(test_output_dirs.root_dir / "original")
    os.makedirs(str(config.outputs_folder))

    original_checkpoint_folder = config.checkpoint_folder

    # First round of training
    config.num_epochs = 4
    model_train_unittest(config, output_folder=test_output_dirs)
    assert len(list(config.checkpoint_folder.glob("*.*"))) == 1
    assert (config.checkpoint_folder /
            LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()

    # Restart training from previous run
    config.num_epochs = 3
    config.set_output_to(test_output_dirs.root_dir / "recovered")
    os.makedirs(str(config.outputs_folder))
    # make if seem like run recovery objects have been downloaded
    checkpoint_root = config.checkpoint_folder / "old_run"
    shutil.copytree(str(original_checkpoint_folder), str(checkpoint_root))

    # Create a new checkpoint handler and set run_recovery to the copied checkpoints
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)
    checkpoint_handler.run_recovery = RunRecovery([checkpoint_root])

    model_train_unittest(config,
                         output_folder=test_output_dirs,
                         checkpoint_handler=checkpoint_handler)
    # remove recovery checkpoints
    shutil.rmtree(checkpoint_root)
    assert len(list(config.checkpoint_folder.glob("*.ckpt"))) == 1
def test_classification_metrics() -> None:
    config = DummyClassification()
    metrics = config._get_metrics_computers()
    logits = [torch.tensor([2.1972, 1.3863, 0.4055]), torch.tensor([-0.8473, 2.1972, -0.4055])]
    posteriors = [torch.sigmoid(logit) for logit in logits]
    labels = [torch.tensor([1, 1, 0]), torch.tensor([0, 0, 0])]
    for logit, posterior, label in zip(logits, posteriors, labels):
        for metric in metrics:
            if isinstance(metric, ScalarMetricsBase) and metric.compute_from_logits:
                metric.update(logit, label)
            else:
                metric.update(posterior, label)
    accuracy_05, accuracy_opt, threshold, fpr, fnr, roc_auc, pr_auc, cross_entropy_with_logits = \
        [metric.compute() for metric in metrics]
    all_labels = torch.cat(labels).numpy()
    all_posteriors = torch.cat(posteriors).numpy()
    expected_accuracy_at_05 = np.mean((all_posteriors > 0.5) == all_labels)
    expected_binary_cross_entropy = log_loss(y_true=all_labels, y_pred=all_posteriors)
    expected_fpr, expected_tpr, expected_thresholds = roc_curve(y_true=all_labels, y_score=all_posteriors)
    expected_roc_auc = auc(expected_fpr, expected_tpr)
    expected_optimal_idx = np.argmax(expected_tpr - expected_fpr)
    expected_optimal_threshold = expected_thresholds[expected_optimal_idx]
    expected_accuracy = np.mean((all_posteriors > expected_optimal_threshold) == all_labels)
    expected_optimal_fpr = expected_fpr[expected_optimal_idx]
    expected_optimal_fnr = 1 - expected_tpr[expected_optimal_idx]
    prec, recall, _ = precision_recall_curve(y_true=all_labels, probas_pred=all_posteriors)
    expected_pr_auc = auc(recall, prec)
    assert accuracy_opt == expected_accuracy
    assert threshold == expected_optimal_threshold
    assert fpr == expected_optimal_fpr
    assert fnr == expected_optimal_fnr
    assert roc_auc == expected_roc_auc
    assert pr_auc == expected_pr_auc
    print(pr_auc, expected_pr_auc)
    # Use default relative tolerance of one part in a million due to floating point arithmetic
    assert cross_entropy_with_logits == pytest.approx(expected_binary_cross_entropy)
    assert accuracy_05 == expected_accuracy_at_05
Esempio n. 5
0
def test_autosave_checkpoints(test_output_dirs: OutputFolderForTests,
                              num_epochs: int) -> None:
    """
    Tests that all autosave checkpoints are cleaned up after training.
    """
    # Lightning does not overwrite checkpoints in-place. Rather, it writes "autosave.ckpt",
    # then "autosave-1.ckpt" and deletes "autosave.ckpt", then "autosave.ckpt" and deletes "autosave-v1.ckpt"
    # All those checkpoints should be cleaned up after training, only the best checkpoint should remain.
    config = DummyClassification()
    config.autosave_every_n_val_epochs = 1
    config.set_output_to(test_output_dirs.root_dir)
    config.num_epochs = num_epochs
    model_train_unittest(config, output_folder=test_output_dirs)
    assert len(list(config.checkpoint_folder.glob("*.*"))) == 1
    assert (config.checkpoint_folder /
            LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
Esempio n. 6
0
def test_runner_restart(test_output_dirs: OutputFolderForTests) -> None:
    """
    Test if starting training from a folder where the checkpoints folder already has recovery checkpoints picks up
    that it is a recovery run. Also checks that we update the start epoch in the config at loading time.
    """
    model_config = DummyClassification()
    model_config.set_output_to(test_output_dirs.root_dir)
    model_config.num_epochs = FIXED_EPOCH + 2
    # We save all checkpoints - if recovery works as expected we should have a new checkpoint for epoch 4, 5.
    model_config.recovery_checkpoint_save_interval = 1
    model_config.recovery_checkpoints_save_last_k = -1
    runner = MLRunner(model_config=model_config)
    runner.setup(use_mount_or_download_dataset=False)
    # Epochs are 0 based for saving
    create_model_and_store_checkpoint(model_config,
                                      runner.container.checkpoint_folder /
                                      f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
                                      f"{FIXED_EPOCH - 1}{CHECKPOINT_SUFFIX}",
                                      weights_only=False)
    azure_config = get_default_azure_config()
    checkpoint_handler = CheckpointHandler(
        azure_config=azure_config,
        container=runner.container,
        project_root=test_output_dirs.root_dir)
    _, storing_logger = model_train(checkpoint_handler=checkpoint_handler,
                                    container=runner.container)
    # We expect to have 4 checkpoints, FIXED_EPOCH (recovery), FIXED_EPOCH+1, FIXED_EPOCH and best.
    assert len(os.listdir(runner.container.checkpoint_folder)) == 4
    assert (runner.container.checkpoint_folder /
            f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
            f"{FIXED_EPOCH - 1}{CHECKPOINT_SUFFIX}").exists()
    assert (runner.container.checkpoint_folder /
            f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
            f"{FIXED_EPOCH}{CHECKPOINT_SUFFIX}").exists()
    assert (runner.container.checkpoint_folder /
            f"{RECOVERY_CHECKPOINT_FILE_NAME}_epoch="
            f"{FIXED_EPOCH + 1}{CHECKPOINT_SUFFIX}").exists()
    assert (runner.container.checkpoint_folder /
            BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).exists()
    # Check that we really restarted epoch from epoch FIXED_EPOCH.
    assert list(storing_logger.epochs) == [FIXED_EPOCH,
                                           FIXED_EPOCH + 1]  # type: ignore
def test_recover_training_mean_teacher_model() -> None:
    """
    Tests that training can be recovered from a previous checkpoint.
    """
    config = DummyClassification()
    config.mean_teacher_alpha = 0.999

    # First round of training
    config.num_epochs = 2
    model_train(config)
    assert len(os.listdir(config.checkpoint_folder)) == 1

    # Restart training from previous run
    config.start_epoch = 2
    config.num_epochs = 3
    model_train(config)
    assert len(os.listdir(config.checkpoint_folder)) == 2
def test_classification_metrics() -> None:
    classification_module = ScalarLightning(DummyClassification())
    metrics = classification_module._get_metrics_computers()
    outputs = [torch.tensor([0.9, 0.8, 0.6]), torch.tensor([0.3, 0.9, 0.4])]
    labels = [torch.tensor([1., 1., 0.]), torch.tensor([0., 0., 0.])]
    for output, label in zip(outputs, labels):
        for metric in metrics:
            metric.update(output, label)
    accuracy_05, accuracy_opt, threshold, fpr, fnr, roc_auc, pr_auc, cross_entropy = [
        metric.compute() for metric in metrics
    ]
    all_labels = torch.cat(labels).numpy()
    all_outputs = torch.cat(outputs).numpy()
    expected_accuracy_at_05 = np.mean((all_outputs > 0.5) == all_labels)
    expected_binary_cross_entropy = log_loss(y_true=all_labels,
                                             y_pred=all_outputs)
    expected_fpr, expected_tpr, expected_thresholds = roc_curve(
        y_true=all_labels, y_score=all_outputs)
    expected_roc_auc = auc(expected_fpr, expected_tpr)
    expected_optimal_idx = np.argmax(expected_tpr - expected_fpr)
    expected_optimal_threshold = expected_thresholds[expected_optimal_idx]
    expected_accuracy = np.mean(
        (all_outputs > expected_optimal_threshold) == all_labels)
    expected_optimal_fpr = expected_fpr[expected_optimal_idx]
    expected_optimal_fnr = 1 - expected_tpr[expected_optimal_idx]
    prec, recall, _ = precision_recall_curve(y_true=all_labels,
                                             probas_pred=all_outputs)
    expected_pr_auc = auc(recall, prec)
    assert accuracy_opt == expected_accuracy
    assert threshold == expected_optimal_threshold
    assert fpr == expected_optimal_fpr
    assert fnr == expected_optimal_fnr
    assert roc_auc == expected_roc_auc
    assert pr_auc == expected_pr_auc
    assert cross_entropy == expected_binary_cross_entropy
    assert accuracy_05 == expected_accuracy_at_05
Esempio n. 9
0
def test_recovery_e2e(test_output_dirs: OutputFolderForTests) -> None:
    """
    Test restarting a training: Train a small model for 5 epochs, then continue training to epoch 10 from the results
    of the first training run.
    """
    model_config = DummyClassification()
    model_config.set_output_to(test_output_dirs.root_dir)
    num_epochs_1 = 5
    model_config.num_epochs = num_epochs_1
    storing_logger_1, checkpoint_handler = model_train_unittest(
        model_config, output_folder=test_output_dirs)
    # Logger should have results for epochs 0..4
    assert list(storing_logger_1.epochs) == list(range(num_epochs_1))
    # Now restart the job, train to epoch 10
    num_epochs_2 = 10
    model_config.num_epochs = num_epochs_2
    storing_logger_2, _ = model_train_unittest(
        model_config,
        output_folder=test_output_dirs,
        checkpoint_handler=checkpoint_handler)
    # Logger should have results only for epochs 5..9
    assert list(storing_logger_2.epochs) == list(
        range(num_epochs_1, num_epochs_2))
Esempio n. 10
0
def test_recover_testing_from_run_recovery(
        mean_teacher_model: bool,
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Checks that inference results are the same whether from a checkpoint in the same run, from a run recovery or from a
    local_weights_path param.
    """
    # Train for 4 epochs
    config = DummyClassification()
    if mean_teacher_model:
        config.mean_teacher_alpha = 0.999
    config.set_output_to(test_output_dirs.root_dir / "original")
    os.makedirs(str(config.outputs_folder))
    config.save_start_epoch = 2
    config.save_step_epochs = 2

    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)
    train_results = model_train(config, checkpoint_handler=checkpoint_handler)
    assert len(train_results.learning_rates_per_epoch) == config.num_epochs

    # Run inference on this
    test_results = model_test(config=config,
                              data_split=ModelExecutionMode.TEST,
                              checkpoint_handler=checkpoint_handler)
    assert isinstance(test_results, InferenceMetricsForClassification)
    assert list(test_results.epochs.keys()) == [config.num_epochs]

    # Mimic using a run recovery and see if it is the same
    config_run_recovery = DummyClassification()
    if mean_teacher_model:
        config_run_recovery.mean_teacher_alpha = 0.999
    config_run_recovery.set_output_to(test_output_dirs.root_dir /
                                      "run_recovery")
    os.makedirs(str(config_run_recovery.outputs_folder))

    checkpoint_handler_run_recovery = get_default_checkpoint_handler(
        model_config=config_run_recovery,
        project_root=test_output_dirs.root_dir)
    # make it seem like run recovery objects have been downloaded
    checkpoint_root = config_run_recovery.checkpoint_folder / "recovered"
    shutil.copytree(str(config.checkpoint_folder), str(checkpoint_root))
    checkpoint_handler_run_recovery.run_recovery = RunRecovery(
        [checkpoint_root])
    test_results_run_recovery = model_test(
        config_run_recovery,
        data_split=ModelExecutionMode.TEST,
        checkpoint_handler=checkpoint_handler_run_recovery)
    assert isinstance(test_results_run_recovery,
                      InferenceMetricsForClassification)
    assert list(test_results_run_recovery.epochs.keys()) == [config.num_epochs]
    assert test_results.epochs[config.num_epochs].values()[MetricType.CROSS_ENTROPY.value] == \
           test_results_run_recovery.epochs[config.num_epochs].values()[MetricType.CROSS_ENTROPY.value]

    # Run inference with the local checkpoints
    config_local_weights = DummyClassification()
    if mean_teacher_model:
        config_local_weights.mean_teacher_alpha = 0.999
    config_local_weights.set_output_to(test_output_dirs.root_dir /
                                       "local_weights_path")
    os.makedirs(str(config_local_weights.outputs_folder))

    local_weights_path = test_output_dirs.root_dir / "local_weights_file.pth"
    shutil.copyfile(
        str(
            create_checkpoint_path(config.checkpoint_folder,
                                   epoch=config.num_epochs)),
        local_weights_path)
    config_local_weights.local_weights_path = local_weights_path

    checkpoint_handler_local_weights = get_default_checkpoint_handler(
        model_config=config_local_weights,
        project_root=test_output_dirs.root_dir)
    checkpoint_handler_local_weights.discover_and_download_checkpoints_from_previous_runs(
    )
    test_results_local_weights = model_test(
        config_local_weights,
        data_split=ModelExecutionMode.TEST,
        checkpoint_handler=checkpoint_handler_local_weights)
    assert isinstance(test_results_local_weights,
                      InferenceMetricsForClassification)
    assert list(test_results_local_weights.epochs.keys()) == [0]
    assert test_results.epochs[config.num_epochs].values()[MetricType.CROSS_ENTROPY.value] == \
           test_results_local_weights.epochs[0].values()[MetricType.CROSS_ENTROPY.value]
Esempio n. 11
0
def test_mean_teacher_model(test_output_dirs: OutputFolderForTests) -> None:
    """
    Test training and weight updates of the mean teacher model computation.
    """
    def _get_parameters_of_model(model: DeviceAwareModule) -> Any:
        """
        Returns the iterator of model parameters
        """
        if isinstance(model, DataParallelModel):
            return model.module.parameters()
        else:
            return model.parameters()

    config = DummyClassification()
    config.set_output_to(test_output_dirs.root_dir)
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)

    config.num_epochs = 1
    # Set train batch size to be arbitrary big to ensure we have only one training step
    # i.e. one mean teacher update.
    config.train_batch_size = 100
    # Train without mean teacher
    model_train(config, checkpoint_handler=checkpoint_handler)

    # Retrieve the weight after one epoch
    model_and_info = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TEST,
        checkpoint_path=config.get_path_to_checkpoint(epoch=1))
    model_and_info.try_create_model_and_load_from_checkpoint()
    model = model_and_info.model
    model_weight = next(_get_parameters_of_model(model))

    # Get the starting weight of the mean teacher model
    ml_util.set_random_seed(config.get_effective_random_seed())

    model_and_info_mean_teacher = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TEST,
        checkpoint_path=None)
    model_and_info_mean_teacher.try_create_model_and_load_from_checkpoint()

    model_and_info_mean_teacher.try_create_mean_teacher_model_and_load_from_checkpoint(
    )
    mean_teach_model = model_and_info_mean_teacher.mean_teacher_model
    assert mean_teach_model is not None  # for mypy
    initial_weight_mean_teacher_model = next(
        _get_parameters_of_model(mean_teach_model))

    # Now train with mean teacher and check the update of the weight
    alpha = 0.999
    config.mean_teacher_alpha = alpha
    model_train(config, checkpoint_handler=checkpoint_handler)

    # Retrieve weight of mean teacher model saved in the checkpoint
    model_and_info_mean_teacher = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TEST,
        checkpoint_path=config.get_path_to_checkpoint(1))
    model_and_info_mean_teacher.try_create_mean_teacher_model_and_load_from_checkpoint(
    )
    mean_teacher_model = model_and_info_mean_teacher.mean_teacher_model
    assert mean_teacher_model is not None  # for mypy
    result_weight = next(_get_parameters_of_model(mean_teacher_model))
    # Retrieve the associated student weight
    model_and_info_mean_teacher.try_create_model_and_load_from_checkpoint()
    student_model = model_and_info_mean_teacher.model
    student_model_weight = next(_get_parameters_of_model(student_model))

    # Assert that the student weight corresponds to the weight of a simple training without mean teacher
    # computation
    assert student_model_weight.allclose(model_weight)

    # Check the update of the parameters
    assert torch.all(alpha * initial_weight_mean_teacher_model +
                     (1 - alpha) * student_model_weight == result_weight)
Esempio n. 12
0
def test_recover_testing_from_run_recovery(
        mean_teacher_model: bool,
        test_output_dirs: OutputFolderForTests) -> None:
    """
    Checks that inference results are the same whether from a checkpoint in the same run, from a run recovery or from a
    local_weights_path param.
    """
    # Train for 4 epochs
    config = DummyClassification()
    if mean_teacher_model:
        config.mean_teacher_alpha = 0.999
    config.set_output_to(test_output_dirs.root_dir / "original")
    os.makedirs(str(config.outputs_folder))

    train_results, checkpoint_handler = model_train_unittest(
        config, output_folder=test_output_dirs)
    assert len(train_results.train_results_per_epoch()) == config.num_epochs

    # Run inference on this
    test_results = model_test(
        config=config,
        data_split=ModelExecutionMode.TEST,
        checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
    assert isinstance(test_results, InferenceMetricsForClassification)

    # Mimic using a run recovery and see if it is the same
    config_run_recovery = DummyClassification()
    if mean_teacher_model:
        config_run_recovery.mean_teacher_alpha = 0.999
    config_run_recovery.set_output_to(test_output_dirs.root_dir /
                                      "run_recovery")
    os.makedirs(str(config_run_recovery.outputs_folder))

    checkpoint_handler_run_recovery = get_default_checkpoint_handler(
        model_config=config_run_recovery,
        project_root=test_output_dirs.root_dir)
    # make it seem like run recovery objects have been downloaded
    checkpoint_root = config_run_recovery.checkpoint_folder / "recovered"
    shutil.copytree(str(config.checkpoint_folder), str(checkpoint_root))
    checkpoint_handler_run_recovery.run_recovery = RunRecovery(
        [checkpoint_root])
    test_results_run_recovery = model_test(
        config_run_recovery,
        data_split=ModelExecutionMode.TEST,
        checkpoint_paths=checkpoint_handler_run_recovery.
        get_checkpoints_to_test())
    assert isinstance(test_results_run_recovery,
                      InferenceMetricsForClassification)
    assert test_results.metrics.values()[MetricType.CROSS_ENTROPY.value] == \
           test_results_run_recovery.metrics.values()[MetricType.CROSS_ENTROPY.value]

    # Run inference with the local checkpoints
    config_local_weights = DummyClassification()
    if mean_teacher_model:
        config_local_weights.mean_teacher_alpha = 0.999
    config_local_weights.set_output_to(test_output_dirs.root_dir /
                                       "local_weights_path")
    os.makedirs(str(config_local_weights.outputs_folder))

    local_weights_path = test_output_dirs.root_dir / "local_weights_file.pth"
    shutil.copyfile(
        str(config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX),
        local_weights_path)
    config_local_weights.local_weights_path = [local_weights_path]

    checkpoint_handler_local_weights = get_default_checkpoint_handler(
        model_config=config_local_weights,
        project_root=test_output_dirs.root_dir)
    checkpoint_handler_local_weights.download_recovery_checkpoints_or_weights()
    test_results_local_weights = model_test(
        config_local_weights,
        data_split=ModelExecutionMode.TEST,
        checkpoint_paths=checkpoint_handler_local_weights.
        get_checkpoints_to_test())
    assert isinstance(test_results_local_weights,
                      InferenceMetricsForClassification)
    assert test_results.metrics.values()[MetricType.CROSS_ENTROPY.value] == \
           test_results_local_weights.metrics.values()[MetricType.CROSS_ENTROPY.value]