def test_cropping_dataset_as_data_loader(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None: batch_size = 2 loader = cropping_dataset.as_data_loader( shuffle=True, batch_size=batch_size, num_dataload_workers=num_dataload_workers) for i, item in enumerate(loader): item = CroppedSample.from_dict(sample=item) assert item is not None assert item.image.shape == \ (batch_size, cropping_dataset.args.number_of_image_channels) + cropping_dataset.args.crop_size # type: ignore assert item.mask.shape == ( batch_size, ) + cropping_dataset.args.crop_size # type: ignore assert item.labels.shape == \ (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.crop_size # type: ignore # check the mask center crops are as expected assert item.mask_center_crop.shape == ( batch_size, ) + cropping_dataset.args.center_size # type: ignore assert item.labels_center_crop.shape == \ (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.center_size # type: ignore # check the contents of the center crops for b in range(batch_size): expected = image_util.get_center_crop( image=item.mask[b], crop_shape=cropping_dataset.args.center_size) assert np.array_equal(item.mask_center_crop[b], expected) for c in range(len(item.labels_center_crop[b])): expected = image_util.get_center_crop( image=item.labels[b][c], crop_shape=cropping_dataset.args.center_size) assert np.array_equal(item.labels_center_crop[b][c], expected)
def test_cropping_dataset_padding(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None: """ Tests the data type of torch tensors (e.g. image, labels, and mask) created by the dataset generator, which are provided as input into the computational graph :return: """ cropping_dataset.args.crop_size = (300, 300, 300) cropping_dataset.args.padding_mode = PaddingMode.Zero loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=2, num_dataload_workers=1) for i, item in enumerate(loader): sample = CroppedSample.from_dict(item) assert sample.image.shape[-3:] == cropping_dataset.args.crop_size
def test_cropping_dataset_has_reproducible_randomness(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None: cropping_dataset.dataset_indices = [1, 2] * 2 expected_center_indices = None for k in range(3): ml_util.set_random_seed(1) loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=4, num_dataload_workers=num_dataload_workers) for i, item in enumerate(loader): item = CroppedSample.from_dict(sample=item) if expected_center_indices is None: expected_center_indices = item.center_indices else: assert np.array_equal(expected_center_indices, item.center_indices)
def test_cropping_dataset_sample_dtype(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None: """ Tests the data type of torch tensors (e.g. image, labels, and mask) created by the dataset generator, which are provided as input into the computational graph :return: """ loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=2, num_dataload_workers=num_dataload_workers) for i, item in enumerate(loader): item = CroppedSample.from_dict(item) assert item.image.numpy().dtype == ImageDataType.IMAGE.value assert item.labels.numpy().dtype == ImageDataType.SEGMENTATION.value assert item.mask.numpy().dtype == ImageDataType.MASK.value assert item.mask_center_crop.numpy().dtype == ImageDataType.MASK.value assert item.labels_center_crop.numpy().dtype == ImageDataType.SEGMENTATION.value