Ejemplo n.º 1
0
def test_cropping_dataset_as_data_loader(cropping_dataset: CroppingDataset,
                                         num_dataload_workers: int) -> None:
    batch_size = 2
    loader = cropping_dataset.as_data_loader(
        shuffle=True,
        batch_size=batch_size,
        num_dataload_workers=num_dataload_workers)
    for i, item in enumerate(loader):
        item = CroppedSample.from_dict(sample=item)
        assert item is not None
        assert item.image.shape == \
               (batch_size,
                cropping_dataset.args.number_of_image_channels) + cropping_dataset.args.crop_size  # type: ignore
        assert item.mask.shape == (
            batch_size, ) + cropping_dataset.args.crop_size  # type: ignore
        assert item.labels.shape == \
               (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.crop_size  # type: ignore
        # check the mask center crops are as expected
        assert item.mask_center_crop.shape == (
            batch_size, ) + cropping_dataset.args.center_size  # type: ignore
        assert item.labels_center_crop.shape == \
               (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.center_size  # type: ignore

        # check the contents of the center crops
        for b in range(batch_size):
            expected = image_util.get_center_crop(
                image=item.mask[b],
                crop_shape=cropping_dataset.args.center_size)
            assert np.array_equal(item.mask_center_crop[b], expected)

            for c in range(len(item.labels_center_crop[b])):
                expected = image_util.get_center_crop(
                    image=item.labels[b][c],
                    crop_shape=cropping_dataset.args.center_size)
                assert np.array_equal(item.labels_center_crop[b][c], expected)
def test_cropping_dataset_padding(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None:
    """
    Tests the data type of torch tensors (e.g. image, labels, and mask) created by the dataset generator,
    which are provided as input into the computational graph
    :return:
    """
    cropping_dataset.args.crop_size = (300, 300, 300)
    cropping_dataset.args.padding_mode = PaddingMode.Zero
    loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=2, num_dataload_workers=1)

    for i, item in enumerate(loader):
        sample = CroppedSample.from_dict(item)
        assert sample.image.shape[-3:] == cropping_dataset.args.crop_size
def test_cropping_dataset_has_reproducible_randomness(cropping_dataset: CroppingDataset,
                                                      num_dataload_workers: int) -> None:
    cropping_dataset.dataset_indices = [1, 2] * 2
    expected_center_indices = None
    for k in range(3):
        ml_util.set_random_seed(1)
        loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=4,
                                                 num_dataload_workers=num_dataload_workers)
        for i, item in enumerate(loader):
            item = CroppedSample.from_dict(sample=item)
            if expected_center_indices is None:
                expected_center_indices = item.center_indices
            else:
                assert np.array_equal(expected_center_indices, item.center_indices)
def test_cropping_dataset_sample_dtype(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None:
    """
    Tests the data type of torch tensors (e.g. image, labels, and mask) created by the dataset generator,
    which are provided as input into the computational graph
    :return:
    """
    loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=2,
                                             num_dataload_workers=num_dataload_workers)
    for i, item in enumerate(loader):
        item = CroppedSample.from_dict(item)
        assert item.image.numpy().dtype == ImageDataType.IMAGE.value
        assert item.labels.numpy().dtype == ImageDataType.SEGMENTATION.value
        assert item.mask.numpy().dtype == ImageDataType.MASK.value
        assert item.mask_center_crop.numpy().dtype == ImageDataType.MASK.value
        assert item.labels_center_crop.numpy().dtype == ImageDataType.SEGMENTATION.value