コード例 #1
0
def model_config(
    slice_exclusion_rules: List[SliceExclusionRule],
    summed_probability_rules: List[SummedProbabilityRule]
) -> SegmentationModelBase:
    test_config = DummyModel()
    test_config.slice_exclusion_rules = slice_exclusion_rules
    test_config.summed_probability_rules = summed_probability_rules
    test_config.ground_truth_ids = ground_truth_ids
    return test_config
コード例 #2
0
def _test_model_train(output_dirs: OutputFolderForTests,
                      image_channels: Any,
                      ground_truth_ids: Any,
                      no_mask_channel: bool = False) -> None:
    def _check_patch_centers(diagnostics_per_epoch: List[np.ndarray],
                             should_equal: bool) -> None:
        patch_centers_epoch1 = diagnostics_per_epoch[0]
        assert len(
            diagnostics_per_epoch
        ) > 1, "Not enough data to check patch centers, need at least 2"
        for diagnostic in diagnostics_per_epoch[1:]:
            assert np.array_equal(patch_centers_epoch1,
                                  diagnostic) == should_equal

    def _check_voxel_count(results_per_epoch: List[Dict[str, float]],
                           expected_voxel_count_per_epoch: List[float],
                           prefix: str) -> None:
        assert len(results_per_epoch) == len(expected_voxel_count_per_epoch)
        for epoch, (results, voxel_count) in enumerate(
                zip(results_per_epoch, expected_voxel_count_per_epoch)):
            # In the test data, both structures "region" and "region_1" are read from the same nifti file, hence
            # their voxel counts must be identical.
            for structure in ["region", "region_1"]:
                assert results[f"{MetricType.VOXEL_COUNT.value}/{structure}"] == pytest.approx(voxel_count, abs=1e-2), \
                    f"{prefix} voxel count mismatch for '{structure}' epoch {epoch}"

    def _mean(a: List[float]) -> float:
        return sum(a) / len(a)

    def _mean_list(lists: List[List[float]]) -> List[float]:
        return list(map(_mean, lists))

    logging_to_stdout(log_level=logging.DEBUG)
    train_config = DummyModel()
    train_config.local_dataset = base_path
    train_config.set_output_to(output_dirs.root_dir)
    train_config.image_channels = image_channels
    train_config.ground_truth_ids = ground_truth_ids
    train_config.mask_id = None if no_mask_channel else train_config.mask_id
    train_config.random_seed = 42
    train_config.class_weights = [0.5, 0.25, 0.25]
    train_config.store_dataset_sample = True
    train_config.recovery_checkpoint_save_interval = 1

    if machine_has_gpu:
        expected_train_losses = [0.4553468, 0.454904]
        expected_val_losses = [0.4553881, 0.4553041]
    else:
        expected_train_losses = [0.4553469, 0.4548947]
        expected_val_losses = [0.4553880, 0.4553041]
    loss_absolute_tolerance = 1e-6
    expected_learning_rates = [train_config.l_rate, 5.3589e-4]

    checkpoint_handler = get_default_checkpoint_handler(
        model_config=train_config, project_root=Path(output_dirs.root_dir))
    model_training_result = model_training.model_train(
        train_config, checkpoint_handler=checkpoint_handler)
    assert isinstance(model_training_result, ModelTrainingResults)

    def assert_all_close(metric: str, expected: List[float],
                         **kwargs: Any) -> None:
        actual = model_training_result.get_training_metric(metric)
        assert np.allclose(
            actual, expected, **kwargs
        ), f"Mismatch for {metric}: Got {actual}, expected {expected}"

    # check to make sure training batches are NOT all the same across epochs
    _check_patch_centers(model_training_result.train_diagnostics,
                         should_equal=False)
    # check to make sure validation batches are all the same across epochs
    _check_patch_centers(model_training_result.val_diagnostics,
                         should_equal=True)
    assert_all_close(MetricType.SUBJECT_COUNT.value, [3.0, 3.0])
    assert_all_close(MetricType.LEARNING_RATE.value,
                     expected_learning_rates,
                     rtol=1e-6)

    if is_windows():
        # Randomization comes out slightly different on Windows. Skip the rest of the detailed checks.
        return

    # Simple regression test: Voxel counts should be the same in both epochs on the validation set,
    # and be the same across 'region' and 'region_1' because they derive from the same Nifti files.
    # The following values are read off directly from the results of compute_dice_across_patches in the training loop
    # This checks that averages are computed correctly, and that metric computers are reset after each epoch.
    train_voxels = [[83092.0, 83212.0, 82946.0], [83000.0, 82881.0, 83309.0]]
    val_voxels = [[82765.0, 83212.0], [82765.0, 83212.0]]
    _check_voxel_count(model_training_result.train_results_per_epoch,
                       _mean_list(train_voxels), "Train")
    _check_voxel_count(model_training_result.val_results_per_epoch,
                       _mean_list(val_voxels), "Val")

    actual_train_losses = model_training_result.get_training_metric(
        MetricType.LOSS.value)
    actual_val_losses = model_training_result.get_validation_metric(
        MetricType.LOSS.value)
    print("actual_train_losses = {}".format(actual_train_losses))
    print("actual_val_losses = {}".format(actual_val_losses))
    assert np.allclose(actual_train_losses,
                       expected_train_losses,
                       atol=loss_absolute_tolerance), "Train losses"
    assert np.allclose(actual_val_losses,
                       expected_val_losses,
                       atol=loss_absolute_tolerance), "Val losses"
    # Check that the metric we track for Hyperdrive runs is actually written.
    assert TrackedMetrics.Val_Loss.value.startswith(VALIDATION_PREFIX)
    tracked_metric = TrackedMetrics.Val_Loss.value[len(VALIDATION_PREFIX):]
    for val_result in model_training_result.val_results_per_epoch:
        assert tracked_metric in val_result

    # The following values are read off directly from the results of compute_dice_across_patches in the
    # training loop. Results are slightly different for CPU, hence use a larger tolerance there.
    dice_tolerance = 1e-4 if machine_has_gpu else 4.5e-4
    train_dice_region = [[0.0, 0.0, 4.0282e-04], [0.0309, 0.0334, 0.0961]]
    train_dice_region1 = [[0.4806, 0.4800, 0.4832], [0.4812, 0.4842, 0.4663]]
    # There appears to be some amount of non-determinism here: When using a tolerance of 1e-4, we get occasional
    # test failures on Linux in the cloud (not on Windows, not on AzureML) Unclear where it comes from. Even when
    # failing here, the losses match up to the expected tolerance.
    assert_all_close("Dice/region",
                     _mean_list(train_dice_region),
                     atol=dice_tolerance)
    assert_all_close("Dice/region_1",
                     _mean_list(train_dice_region1),
                     atol=dice_tolerance)
    expected_average_dice = [
        _mean(train_dice_region[i] + train_dice_region1[i])  # type: ignore
        for i in range(len(train_dice_region))
    ]
    assert_all_close("Dice/AverageAcrossStructures",
                     expected_average_dice,
                     atol=dice_tolerance)

    # check output files/directories
    assert train_config.outputs_folder.is_dir()
    assert train_config.logs_folder.is_dir()

    # Tensorboard event files go into a Lightning subfolder (Pytorch Lightning default)
    assert (train_config.logs_folder / "Lightning").is_dir()
    assert len([(train_config.logs_folder / "Lightning").glob("events*")]) == 1

    assert train_config.num_epochs == 2
    # Checkpoint folder
    assert train_config.checkpoint_folder.is_dir()
    actual_checkpoints = list(train_config.checkpoint_folder.rglob("*.ckpt"))
    assert len(
        actual_checkpoints) == 2, f"Actual checkpoints: {actual_checkpoints}"
    assert (train_config.checkpoint_folder /
            RECOVERY_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
    assert (train_config.checkpoint_folder /
            BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
    assert (train_config.outputs_folder / DATASET_CSV_FILE_NAME).is_file()
    assert (train_config.outputs_folder /
            STORED_CSV_FILE_NAMES[ModelExecutionMode.TRAIN]).is_file()
    assert (train_config.outputs_folder /
            STORED_CSV_FILE_NAMES[ModelExecutionMode.VAL]).is_file()

    # Path visualization: There should be 3 slices for each of the 2 subjects
    sampling_folder = train_config.outputs_folder / PATCH_SAMPLING_FOLDER
    assert sampling_folder.is_dir()
    assert train_config.show_patch_sampling > 0
    assert len(list(sampling_folder.rglob(
        "*.png"))) == 3 * train_config.show_patch_sampling

    # Time per epoch: Test that we have all these times logged.
    model_training_result.get_training_metric(
        MetricType.SECONDS_PER_EPOCH.value)
    model_training_result.get_validation_metric(
        MetricType.SECONDS_PER_EPOCH.value)
    model_training_result.get_validation_metric(
        MetricType.SECONDS_PER_BATCH.value)
    model_training_result.get_training_metric(
        MetricType.SECONDS_PER_BATCH.value)
コード例 #3
0
def _test_model_train(output_dirs: TestOutputDirectories,
                      image_channels: Any,
                      ground_truth_ids: Any,
                      no_mask_channel: bool = False) -> None:
    def _check_patch_centers(epoch_results: List[MetricsDict],
                             should_equal: bool) -> None:
        diagnostics_per_epoch = [
            m.diagnostics[MetricType.PATCH_CENTER.value] for m in epoch_results
        ]
        patch_centers_epoch1 = diagnostics_per_epoch[0]
        for diagnostic in diagnostics_per_epoch[1:]:
            assert np.array_equal(patch_centers_epoch1,
                                  diagnostic) == should_equal

    train_config = DummyModel()
    train_config.local_dataset = base_path
    train_config.set_output_to(output_dirs.root_dir)
    train_config.image_channels = image_channels
    train_config.ground_truth_ids = ground_truth_ids
    train_config.mask_id = None if no_mask_channel else train_config.mask_id
    train_config.random_seed = 42
    train_config.class_weights = [0.5, 0.25, 0.25]
    train_config.store_dataset_sample = True

    expected_train_losses = [0.455538, 0.455213]
    expected_val_losses = [0.455190, 0.455139]

    expected_stats = "Epoch\tLearningRate\tTrainLoss\tTrainDice\tValLoss\tValDice\n" \
                     "1\t1.00e-03\t0.456\t0.242\t0.455\t0.000\n" \
                     "2\t5.36e-04\t0.455\t0.247\t0.455\t0.000"

    expected_learning_rates = [[train_config.l_rate], [5.3589e-4]]

    loss_absolute_tolerance = 1e-3
    model_training_result = model_training.model_train(train_config)
    assert isinstance(model_training_result, ModelTrainingResults)

    # check to make sure training batches are NOT all the same across epochs
    _check_patch_centers(model_training_result.train_results_per_epoch,
                         should_equal=False)
    # check to make sure validation batches are all the same across epochs
    _check_patch_centers(model_training_result.val_results_per_epoch,
                         should_equal=True)
    assert isinstance(model_training_result.train_results_per_epoch[0],
                      MetricsDict)
    actual_train_losses = [
        m.get_single_metric(MetricType.LOSS)
        for m in model_training_result.train_results_per_epoch
    ]
    actual_val_losses = [
        m.get_single_metric(MetricType.LOSS)
        for m in model_training_result.val_results_per_epoch
    ]
    print("actual_train_losses = {}".format(actual_train_losses))
    print("actual_val_losses = {}".format(actual_val_losses))
    assert np.allclose(actual_train_losses,
                       expected_train_losses,
                       atol=loss_absolute_tolerance)
    assert np.allclose(actual_val_losses,
                       expected_val_losses,
                       atol=loss_absolute_tolerance)
    assert np.allclose(model_training_result.learning_rates_per_epoch,
                       expected_learning_rates,
                       rtol=1e-6)

    # check output files/directories
    assert train_config.outputs_folder.is_dir()
    assert train_config.logs_folder.is_dir()

    # The train and val folder should contain Tensorflow event files
    assert (train_config.logs_folder / "train").is_dir()
    assert (train_config.logs_folder / "val").is_dir()
    assert len([(train_config.logs_folder / "train").glob("*")]) == 1
    assert len([(train_config.logs_folder / "val").glob("*")]) == 1

    # Checkpoint folder
    # With these settings, we should see a checkpoint only at epoch 2:
    # That's the last epoch, and there should always be checkpoint at the last epoch)
    assert train_config.save_start_epoch == 1
    assert train_config.save_step_epochs == 100
    assert train_config.num_epochs == 2
    assert os.path.isdir(train_config.checkpoint_folder)
    assert os.path.isfile(
        os.path.join(train_config.checkpoint_folder,
                     "2" + CHECKPOINT_FILE_SUFFIX))
    assert (train_config.outputs_folder / DATASET_CSV_FILE_NAME).is_file()
    assert (train_config.outputs_folder /
            STORED_CSV_FILE_NAMES[ModelExecutionMode.TRAIN]).is_file()
    assert (train_config.outputs_folder /
            STORED_CSV_FILE_NAMES[ModelExecutionMode.VAL]).is_file()
    assert_file_contents(train_config.outputs_folder / TRAIN_STATS_FILE,
                         expected_stats)

    # Test for saving of example images
    assert os.path.isdir(train_config.example_images_folder)
    example_files = os.listdir(train_config.example_images_folder)
    assert len(example_files) == 3 * 2
コード例 #4
0
def _test_model_train(output_dirs: OutputFolderForTests,
                      image_channels: Any,
                      ground_truth_ids: Any,
                      no_mask_channel: bool = False) -> None:
    def _check_patch_centers(diagnostics_per_epoch: List[np.ndarray],
                             should_equal: bool) -> None:
        patch_centers_epoch1 = diagnostics_per_epoch[0]
        assert len(
            diagnostics_per_epoch
        ) > 1, "Not enough data to check patch centers, need at least 2"
        for diagnostic in diagnostics_per_epoch[1:]:
            assert np.array_equal(patch_centers_epoch1,
                                  diagnostic) == should_equal

    def _check_voxel_count(results_per_epoch: List[Dict[str, float]],
                           expected_voxel_count_per_epoch: List[float],
                           prefix: str) -> None:
        assert len(results_per_epoch) == len(expected_voxel_count_per_epoch)
        for epoch, (results, voxel_count) in enumerate(
                zip(results_per_epoch, expected_voxel_count_per_epoch)):
            # In the test data, both structures "region" and "region_1" are read from the same nifti file, hence
            # their voxel counts must be identical.
            for structure in ["region", "region_1"]:
                assert results[f"{MetricType.VOXEL_COUNT.value}/{structure}"] == pytest.approx(voxel_count, abs=1e-2), \
                    f"{prefix} voxel count mismatch for '{structure}' epoch {epoch}"

    def _mean(a: List[float]) -> float:
        return sum(a) / len(a)

    def _mean_list(lists: List[List[float]]) -> List[float]:
        return list(map(_mean, lists))

    logging_to_stdout(log_level=logging.DEBUG)
    train_config = DummyModel()
    train_config.local_dataset = base_path
    train_config.set_output_to(output_dirs.root_dir)
    train_config.image_channels = image_channels
    train_config.ground_truth_ids = ground_truth_ids
    train_config.mask_id = None if no_mask_channel else train_config.mask_id
    train_config.random_seed = 42
    train_config.class_weights = [0.5, 0.25, 0.25]
    train_config.store_dataset_sample = no_mask_channel
    train_config.check_exclusive = False

    if machine_has_gpu:
        expected_train_losses = [0.4554231, 0.4550124]
        expected_val_losses = [0.4553894, 0.4553061]
    else:
        expected_train_losses = [0.4554231, 0.4550112]
        expected_val_losses = [0.4553893, 0.4553061]
    loss_absolute_tolerance = 1e-6
    expected_learning_rates = [train_config.l_rate, 5.3589e-4]

    model_training_result, _ = model_train_unittest(train_config,
                                                    output_folder=output_dirs)
    assert isinstance(model_training_result, StoringLogger)
    # Check that all metrics from the BatchTimeCallback are present
    # # TODO: re-enable once the BatchTimeCallback is fixed
    # for epoch, epoch_results in model_training_result.results_per_epoch.items():
    #     for prefix in [TRAIN_PREFIX, VALIDATION_PREFIX]:
    #         for metric_type in [BatchTimeCallback.EPOCH_TIME,
    #                             BatchTimeCallback.BATCH_TIME + " avg",
    #                             BatchTimeCallback.BATCH_TIME + " max",
    #                             BatchTimeCallback.EXCESS_LOADING_TIME]:
    #             expected = BatchTimeCallback.METRICS_PREFIX + prefix + metric_type
    #             assert expected in epoch_results, f"Expected {expected} in results for epoch {epoch}"
    #             # Excess loading time can be zero because that only measure batches over the threshold
    #             if metric_type != BatchTimeCallback.EXCESS_LOADING_TIME:
    #                 value = epoch_results[expected]
    #                 assert isinstance(value, float)
    #                 assert value > 0.0, f"Time for {expected} should be > 0"

    actual_train_losses = model_training_result.get_train_metric(
        MetricType.LOSS.value)
    actual_val_losses = model_training_result.get_val_metric(
        MetricType.LOSS.value)
    print("actual_train_losses = {}".format(actual_train_losses))
    print("actual_val_losses = {}".format(actual_val_losses))

    def assert_all_close(metric: str, expected: List[float],
                         **kwargs: Any) -> None:
        actual = model_training_result.get_train_metric(metric)
        assert np.allclose(
            actual, expected, **kwargs
        ), f"Mismatch for {metric}: Got {actual}, expected {expected}"

    # check to make sure training batches are NOT all the same across epochs
    _check_patch_centers(model_training_result.train_diagnostics,
                         should_equal=False)
    # check to make sure validation batches are all the same across epochs
    _check_patch_centers(model_training_result.val_diagnostics,
                         should_equal=True)
    assert_all_close(MetricType.SUBJECT_COUNT.value, [3.0, 3.0])
    assert_all_close(MetricType.LEARNING_RATE.value,
                     expected_learning_rates,
                     rtol=1e-6)

    if is_windows():
        # Randomization comes out slightly different on Windows. Skip the rest of the detailed checks.
        return

    # Simple regression test: Voxel counts should be the same in both epochs on the validation set,
    # and be the same across 'region' and 'region_1' because they derive from the same Nifti files.
    # The following values are read off directly from the results of compute_dice_across_patches in the training loop
    # This checks that averages are computed correctly, and that metric computers are reset after each epoch.
    train_voxels = [[82765.0, 83212.0, 82740.0], [82831.0, 82647.0, 83255.0]]
    val_voxels = [[82765.0, 83212.0], [82765.0, 83212.0]]
    _check_voxel_count(model_training_result.train_results_per_epoch(),
                       _mean_list(train_voxels), "Train")
    _check_voxel_count(model_training_result.val_results_per_epoch(),
                       _mean_list(val_voxels), "Val")

    assert np.allclose(actual_train_losses,
                       expected_train_losses,
                       atol=loss_absolute_tolerance), "Train losses"
    assert np.allclose(actual_val_losses,
                       expected_val_losses,
                       atol=loss_absolute_tolerance), "Val losses"
    # Check that the metric we track for Hyperdrive runs is actually written.
    assert TrackedMetrics.Val_Loss.value.startswith(VALIDATION_PREFIX)
    tracked_metric = TrackedMetrics.Val_Loss.value[len(VALIDATION_PREFIX):]
    for val_result in model_training_result.val_results_per_epoch():
        assert tracked_metric in val_result

    # The following values are read off directly from the results of compute_dice_across_patches in the
    # training loop. Results are slightly different for GPU, hence use a larger tolerance there.
    dice_tolerance = 1e-3 if machine_has_gpu else 4.5e-4
    train_dice_region = [[0.0, 0.0, 0.0], [0.0376, 0.0343, 0.1017]]
    train_dice_region1 = [[0.4845, 0.4814, 0.4829], [0.4822, 0.4747, 0.4426]]
    # There appears to be some amount of non-determinism here: When using a tolerance of 1e-4, we get occasional
    # test failures on Linux in the cloud (not on Windows, not on AzureML) Unclear where it comes from. Even when
    # failing here, the losses match up to the expected tolerance.
    assert_all_close("Dice/region",
                     _mean_list(train_dice_region),
                     atol=dice_tolerance)
    assert_all_close("Dice/region_1",
                     _mean_list(train_dice_region1),
                     atol=dice_tolerance)
    expected_average_dice = [
        _mean(train_dice_region[i] + train_dice_region1[i])  # type: ignore
        for i in range(len(train_dice_region))
    ]
    assert_all_close("Dice/AverageAcrossStructures",
                     expected_average_dice,
                     atol=dice_tolerance)

    # check output files/directories
    assert train_config.outputs_folder.is_dir()
    assert train_config.logs_folder.is_dir()

    # Tensorboard event files go into a Lightning subfolder (Pytorch Lightning default)
    assert (train_config.logs_folder / "Lightning").is_dir()
    assert len([(train_config.logs_folder / "Lightning").glob("events*")]) == 1

    assert train_config.num_epochs == 2
    # Checkpoint folder
    assert train_config.checkpoint_folder.is_dir()
    actual_checkpoints = list(train_config.checkpoint_folder.rglob("*.ckpt"))
    assert len(
        actual_checkpoints) == 1, f"Actual checkpoints: {actual_checkpoints}"
    assert (train_config.checkpoint_folder /
            LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX).is_file()
    assert (train_config.outputs_folder / DATASET_CSV_FILE_NAME).is_file()
    assert (train_config.outputs_folder /
            STORED_CSV_FILE_NAMES[ModelExecutionMode.TRAIN]).is_file()
    assert (train_config.outputs_folder /
            STORED_CSV_FILE_NAMES[ModelExecutionMode.VAL]).is_file()

    # Path visualization: There should be 3 slices for each of the 2 subjects
    sampling_folder = train_config.outputs_folder / PATCH_SAMPLING_FOLDER
    assert sampling_folder.is_dir()
    assert train_config.show_patch_sampling > 0
    assert len(list(sampling_folder.rglob(
        "*.png"))) == 3 * train_config.show_patch_sampling

    # # Test for saving of example images
    assert train_config.example_images_folder.is_dir(
    ) if train_config.store_dataset_sample else True
    example_files = list(train_config.example_images_folder.rglob("*.*"))
    assert len(example_files) == (3 * 2 *
                                  2 if train_config.store_dataset_sample else 0
                                  )  # images x epochs x patients
コード例 #5
0
def test_model_test(test_output_dirs: OutputFolderForTests,
                    use_partial_ground_truth: bool,
                    allow_partial_ground_truth: bool) -> None:
    """
    Check the CSVs (and image files) output by InnerEye.ML.model_testing.segmentation_model_test
    :param test_output_dirs: The fixture in conftest.py
    :param use_partial_ground_truth: Whether to remove some ground truth labels from some test users
    :param allow_partial_ground_truth: What to set the allow_incomplete_labels flag to
    """
    train_and_test_data_dir = full_ml_test_data_path("train_and_test_data")
    seed_everything(42)
    config = DummyModel()
    config.allow_incomplete_labels = allow_partial_ground_truth
    config.set_output_to(test_output_dirs.root_dir)
    placeholder_dataset_id = "place_holder_dataset_id"
    config.azure_dataset_id = placeholder_dataset_id
    transform = config.get_full_image_sample_transforms().test
    df = pd.read_csv(full_ml_test_data_path(DATASET_CSV_FILE_NAME))

    if use_partial_ground_truth:
        config.check_exclusive = False
        config.ground_truth_ids = ["region", "region_1"]

        # As in Tests.ML.pipelines.test.inference.test_evaluate_model_predictions patients 3, 4,
        # and 5 are in the test dataset with:
        # Patient 3 has one missing ground truth channel: "region"
        df = df[df["subject"].ne(3) | df["channel"].ne("region")]
        # Patient 4 has all missing ground truth channels: "region", "region_1"
        df = df[df["subject"].ne(4) | df["channel"].ne("region")]
        df = df[df["subject"].ne(4) | df["channel"].ne("region_1")]
        # Patient 5 has no missing ground truth channels.

        config.dataset_data_frame = df

        df = df[df.subject.isin([3, 4, 5])]

        config.train_subject_ids = ['1', '2']
        config.test_subject_ids = ['3', '4', '5']
        config.val_subject_ids = ['6', '7']
    else:
        df = df[df.subject.isin([1, 2])]

    if use_partial_ground_truth and not allow_partial_ground_truth:
        with pytest.raises(ValueError) as value_error:
            # noinspection PyTypeHints
            config._datasets_for_inference = {
                ModelExecutionMode.TEST:
                FullImageDataset(config,
                                 df,
                                 full_image_sample_transforms=transform)
            }  # type: ignore
        assert "Patient 3 does not have channel 'region'" in str(
            value_error.value)
        return
    else:
        # noinspection PyTypeHints
        config._datasets_for_inference = {
            ModelExecutionMode.TEST:
            FullImageDataset(config,
                             df,
                             full_image_sample_transforms=transform)
        }  # type: ignore
    execution_mode = ModelExecutionMode.TEST
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)
    # Mimic the behaviour that checkpoints are downloaded from blob storage into the checkpoints folder.
    create_model_and_store_checkpoint(
        config,
        config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX)
    checkpoint_handler.additional_training_done()
    inference_results = model_testing.segmentation_model_test(
        config,
        execution_mode=execution_mode,
        checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
    epoch_dir = config.outputs_folder / get_best_epoch_results_path(
        execution_mode)
    total_num_patients_column_name = f"total_{MetricsFileColumns.Patient.value}".lower(
    )
    if not total_num_patients_column_name.endswith("s"):
        total_num_patients_column_name += "s"

    if use_partial_ground_truth:
        num_subjects = len(pd.unique(df["subject"]))
        if allow_partial_ground_truth:
            assert csv_column_contains_value(
                csv_file_path=epoch_dir / METRICS_AGGREGATES_FILE,
                column_name=total_num_patients_column_name,
                value=num_subjects,
                contains_only_value=True)
            assert csv_column_contains_value(
                csv_file_path=epoch_dir / SUBJECT_METRICS_FILE_NAME,
                column_name=MetricsFileColumns.Dice.value,
                value='',
                contains_only_value=False)
    else:
        aggregates_df = pd.read_csv(epoch_dir / METRICS_AGGREGATES_FILE)
        assert total_num_patients_column_name not in aggregates_df.columns  # Only added if using partial ground truth

        assert not csv_column_contains_value(
            csv_file_path=epoch_dir / SUBJECT_METRICS_FILE_NAME,
            column_name=MetricsFileColumns.Dice.value,
            value='',
            contains_only_value=False)

        assert inference_results.metrics == pytest.approx(0.66606902, abs=1e-6)
        assert config.outputs_folder.is_dir()
        assert epoch_dir.is_dir()
        patient1 = io_util.load_nifti_image(train_and_test_data_dir /
                                            "id1_channel1.nii.gz")
        patient2 = io_util.load_nifti_image(train_and_test_data_dir /
                                            "id2_channel1.nii.gz")

        assert_file_contains_string(epoch_dir / DATASET_ID_FILE,
                                    placeholder_dataset_id)
        assert_file_contains_string(epoch_dir / GROUND_TRUTH_IDS_FILE,
                                    "region")
        assert_text_files_match(
            epoch_dir / model_testing.SUBJECT_METRICS_FILE_NAME,
            train_and_test_data_dir / model_testing.SUBJECT_METRICS_FILE_NAME)
        assert_text_files_match(
            epoch_dir / model_testing.METRICS_AGGREGATES_FILE,
            train_and_test_data_dir / model_testing.METRICS_AGGREGATES_FILE)
        # Plotting results vary between platforms. Can only check if the file is generated, but not its contents.
        assert (epoch_dir / model_testing.BOXPLOT_FILE).exists()

        assert_nifti_content(epoch_dir / "001" / "posterior_region.nii.gz",
                             get_image_shape(patient1), patient1.header, [137],
                             np.ubyte)
        assert_nifti_content(epoch_dir / "002" / "posterior_region.nii.gz",
                             get_image_shape(patient2), patient2.header, [137],
                             np.ubyte)
        assert_nifti_content(epoch_dir / "001" / DEFAULT_RESULT_IMAGE_NAME,
                             get_image_shape(patient1), patient1.header, [1],
                             np.ubyte)
        assert_nifti_content(epoch_dir / "002" / DEFAULT_RESULT_IMAGE_NAME,
                             get_image_shape(patient2), patient2.header, [1],
                             np.ubyte)
        assert_nifti_content(epoch_dir / "001" / "posterior_background.nii.gz",
                             get_image_shape(patient1), patient1.header, [117],
                             np.ubyte)
        assert_nifti_content(epoch_dir / "002" / "posterior_background.nii.gz",
                             get_image_shape(patient2), patient2.header, [117],
                             np.ubyte)
        thumbnails_folder = epoch_dir / model_testing.THUMBNAILS_FOLDER
        assert thumbnails_folder.is_dir()
        png_files = list(thumbnails_folder.glob("*.png"))
        overlays = [f for f in png_files if "_region_slice_" in str(f)]
        assert len(overlays) == len(df.subject.unique(
        )), "There should be one overlay/contour file per subject"

        # Writing dataset.csv normally happens at the beginning of training,
        # but this test reads off a saved checkpoint file.
        # Dataset.csv must be present for plot_cross_validation.
        config.write_dataset_files()
        # Test if the metrics files can be picked up correctly by the cross validation code
        config_and_files = get_config_and_results_for_offline_runs(config)
        result_files = config_and_files.files
        assert len(result_files) == 1
        for file in result_files:
            assert file.execution_mode == execution_mode
            assert file.dataset_csv_file is not None
            assert file.dataset_csv_file.exists()
            assert file.metrics_file is not None
            assert file.metrics_file.exists()