Example #1
0
def test_sample_construct_copy(random_image_crop: Any, random_mask_crop: Any,
                               random_label_crop: Any) -> None:
    sample = Sample(image=random_image_crop,
                    mask=random_mask_crop,
                    labels=random_label_crop,
                    metadata=PatientMetadata(patient_id='1'))

    sample_clone = sample.clone_with_overrides()
    assert sample.get_dict() == sample_clone.get_dict()
    assert type(sample) == type(sample_clone)

    sample_clone = sample.clone_with_overrides(metadata=PatientMetadata(
        patient_id='2'))
    assert sample_clone.patient_id == 2
Example #2
0
def load_train_and_test_data_channels(
        patient_ids: List[int],
        normalization_fn: PhotometricNormalization) -> List[Sample]:
    if np.any(np.asarray(patient_ids) <= 0):
        raise ValueError("data_items must be >= 0")

    file_name = lambda k, y: full_ml_test_data_path("train_and_test_data"
                                                    ) / f"id{k}_{y}.nii.gz"

    get_sample = lambda z: io_util.load_images_from_dataset_source(
        dataset_source=PatientDatasetSource(
            metadata=PatientMetadata(patient_id=z),
            image_channels=[file_name(z, c) for c in TEST_CHANNEL_IDS],
            mask_channel=file_name(z, TEST_MASK_ID),
            ground_truth_channels=[file_name(z, TEST_GT_ID)]))

    samples = []
    for x in patient_ids:
        sample = get_sample(x)
        sample = Sample(image=normalization_fn.transform(
            sample.image, sample.mask),
                        mask=sample.mask,
                        labels=sample.labels,
                        metadata=sample.metadata)
        samples.append(sample)

    return samples
def test_sample(random_image_crop: Any, random_mask_crop: Any, random_label_crop: Any, random_patient_id: Any) -> None:
    """
    Tests that after creating and extracting a sample we obtain the same result
    :return:
    """
    metadata = PatientMetadata(patient_id=42, institution="foo")
    sample = Sample(image=random_image_crop,
                    mask=random_mask_crop,
                    labels=random_label_crop,
                    metadata=metadata)

    patched_sample = CroppedSample(image=random_image_crop,
                                   mask=random_mask_crop,
                                   labels=random_label_crop,
                                   mask_center_crop=random_mask_crop,
                                   labels_center_crop=random_label_crop,
                                   metadata=metadata,
                                   center_indices=np.zeros((1, 3)))

    extracted_sample = sample.get_dict()
    extracted_patched_sample = patched_sample.get_dict()

    sample_and_patched_sample_equal: Callable[[str, Any], bool] \
        = lambda k, x: bool(
        np.array_equal(extracted_sample[k], extracted_patched_sample[k]) and np.array_equal(extracted_patched_sample[k],
                                                                                            x))

    assert sample_and_patched_sample_equal("image", random_image_crop)
    assert sample_and_patched_sample_equal("mask", random_mask_crop)
    assert sample_and_patched_sample_equal("labels", random_label_crop)

    assert np.array_equal(extracted_patched_sample["mask_center_crop"], random_mask_crop)
    assert np.array_equal(extracted_patched_sample["labels_center_crop"], random_label_crop)
    assert extracted_sample["metadata"] == extracted_patched_sample["metadata"] == metadata
Example #4
0
def test_custom_collate() -> None:
    """
    Tests the custom collate function that collates metadata into lists.
    """
    metadata = PatientMetadata(patient_id='42')
    foo = "foo"
    d1 = {foo: 1, SAMPLE_METADATA_FIELD: "something"}
    d2 = {foo: 2, SAMPLE_METADATA_FIELD: metadata}
    result = collate_with_metadata([d1, d2])
    assert foo in result
    assert SAMPLE_METADATA_FIELD in result
    assert isinstance(result[SAMPLE_METADATA_FIELD], list)
    assert result[SAMPLE_METADATA_FIELD] == ["something", metadata]
    assert isinstance(result[foo], torch.Tensor)
    assert result[foo].tolist() == [1, 2]
