Exemple #1
0
def test_csv_dataset_as_data_loader(normalize_fn: Any,
                                    full_image_dataset: FullImageDataset,
                                    num_dataload_workers: int) -> None:
    batch_size = 2
    # load the original images separately for comparison
    expected_samples = load_train_and_test_data_channels(
        patient_ids=list(range(1, batch_size + 1)),
        normalization_fn=normalize_fn)
    csv_dataset_loader = full_image_dataset.as_data_loader(
        batch_size=batch_size,
        shuffle=True,
        num_dataload_workers=num_dataload_workers)
    for i, batch in enumerate(csv_dataset_loader):
        for x in range(batch_size):
            actual_sample = {}
            for k, v in batch.items():
                actual_sample[k] = v[x]
            sample = Sample.from_dict(sample=actual_sample)
            # have to do this as the ordering in which the dataloader gives samples is non-deterministic
            expected_sample = expected_samples[sample.patient_id -
                                               1]  # type: ignore

            assert sample.patient_id == expected_sample.patient_id
            assert np.array_equal(sample.image, expected_sample.image)
            assert np.array_equal(sample.labels, expected_sample.labels)
            assert np.array_equal(sample.mask, expected_sample.mask)
Exemple #2
0
def evaluate_model_predictions(
        process_id: int, config: SegmentationModelBase,
        dataset: FullImageDataset,
        results_folder: Path) -> Tuple[PatientMetadata, MetricsDict]:
    """
    Evaluates model segmentation predictions, dice scores and surface distances are computed.
    Generated contours are plotted and saved in results folder.
    The function is intended to be used in parallel for loop to process each image in parallel.
    :param process_id: Identifier for the process calling the function
    :param config: Segmentation model config object
    :param dataset: Dataset object, it is used to load intensity image, labels, and patient metadata.
    :param results_folder: Path to results folder
    :returns [PatientMetadata, list[list]]: Patient metadata and list of computed metrics for each image.
    """
    sample = dataset.get_samples_at_index(index=process_id)[0]
    logging.info(f"Evaluating predictions for patient {sample.patient_id}")
    patient_results_folder = get_patient_results_folder(
        results_folder, sample.patient_id)
    segmentation = load_nifti_image(patient_results_folder /
                                    DEFAULT_RESULT_IMAGE_NAME).image
    metrics_per_class = metrics.calculate_metrics_per_class(
        segmentation,
        sample.labels,
        ground_truth_ids=config.ground_truth_ids,
        voxel_spacing=sample.image_spacing,
        patient_id=sample.patient_id)
    thumbnails_folder = results_folder / THUMBNAILS_FOLDER
    thumbnails_folder.mkdir(exist_ok=True)
    plotting.plot_contours_for_all_classes(
        sample,
        segmentation=segmentation,
        foreground_class_names=config.ground_truth_ids,
        result_folder=thumbnails_folder,
        image_range=config.output_range)
    return sample.metadata, metrics_per_class
Exemple #3
0
def full_image_dataset(default_config: SegmentationModelBase,
                       normalize_fn: Callable) -> FullImageDataset:
    df = default_config.get_dataset_splits()
    return FullImageDataset(
        args=default_config,
        full_image_sample_transforms=Compose3D([normalize_fn]),  # type: ignore
        data_frame=df.train)
    def create_and_set_torch_datasets(self, for_training: bool = True, for_inference: bool = True) -> None:
        """
        Creates torch datasets for all model execution modes, and stores them in the object.
        """
        from InnerEye.ML.dataset.cropping_dataset import CroppingDataset
        from InnerEye.ML.dataset.full_image_dataset import FullImageDataset

        dataset_splits = self.get_dataset_splits()
        crop_transforms = self.get_cropped_image_sample_transforms()
        full_image_transforms = self.get_full_image_sample_transforms()
        if for_training:
            self._datasets_for_training = {
                ModelExecutionMode.TRAIN: CroppingDataset(
                    self,
                    dataset_splits.train,
                    cropped_sample_transforms=crop_transforms.train,  # type: ignore
                    full_image_sample_transforms=full_image_transforms.train),  # type: ignore
                ModelExecutionMode.VAL: CroppingDataset(
                    self, dataset_splits.val,
                    cropped_sample_transforms=crop_transforms.val,  # type: ignore
                    full_image_sample_transforms=full_image_transforms.val),  # type: ignore
            }
        if for_inference:
            self._datasets_for_inference = {
                mode: FullImageDataset(
                    self,
                    dataset_splits[mode],
                    full_image_sample_transforms=full_image_transforms.test)  # type: ignore
                for mode in ModelExecutionMode if len(dataset_splits[mode]) > 0
            }
