def test_cropping_dataset_as_data_loader(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None: batch_size = 2 loader = cropping_dataset.as_data_loader( shuffle=True, batch_size=batch_size, num_dataload_workers=num_dataload_workers) for i, item in enumerate(loader): item = CroppedSample.from_dict(sample=item) assert item is not None assert item.image.shape == \ (batch_size, cropping_dataset.args.number_of_image_channels) + cropping_dataset.args.crop_size # type: ignore assert item.mask.shape == ( batch_size, ) + cropping_dataset.args.crop_size # type: ignore assert item.labels.shape == \ (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.crop_size # type: ignore # check the mask center crops are as expected assert item.mask_center_crop.shape == ( batch_size, ) + cropping_dataset.args.center_size # type: ignore assert item.labels_center_crop.shape == \ (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.center_size # type: ignore # check the contents of the center crops for b in range(batch_size): expected = image_util.get_center_crop( image=item.mask[b], crop_shape=cropping_dataset.args.center_size) assert np.array_equal(item.mask_center_crop[b], expected) for c in range(len(item.labels_center_crop[b])): expected = image_util.get_center_crop( image=item.labels[b][c], crop_shape=cropping_dataset.args.center_size) assert np.array_equal(item.labels_center_crop[b][c], expected)
def test_get_center_crop() -> None: """ Test to make sure the center crop is extracted correctly from a given image. """ image = np.random.uniform(size=(4, 4, 4)) crop = image_util.get_center_crop(image=image, crop_shape=(2, 2, 2)) expected = image[1:3, 1:3, 1:3] assert np.array_equal(crop, expected)
def create_random_cropped_sample( sample: Sample, crop_size: TupleInt3, center_size: TupleInt3, class_weights: Optional[List[float]] = None) -> CroppedSample: """ Creates an instance of a cropped sample extracted from full 3D images. :param sample: the full size 3D sample to use for extracting a cropped sample. :param crop_size: the size of the crop to extract. :param center_size: the size of the center of the crop (this should be the same as the spatial dimensions of the posteriors that the model produces) :param class_weights: the distribution to use for the crop center class. :return: CroppedSample """ # crop the original raw sample sample, center_point = random_crop(sample=sample, crop_size=crop_size, class_weights=class_weights) # crop the mask and label centers if required if center_size == crop_size: mask_center_crop = sample.mask labels_center_crop = sample.labels else: mask_center_crop = image_util.get_center_crop( image=sample.mask, crop_shape=center_size) labels_center_crop = np.zeros( shape=[len(sample.labels)] + list(center_size), # type: ignore dtype=ImageDataType.SEGMENTATION.value) for c in range(len(sample.labels)): # type: ignore labels_center_crop[c] = image_util.get_center_crop( image=sample.labels[c], crop_shape=center_size) return CroppedSample(image=sample.image, mask=sample.mask, labels=sample.labels, mask_center_crop=mask_center_crop, labels_center_crop=labels_center_crop, center_indices=center_point, metadata=sample.metadata)
def forward(self, patches: np.ndarray) -> torch.Tensor: # type: ignore # simulate models where only the center of the patch is returned image_shape = patches.shape[2:] def shrink_dim(i: int) -> int: return image_shape[i] - 2 * self.shrink_by[i] output_size = (shrink_dim(0), shrink_dim(1), shrink_dim(2)) predictions = torch.zeros(patches.shape[:2] + output_size) for i, patch in enumerate(patches): for j, channel in enumerate(patch): predictions[i, j] = image_util.get_center_crop(image=channel, crop_shape=output_size) return predictions
def from_numpy_crop_and_resize(array: np.ndarray) -> torch.Tensor: if image_size: if not issubclass(array.dtype.type, np.floating): raise ValueError("Array must be of type float.") if array.shape[0] == 1 and not image_size[0] == 1: raise ValueError(f"Input image is 2D with singleton dimension {array.shape}, but parameter " f"image_shape has non-singleton first dimension {image_size}") array = resize(array, image_size, anti_aliasing=True) t = torch.from_numpy(array) if center_crop_size: if array.shape[0] == 1 and not center_crop_size[0] == 1: raise ValueError(f"Input image is 2D with singleton dimension {array.shape}, but parameter " f"center_crop_size has non-singleton first dimension {center_crop_size}") return get_center_crop(t, center_crop_size) return t
def test_get_center_crop_invalid(image: Any, crop_shape: Any) -> None: """ Test that get_center_crop corectly raises an error for invalid arguments """ with pytest.raises(Exception): assert image_util.get_center_crop(image=image, crop_shape=crop_shape)