def test_visualize_patch_sampling_2d(
        test_output_dirs: TestOutputDirectories) -> None:
    """
    Tests if patch sampling works for 2D images.
    :param test_output_dirs:
    """
    set_random_seed(0)
    shape = (1, 20, 30)
    foreground_classes = ["fg"]
    class_weights = equally_weighted_classes(foreground_classes)
    config = SegmentationModelBase(should_validate=False,
                                   crop_size=(1, 5, 10),
                                   class_weights=class_weights)
    image = np.random.rand(1, *shape).astype(np.float32) * 1000
    mask = np.ones(shape)
    labels = np.zeros((len(class_weights), ) + shape)
    labels[1, 0, 8:12, 5:25] = 1
    labels[0] = 1 - labels[1]
    output_folder = Path(test_output_dirs.root_dir)
    image_header = None
    sample = Sample(image=image,
                    mask=mask,
                    labels=labels,
                    metadata=PatientMetadata(patient_id='123',
                                             image_header=image_header))
    heatmap = visualize_random_crops(sample,
                                     config,
                                     output_folder=output_folder)
    expected_folder = full_ml_test_data_path("patch_sampling")
    expected_heatmap = expected_folder / "sampling_2d.npy"
    # To update the stored results, uncomment this line:
    # np.save(str(expected_heatmap), heatmap)
    assert np.allclose(heatmap, np.load(
        str(expected_heatmap))), "Patch sampling created a different heatmap."
    assert len(list(output_folder.rglob("*.nii.gz"))) == 0
    assert len(list(output_folder.rglob("*.png"))) == 1
    actual_file = output_folder / "123_sampled_patches.png"
    assert_file_exists(actual_file)
    expected = expected_folder / "sampling_2d.png"
    # To update the stored results, uncomment this line:
    # expected.write_bytes(actual_file.read_bytes())
    if not is_running_on_azure():
        # When running on the Azure build agents, it appears that the bounding box of the images
        # is slightly different than on local runs, even with equal dpi settings.
        # It says: Image sizes don't match: actual (685, 469), expected (618, 424)
        # Not able to figure out how to make the run results consistent, hence disable in cloud runs.
        assert_binary_files_match(actual_file, expected)
Example #6
0
def test_get_all_metadata(default_config: ModelConfigBase) -> None:
    df = default_config.get_dataset_splits().train
    assert PatientMetadata.from_dataframe(df, '1') == PatientMetadata(
        patient_id='1', institution="1")
    assert PatientMetadata.from_dataframe(df, '2') == PatientMetadata(
        patient_id='2', institution="2")
def test_visualize_patch_sampling(test_output_dirs: TestOutputDirectories,
                                  labels_to_boundary: bool) -> None:
    """
    Tests if patch sampling and producing diagnostic images works as expected.
    :param test_output_dirs:
    :param labels_to_boundary: If true, the ground truth labels are placed close to the image boundary, so that
    crops have to be adjusted inwards. If false, ground truth labels are all far from the image boundaries.
    """
    set_random_seed(0)
    shape = (10, 30, 30)
    foreground_classes = ["fg"]
    class_weights = equally_weighted_classes(foreground_classes)
    config = SegmentationModelBase(should_validate=False,
                                   crop_size=(2, 10, 10),
                                   class_weights=class_weights)
    image = np.random.rand(1, *shape).astype(np.float32) * 1000
    mask = np.ones(shape)
    labels = np.zeros((len(class_weights), ) + shape)
    if labels_to_boundary:
        # Generate foreground labels in such a way that a patch centered around a foreground pixel would
        # reach outside of the image.
        labels[1, 4:8, 3:27, 3:27] = 1
    else:
        labels[1, 4:8, 15:18, 15:18] = 1
    labels[0] = 1 - labels[1]
    output_folder = Path(test_output_dirs.root_dir)
    image_header = get_unit_image_header()
    sample = Sample(image=image,
                    mask=mask,
                    labels=labels,
                    metadata=PatientMetadata(patient_id='123',
                                             image_header=image_header))
    expected_folder = full_ml_test_data_path("patch_sampling")
    heatmap = visualize_random_crops(sample,
                                     config,
                                     output_folder=output_folder)
    expected_heatmap = expected_folder / ("sampled_to_boundary.npy"
                                          if labels_to_boundary else
                                          "sampled_center.npy")
    # To update the stored results, uncomment this line:
    # np.save(str(expected_heatmap), heatmap)
    assert np.allclose(heatmap, np.load(
        str(expected_heatmap))), "Patch sampling created a different heatmap."
    f1 = output_folder / "123_ct.nii.gz"
    assert_file_exists(f1)
    f2 = output_folder / "123_sampled_patches.nii.gz"
    assert_file_exists(f2)
    thumbnails = [
        "123_sampled_patches_dim0.png",
        "123_sampled_patches_dim1.png",
        "123_sampled_patches_dim2.png",
    ]
    for f in thumbnails:
        assert_file_exists(output_folder / f)

    expected = expected_folder / ("sampled_to_boundary.nii.gz"
                                  if labels_to_boundary else
                                  "sampled_center.nii.gz")
    # To update test results:
    # shutil.copy(str(f2), str(expected))
    expected_image = io_util.load_nifti_image(expected)
    actual_image = io_util.load_nifti_image(f2)
    np.allclose(expected_image.image, actual_image.image)
    if labels_to_boundary:
        for f in thumbnails:
            # Uncomment this line to update test results
            # (expected_folder / f).write_bytes((output_folder / f).read_bytes())
            if not is_running_on_azure():
                # When running on the Azure build agents, it appears that the bounding box of the images
                # is slightly different than on local runs, even with equal dpi settings.
                # Not able to figure out how to make the run results consistent, hence disable in cloud runs.
                assert_binary_files_match(output_folder / f,
                                          expected_folder / f)