Exemple #5
0
def visualize_random_crops_for_dataset(
        config: SegmentationModelBase,
        output_folder: Optional[Path] = None) -> None:
    """
    For segmentation models only: This function generates visualizations of the effect of sampling random patches
    for training. Visualizations are stored in both Nifti format, and as 3 PNG thumbnail files, in the output folder.
    :param config: The model configuration.
    :param output_folder: The folder in which the visualizations should be written. If not provided, use a subfolder
    "patch_sampling" in the models's default output folder
    """
    dataset_splits = config.get_dataset_splits()
    # Load a sample using the full image data loader
    full_image_dataset = FullImageDataset(config, dataset_splits.train)
    output_folder = output_folder or config.outputs_folder / PATCH_SAMPLING_FOLDER
    count = min(config.show_patch_sampling, len(full_image_dataset))
    for sample_index in range(count):
        sample = full_image_dataset.get_samples_at_index(index=sample_index)[0]
        visualize_random_crops(sample, config, output_folder=output_folder)
Exemple #6
0
def test_model_test(test_output_dirs: OutputFolderForTests) -> None:
    train_and_test_data_dir = full_ml_test_data_path("train_and_test_data")

    config = DummyModel()
    config.set_output_to(test_output_dirs.root_dir)
    epoch = 1
    config.num_epochs = epoch
    assert config.get_test_epochs() == [epoch]
    placeholder_dataset_id = "place_holder_dataset_id"
    config.azure_dataset_id = placeholder_dataset_id
    transform = config.get_full_image_sample_transforms().test
    df = pd.read_csv(full_ml_test_data_path(DATASET_CSV_FILE_NAME))
    df = df[df.subject.isin([1, 2])]
    # noinspection PyTypeHints
    config._datasets_for_inference = \
        {ModelExecutionMode.TEST: FullImageDataset(config, df, full_image_sample_transforms=transform)}  # type: ignore
    execution_mode = ModelExecutionMode.TEST
    checkpoint_handler = get_default_checkpoint_handler(model_config=config,
                                                        project_root=test_output_dirs.root_dir)
    # Mimic the behaviour that checkpoints are downloaded from blob storage into the checkpoints folder.
    stored_checkpoints = full_ml_test_data_path("checkpoints")
    shutil.copytree(str(stored_checkpoints), str(config.checkpoint_folder))
    checkpoint_handler.additional_training_done()
    inference_results = model_testing.segmentation_model_test(config,
                                                              data_split=execution_mode,
                                                              checkpoint_handler=checkpoint_handler)
    epoch_dir = config.outputs_folder / get_epoch_results_path(epoch, execution_mode)
    assert inference_results.epochs[epoch] == pytest.approx(0.66606902, abs=1e-6)

    assert config.outputs_folder.is_dir()
    assert epoch_dir.is_dir()
    patient1 = io_util.load_nifti_image(train_and_test_data_dir / "id1_channel1.nii.gz")
    patient2 = io_util.load_nifti_image(train_and_test_data_dir / "id2_channel1.nii.gz")

    assert_file_contains_string(epoch_dir / DATASET_ID_FILE, placeholder_dataset_id)
    assert_file_contains_string(epoch_dir / GROUND_TRUTH_IDS_FILE, "region")
    assert_text_files_match(epoch_dir / model_testing.METRICS_FILE_NAME,
                            train_and_test_data_dir / model_testing.METRICS_FILE_NAME)
    assert_text_files_match(epoch_dir / model_testing.METRICS_AGGREGATES_FILE,
                            train_and_test_data_dir / model_testing.METRICS_AGGREGATES_FILE)
    # Plotting results vary between platforms. Can only check if the file is generated, but not its contents.
    assert (epoch_dir / model_testing.BOXPLOT_FILE).exists()

    assert_nifti_content(epoch_dir / "001" / "posterior_region.nii.gz", get_image_shape(patient1),
                         patient1.header,
                         [136], np.ubyte)
    assert_nifti_content(epoch_dir / "002" / "posterior_region.nii.gz", get_image_shape(patient2),
                         patient2.header,
                         [136], np.ubyte)
    assert_nifti_content(epoch_dir / "001" / DEFAULT_RESULT_IMAGE_NAME, get_image_shape(patient1),
                         patient1.header,
                         [1], np.ubyte)
    assert_nifti_content(epoch_dir / "002" / DEFAULT_RESULT_IMAGE_NAME, get_image_shape(patient2),
                         patient2.header,
                         [1], np.ubyte)
    assert_nifti_content(epoch_dir / "001" / "posterior_background.nii.gz", get_image_shape(patient1),
                         patient1.header,
                         [118], np.ubyte)
    assert_nifti_content(epoch_dir / "002" / "posterior_background.nii.gz", get_image_shape(patient2),
                         patient2.header,
                         [118], np.ubyte)
    thumbnails_folder = epoch_dir / model_testing.THUMBNAILS_FOLDER
    assert thumbnails_folder.is_dir()
    png_files = list(thumbnails_folder.glob("*.png"))
    overlays = [f for f in png_files if "_region_slice_" in str(f)]
    assert len(overlays) == len(df.subject.unique()), "There should be one overlay/contour file per subject"

    # Writing dataset.csv normally happens at the beginning of training,
    # but this test reads off a saved checkpoint file.
    # Dataset.csv must be present for plot_cross_validation.
    config.write_dataset_files()
    # Test if the metrics files can be picked up correctly by the cross validation code
    config_and_files = get_config_and_results_for_offline_runs(config)
    result_files = config_and_files.files
    assert len(result_files) == 1
    for file in result_files:
        assert file.execution_mode == execution_mode
        assert file.dataset_csv_file is not None
        assert file.dataset_csv_file.exists()
        assert file.metrics_file is not None
        assert file.metrics_file.exists()
