Пример #1
0
def test_cropping_dataset_as_data_loader(cropping_dataset: CroppingDataset,
                                         num_dataload_workers: int) -> None:
    batch_size = 2
    loader = cropping_dataset.as_data_loader(
        shuffle=True,
        batch_size=batch_size,
        num_dataload_workers=num_dataload_workers)
    for i, item in enumerate(loader):
        item = CroppedSample.from_dict(sample=item)
        assert item is not None
        assert item.image.shape == \
               (batch_size,
                cropping_dataset.args.number_of_image_channels) + cropping_dataset.args.crop_size  # type: ignore
        assert item.mask.shape == (
            batch_size, ) + cropping_dataset.args.crop_size  # type: ignore
        assert item.labels.shape == \
               (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.crop_size  # type: ignore
        # check the mask center crops are as expected
        assert item.mask_center_crop.shape == (
            batch_size, ) + cropping_dataset.args.center_size  # type: ignore
        assert item.labels_center_crop.shape == \
               (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.center_size  # type: ignore

        # check the contents of the center crops
        for b in range(batch_size):
            expected = image_util.get_center_crop(
                image=item.mask[b],
                crop_shape=cropping_dataset.args.center_size)
            assert np.array_equal(item.mask_center_crop[b], expected)

            for c in range(len(item.labels_center_crop[b])):
                expected = image_util.get_center_crop(
                    image=item.labels[b][c],
                    crop_shape=cropping_dataset.args.center_size)
                assert np.array_equal(item.labels_center_crop[b][c], expected)
Пример #2
0
def test_get_center_crop() -> None:
    """
    Test to make sure the center crop is extracted correctly from a given image.
    """
    image = np.random.uniform(size=(4, 4, 4))
    crop = image_util.get_center_crop(image=image, crop_shape=(2, 2, 2))
    expected = image[1:3, 1:3, 1:3]
    assert np.array_equal(crop, expected)
Пример #3
0
    def create_random_cropped_sample(
            sample: Sample,
            crop_size: TupleInt3,
            center_size: TupleInt3,
            class_weights: Optional[List[float]] = None) -> CroppedSample:
        """
        Creates an instance of a cropped sample extracted from full 3D images.
        :param sample: the full size 3D sample to use for extracting a cropped sample.
        :param crop_size: the size of the crop to extract.
        :param center_size: the size of the center of the crop (this should be the same as the spatial dimensions
                            of the posteriors that the model produces)
        :param class_weights: the distribution to use for the crop center class.
        :return: CroppedSample
        """
        # crop the original raw sample
        sample, center_point = random_crop(sample=sample,
                                           crop_size=crop_size,
                                           class_weights=class_weights)

        # crop the mask and label centers if required
        if center_size == crop_size:
            mask_center_crop = sample.mask
            labels_center_crop = sample.labels
        else:
            mask_center_crop = image_util.get_center_crop(
                image=sample.mask, crop_shape=center_size)
            labels_center_crop = np.zeros(
                shape=[len(sample.labels)] + list(center_size),  # type: ignore
                dtype=ImageDataType.SEGMENTATION.value)
            for c in range(len(sample.labels)):  # type: ignore
                labels_center_crop[c] = image_util.get_center_crop(
                    image=sample.labels[c], crop_shape=center_size)

        return CroppedSample(image=sample.image,
                             mask=sample.mask,
                             labels=sample.labels,
                             mask_center_crop=mask_center_crop,
                             labels_center_crop=labels_center_crop,
                             center_indices=center_point,
                             metadata=sample.metadata)
    def forward(self, patches: np.ndarray) -> torch.Tensor:  # type: ignore
        # simulate models where only the center of the patch is returned
        image_shape = patches.shape[2:]

        def shrink_dim(i: int) -> int:
            return image_shape[i] - 2 * self.shrink_by[i]

        output_size = (shrink_dim(0), shrink_dim(1), shrink_dim(2))
        predictions = torch.zeros(patches.shape[:2] + output_size)
        for i, patch in enumerate(patches):
            for j, channel in enumerate(patch):
                predictions[i, j] = image_util.get_center_crop(image=channel, crop_shape=output_size)

        return predictions
Пример #5
0
 def from_numpy_crop_and_resize(array: np.ndarray) -> torch.Tensor:
     if image_size:
         if not issubclass(array.dtype.type, np.floating):
             raise ValueError("Array must be of type float.")
         if array.shape[0] == 1 and not image_size[0] == 1:
             raise ValueError(f"Input image is 2D with singleton dimension {array.shape}, but parameter "
                              f"image_shape has non-singleton first dimension {image_size}")
         array = resize(array, image_size, anti_aliasing=True)
     t = torch.from_numpy(array)
     if center_crop_size:
         if array.shape[0] == 1 and not center_crop_size[0] == 1:
             raise ValueError(f"Input image is 2D with singleton dimension {array.shape}, but parameter "
                              f"center_crop_size has non-singleton first dimension {center_crop_size}")
         return get_center_crop(t, center_crop_size)
     return t
Пример #6
0
def test_get_center_crop_invalid(image: Any, crop_shape: Any) -> None:
    """
    Test that get_center_crop corectly raises an error for invalid arguments
    """
    with pytest.raises(Exception):
        assert image_util.get_center_crop(image=image, crop_shape=crop_shape)