Example #8
0

def test_nii_load_zyx(test_output_dirs: OutputFolderForTests) -> None:
    expected_shape = (44, 167, 167)
    file_path = full_ml_test_data_path("patch_sampling/scan_small.nii.gz")
    image: sitk.Image = sitk.ReadImage(str(file_path))
    assert image.GetSize() == reverse_tuple_float3(expected_shape)
    img = sitk.GetArrayFromImage(image)
    assert img.shape == expected_shape
    image_header = io_util.load_nifti_image(file_path)
    assert image_header.image.shape == expected_shape
    assert image_header.header.spacing is not None
    np.testing.assert_allclose(image_header.header.spacing, (3.0, 1.0, 1.0), rtol=0.1)


@pytest.mark.parametrize("metadata", [None, PatientMetadata(patient_id="0")])
@pytest.mark.parametrize("image_channel", [None, known_nii_path, f"{good_h5_path}|volume|0", good_npy_path])
@pytest.mark.parametrize("ground_truth_channel",
                         [None, known_nii_path, f"{good_h5_path}|segmentation|0|1", good_npy_path])
@pytest.mark.parametrize("mask_channel", [None, known_nii_path, good_npy_path])
def test_load_images_from_dataset_source(
        metadata: Optional[str],
        image_channel: Optional[str],
        ground_truth_channel: Optional[str],
        mask_channel: Optional[str]) -> None:
    """
    Test if images are loaded as expected from channels
    """
    # metadata, image and GT channels must be present. Mask is optional
    if None in [metadata, image_channel, ground_truth_channel]:
        with pytest.raises(Exception):
Example #9
0
    Checks if the shapes of the given tensors is equal, and the values are approximately equal, with a given
    absolute tolerance.
    """
    if isinstance(t2, list):
        t2 = torch.tensor(t2)
    assert t1.shape == t2.shape, "Shapes must match"
    # Alternative is to use torch.allclose here, but that method also checks that datatypes match. This makes
    # writing the test cases more cumbersome.
    v1 = t1.flatten().tolist()
    v2 = t2.flatten().tolist()
    assert v1 == pytest.approx(
        v2, abs=abs
    ), f"Tensor elements don't match with tolerance {abs}: {v1} != {v2}"


DummyPatientMetadata = PatientMetadata(patient_id=42)


def get_model_loader(
    namespace: Optional[str] = None
) -> ModelConfigLoader[SegmentationModelBase]:
    """
    Returns a ModelConfigLoader for segmentation models, with the given non-default namespace (if not None)
    to search under.
    """
    return ModelConfigLoader[SegmentationModelBase](
        model_configs_namespace=namespace)


def get_default_azure_config() -> AzureConfig:
    """
Example #10
0
    expected = expected_file.read_bytes()
    if actual == expected:
        return
    if actual_file.suffix == ".png" and expected_file.suffix == ".png":
        actual_image = Image.open(actual_file)
        expected_image = Image.open(expected_file)
        actual_size = actual_image.size
        expected_size = expected_image.size
        assert actual_size == expected_size, f"Image sizes don't match: actual {actual_size}, expected {expected_size}"
        assert np.allclose(
            np.array(actual_image),
            np.array(expected_image)), "Image pixel data does not match."
    assert False, f"File contents does not match: len(actual)={len(actual)}, len(expected)={len(expected)}"


DummyPatientMetadata = PatientMetadata(patient_id='42')


def get_model_loader(
    namespace: Optional[str] = None
) -> ModelConfigLoader[SegmentationModelBase]:
    """
    Returns a ModelConfigLoader for segmentation models, with the given non-default namespace (if not None)
    to search under.
    """
    return ModelConfigLoader[SegmentationModelBase](
        model_configs_namespace=namespace)


def get_default_azure_config() -> AzureConfig:
    """