def main(args: CheckPatchSamplingConfig) -> None:
    # Identify paths to inputs and outputs
    commandline_args = {
        "train_batch_size": 1,
        "local_dataset": Path(args.local_dataset)
    }
    output_folder = Path(args.output_folder)
    output_folder.mkdir(parents=True, exist_ok=True)

    # Create a config file
    config = ModelConfigLoader[SegmentationModelBase](
    ).create_model_config_from_name(args.model_name,
                                    overrides=commandline_args)

    # Set a random seed
    ml_util.set_random_seed(config.random_seed)

    # Get a dataloader object that checks csv
    dataset_splits = config.get_dataset_splits()

    # Load a sample using the full image data loader
    full_image_dataset = FullImageDataset(config, dataset_splits.train)

    for sample_index in range(args.number_samples):
        sample = CroppingDataset.create_possibly_padded_sample_for_cropping(
            sample=full_image_dataset.get_samples_at_index(
                index=sample_index)[0],
            crop_size=config.crop_size,
            padding_mode=config.padding_mode)
        print("Processing sample: ", sample.patient_id)

        # Exhaustively sample with random crop function
        heatmap = np.zeros(sample.mask.shape, dtype=np.uint16)
        for _ in range(args.number_crop_iterations):
            cropped_sample, center_point = augmentation.random_crop(
                sample=sample,
                crop_size=config.crop_size,
                class_weights=config.class_weights)
            patch_mask = create_mask_for_patch(output_shape=heatmap.shape,
                                               output_dtype=heatmap.dtype,
                                               center=center_point,
                                               crop_size=config.crop_size)
            heatmap += patch_mask

        ct_output_name = str(output_folder /
                             "{}_ct.nii.gz".format(int(sample.patient_id)))
        heatmap_output_name = str(
            output_folder /
            "{}_sampled_patches.nii.gz".format(int(sample.patient_id)))
        if not sample.metadata.image_header:
            raise ValueError("None header expected some header")
        io_util.store_as_nifti(image=heatmap,
                               header=sample.metadata.image_header,
                               file_name=heatmap_output_name,
                               image_type=heatmap.dtype,
                               scale=False)
        io_util.store_as_nifti(image=sample.image[0],
                               header=sample.metadata.image_header,
                               file_name=ct_output_name,
                               image_type=sample.image.dtype,
                               scale=False)
