Ejemplo n.º 1
0
def test_invalid_crop_size(crop_size: Any) -> None:
    with pytest.raises(Exception):
        random_crop(
            Sample(metadata=DummyPatientMetadata,
                   image=valid_image_4d,
                   labels=valid_labels,
                   mask=valid_mask), crop_size, valid_class_weights)
Ejemplo n.º 2
0
def test_valid_class_weights(class_weights: List[float]) -> None:
    """
    Produce a large number of crops and make sure the crop center class proportions respect class weights
    """
    valid_image_4d, valid_labels, valid_mask = create_valid_image()
    num_classes = len(valid_labels)
    image = np.zeros_like(valid_image_4d)
    labels = np.zeros_like(valid_labels)
    class0, class1, class2 = non_empty_classes = [0, 2, 4]
    labels[class0] = 1
    labels[class0][3, 3, 3] = 0
    labels[class0][3, 2, 3] = 0
    labels[class1][3, 3, 3] = 1
    labels[class2][3, 2, 3] = 1

    mask = np.ones_like(valid_mask)
    sample = Sample(image=image,
                    labels=labels,
                    mask=mask,
                    metadata=DummyPatientMetadata)

    crop_size = (1, 1, 1)
    total_crops = 200
    sampled_label_center_distribution = np.zeros(num_classes)

    # If there is no class that has a non-zero weight and is present in the sample, there is no valid
    # way to select a class, so we expect an exception to be thrown.
    if class_weights is not None and sum(class_weights[c]
                                         for c in non_empty_classes) == 0:
        with pytest.raises(ValueError):
            random_crop(sample, crop_size, class_weights)
        return

    for _ in range(0, total_crops):
        crop_sample, center = random_crop(sample, crop_size, class_weights)
        sampled_class = list(labels[:, center[0], center[1],
                                    center[2]]).index(1)
        sampled_label_center_distribution[sampled_class] += 1

    sampled_label_center_distribution /= total_crops

    if class_weights is None:
        weight = 1.0 / len(non_empty_classes)
        expected_label_center_distribution = [
            weight if c in non_empty_classes else 0.0
            for c in range(number_of_classes)
        ]
    else:
        total = sum(class_weights[c] for c in non_empty_classes)
        expected_label_center_distribution = [
            class_weights[c] / total if c in non_empty_classes else 0.0
            for c in range(number_of_classes)
        ]
    assert np.allclose(sampled_label_center_distribution,
                       expected_label_center_distribution,
                       atol=0.1)
Ejemplo n.º 3
0
def test_invalid_arrays(image: Any, labels: Any, mask: Any,
                        class_weights: Any) -> None:
    """
    Tests failure cases of the random_crop function for invalid image, labels, mask or class
    weights arguments.
    """
    with pytest.raises(Exception):
        random_crop(
            Sample(metadata=DummyPatientMetadata,
                   image=image,
                   labels=labels,
                   mask=mask), valid_crop_size, class_weights)
Ejemplo n.º 4
0
def test_random_crop_no_fg() -> None:
    with pytest.raises(Exception):
        random_crop(
            Sample(metadata=DummyPatientMetadata,
                   image=valid_image_4d,
                   labels=valid_labels,
                   mask=np.zeros_like(valid_mask)), valid_crop_size,
            valid_class_weights)

    with pytest.raises(Exception):
        random_crop(
            Sample(metadata=DummyPatientMetadata,
                   image=valid_image_4d,
                   labels=np.zeros_like(valid_labels),
                   mask=valid_mask), valid_crop_size, valid_class_weights)
Ejemplo n.º 5
0
def test_invalid_arrays(image: Any, labels: Any, mask: Any,
                        class_weights: Any) -> None:
    """
    Tests failure cases of the random_crop function for invalid image, labels, mask or class
    weights arguments.
    """
    # Skip the final combination, because it is valid
    if not (np.array_equal(image, valid_image_4d) and np.array_equal(
            labels, valid_labels) and np.array_equal(mask, valid_mask)
            and class_weights == valid_class_weights):
        with pytest.raises(Exception):
            random_crop(
                Sample(metadata=DummyPatientMetadata,
                       image=image,
                       labels=labels,
                       mask=mask), valid_crop_size, class_weights)
Ejemplo n.º 6
0
def test_valid_full_crop() -> None:
    metadata = DummyPatientMetadata
    sample, _ = random_crop(sample=Sample(image=valid_image_4d,
                                          labels=valid_labels,
                                          mask=valid_mask,
                                          metadata=metadata),
                            crop_size=valid_full_crop_size,
                            class_weights=valid_class_weights)

    assert np.array_equal(sample.image, valid_image_4d)
    assert np.array_equal(sample.labels, valid_labels)
    assert np.array_equal(sample.mask, valid_mask)
    assert sample.metadata == metadata
Ejemplo n.º 7
0
def test_random_crop(crop_size: Any) -> None:
    labels = valid_labels
    # create labels such that there are no foreground voxels in a particular class
    # this should ne handled gracefully (class being ignored from sampling)
    labels[0] = 1
    labels[1] = 0
    sample, _ = random_crop(
        Sample(image=valid_image_4d,
               labels=valid_labels,
               mask=valid_mask,
               metadata=DummyPatientMetadata), crop_size, valid_class_weights)

    expected_img_crop_size = (valid_image_4d.shape[0], *crop_size)
    expected_labels_crop_size = (valid_labels.shape[0], *crop_size)

    assert sample.image.shape == expected_img_crop_size
    assert sample.labels.shape == expected_labels_crop_size
    assert sample.mask.shape == tuple(crop_size)
Ejemplo n.º 8
0
    def create_random_cropped_sample(
            sample: Sample,
            crop_size: TupleInt3,
            center_size: TupleInt3,
            class_weights: Optional[List[float]] = None) -> CroppedSample:
        """
        Creates an instance of a cropped sample extracted from full 3D images.
        :param sample: the full size 3D sample to use for extracting a cropped sample.
        :param crop_size: the size of the crop to extract.
        :param center_size: the size of the center of the crop (this should be the same as the spatial dimensions
                            of the posteriors that the model produces)
        :param class_weights: the distribution to use for the crop center class.
        :return: CroppedSample
        """
        # crop the original raw sample
        sample, center_point = random_crop(sample=sample,
                                           crop_size=crop_size,
                                           class_weights=class_weights)

        # crop the mask and label centers if required
        if center_size == crop_size:
            mask_center_crop = sample.mask
            labels_center_crop = sample.labels
        else:
            mask_center_crop = image_util.get_center_crop(
                image=sample.mask, crop_shape=center_size)
            labels_center_crop = np.zeros(
                shape=[len(sample.labels)] + list(center_size),  # type: ignore
                dtype=ImageDataType.SEGMENTATION.value)
            for c in range(len(sample.labels)):  # type: ignore
                labels_center_crop[c] = image_util.get_center_crop(
                    image=sample.labels[c], crop_shape=center_size)

        return CroppedSample(image=sample.image,
                             mask=sample.mask,
                             labels=sample.labels,
                             mask_center_crop=mask_center_crop,
                             labels_center_crop=labels_center_crop,
                             center_indices=center_point,
                             metadata=sample.metadata)