def test_model_test(test_output_dirs: OutputFolderForTests,
                    use_partial_ground_truth: bool,
                    allow_partial_ground_truth: bool) -> None:
    """
    Check the CSVs (and image files) output by InnerEye.ML.model_testing.segmentation_model_test
    :param test_output_dirs: The fixture in conftest.py
    :param use_partial_ground_truth: Whether to remove some ground truth labels from some test users
    :param allow_partial_ground_truth: What to set the allow_incomplete_labels flag to
    """
    train_and_test_data_dir = full_ml_test_data_path("train_and_test_data")
    seed_everything(42)
    config = DummyModel()
    config.allow_incomplete_labels = allow_partial_ground_truth
    config.set_output_to(test_output_dirs.root_dir)
    placeholder_dataset_id = "place_holder_dataset_id"
    config.azure_dataset_id = placeholder_dataset_id
    transform = config.get_full_image_sample_transforms().test
    df = pd.read_csv(full_ml_test_data_path(DATASET_CSV_FILE_NAME))

    if use_partial_ground_truth:
        config.check_exclusive = False
        config.ground_truth_ids = ["region", "region_1"]

        # As in Tests.ML.pipelines.test.inference.test_evaluate_model_predictions patients 3, 4,
        # and 5 are in the test dataset with:
        # Patient 3 has one missing ground truth channel: "region"
        df = df[df["subject"].ne(3) | df["channel"].ne("region")]
        # Patient 4 has all missing ground truth channels: "region", "region_1"
        df = df[df["subject"].ne(4) | df["channel"].ne("region")]
        df = df[df["subject"].ne(4) | df["channel"].ne("region_1")]
        # Patient 5 has no missing ground truth channels.

        config.dataset_data_frame = df

        df = df[df.subject.isin([3, 4, 5])]

        config.train_subject_ids = ['1', '2']
        config.test_subject_ids = ['3', '4', '5']
        config.val_subject_ids = ['6', '7']
    else:
        df = df[df.subject.isin([1, 2])]

    if use_partial_ground_truth and not allow_partial_ground_truth:
        with pytest.raises(ValueError) as value_error:
            # noinspection PyTypeHints
            config._datasets_for_inference = {
                ModelExecutionMode.TEST:
                FullImageDataset(config,
                                 df,
                                 full_image_sample_transforms=transform)
            }  # type: ignore
        assert "Patient 3 does not have channel 'region'" in str(
            value_error.value)
        return
    else:
        # noinspection PyTypeHints
        config._datasets_for_inference = {
            ModelExecutionMode.TEST:
            FullImageDataset(config,
                             df,
                             full_image_sample_transforms=transform)
        }  # type: ignore
    execution_mode = ModelExecutionMode.TEST
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)
    # Mimic the behaviour that checkpoints are downloaded from blob storage into the checkpoints folder.
    create_model_and_store_checkpoint(
        config,
        config.checkpoint_folder / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX)
    checkpoint_handler.additional_training_done()
    inference_results = model_testing.segmentation_model_test(
        config,
        execution_mode=execution_mode,
        checkpoint_paths=checkpoint_handler.get_checkpoints_to_test())
    epoch_dir = config.outputs_folder / get_best_epoch_results_path(
        execution_mode)
    total_num_patients_column_name = f"total_{MetricsFileColumns.Patient.value}".lower(
    )
    if not total_num_patients_column_name.endswith("s"):
        total_num_patients_column_name += "s"

    if use_partial_ground_truth:
        num_subjects = len(pd.unique(df["subject"]))
        if allow_partial_ground_truth:
            assert csv_column_contains_value(
                csv_file_path=epoch_dir / METRICS_AGGREGATES_FILE,
                column_name=total_num_patients_column_name,
                value=num_subjects,
                contains_only_value=True)
            assert csv_column_contains_value(
                csv_file_path=epoch_dir / SUBJECT_METRICS_FILE_NAME,
                column_name=MetricsFileColumns.Dice.value,
                value='',
                contains_only_value=False)
    else:
        aggregates_df = pd.read_csv(epoch_dir / METRICS_AGGREGATES_FILE)
        assert total_num_patients_column_name not in aggregates_df.columns  # Only added if using partial ground truth

        assert not csv_column_contains_value(
            csv_file_path=epoch_dir / SUBJECT_METRICS_FILE_NAME,
            column_name=MetricsFileColumns.Dice.value,
            value='',
            contains_only_value=False)

        assert inference_results.metrics == pytest.approx(0.66606902, abs=1e-6)
        assert config.outputs_folder.is_dir()
        assert epoch_dir.is_dir()
        patient1 = io_util.load_nifti_image(train_and_test_data_dir /
                                            "id1_channel1.nii.gz")
        patient2 = io_util.load_nifti_image(train_and_test_data_dir /
                                            "id2_channel1.nii.gz")

        assert_file_contains_string(epoch_dir / DATASET_ID_FILE,
                                    placeholder_dataset_id)
        assert_file_contains_string(epoch_dir / GROUND_TRUTH_IDS_FILE,
                                    "region")
        assert_text_files_match(
            epoch_dir / model_testing.SUBJECT_METRICS_FILE_NAME,
            train_and_test_data_dir / model_testing.SUBJECT_METRICS_FILE_NAME)
        assert_text_files_match(
            epoch_dir / model_testing.METRICS_AGGREGATES_FILE,
            train_and_test_data_dir / model_testing.METRICS_AGGREGATES_FILE)
        # Plotting results vary between platforms. Can only check if the file is generated, but not its contents.
        assert (epoch_dir / model_testing.BOXPLOT_FILE).exists()

        assert_nifti_content(epoch_dir / "001" / "posterior_region.nii.gz",
                             get_image_shape(patient1), patient1.header, [137],
                             np.ubyte)
        assert_nifti_content(epoch_dir / "002" / "posterior_region.nii.gz",
                             get_image_shape(patient2), patient2.header, [137],
                             np.ubyte)
        assert_nifti_content(epoch_dir / "001" / DEFAULT_RESULT_IMAGE_NAME,
                             get_image_shape(patient1), patient1.header, [1],
                             np.ubyte)
        assert_nifti_content(epoch_dir / "002" / DEFAULT_RESULT_IMAGE_NAME,
                             get_image_shape(patient2), patient2.header, [1],
                             np.ubyte)
        assert_nifti_content(epoch_dir / "001" / "posterior_background.nii.gz",
                             get_image_shape(patient1), patient1.header, [117],
                             np.ubyte)
        assert_nifti_content(epoch_dir / "002" / "posterior_background.nii.gz",
                             get_image_shape(patient2), patient2.header, [117],
                             np.ubyte)
        thumbnails_folder = epoch_dir / model_testing.THUMBNAILS_FOLDER
        assert thumbnails_folder.is_dir()
        png_files = list(thumbnails_folder.glob("*.png"))
        overlays = [f for f in png_files if "_region_slice_" in str(f)]
        assert len(overlays) == len(df.subject.unique(
        )), "There should be one overlay/contour file per subject"

        # Writing dataset.csv normally happens at the beginning of training,
        # but this test reads off a saved checkpoint file.
        # Dataset.csv must be present for plot_cross_validation.
        config.write_dataset_files()
        # Test if the metrics files can be picked up correctly by the cross validation code
        config_and_files = get_config_and_results_for_offline_runs(config)
        result_files = config_and_files.files
        assert len(result_files) == 1
        for file in result_files:
            assert file.execution_mode == execution_mode
            assert file.dataset_csv_file is not None
            assert file.dataset_csv_file.exists()
            assert file.metrics_file is not None
            assert file.metrics_file.exists()