Example #1
0
    def create_and_set_torch_datasets(self, for_training: bool = True, for_inference: bool = True) -> None:
        """
        Creates torch datasets for all model execution modes, and stores them in the object.
        """
        from InnerEye.ML.dataset.cropping_dataset import CroppingDataset
        from InnerEye.ML.dataset.full_image_dataset import FullImageDataset

        dataset_splits = self.get_dataset_splits()
        crop_transforms = self.get_cropped_image_sample_transforms()
        full_image_transforms = self.get_full_image_sample_transforms()
        if for_training:
            self._datasets_for_training = {
                ModelExecutionMode.TRAIN: CroppingDataset(
                    self,
                    dataset_splits.train,
                    cropped_sample_transforms=crop_transforms.train,  # type: ignore
                    full_image_sample_transforms=full_image_transforms.train),  # type: ignore
                ModelExecutionMode.VAL: CroppingDataset(
                    self, dataset_splits.val,
                    cropped_sample_transforms=crop_transforms.val,  # type: ignore
                    full_image_sample_transforms=full_image_transforms.val),  # type: ignore
            }
        if for_inference:
            self._datasets_for_inference = {
                mode: FullImageDataset(
                    self,
                    dataset_splits[mode],
                    full_image_sample_transforms=full_image_transforms.test)  # type: ignore
                for mode in ModelExecutionMode if len(dataset_splits[mode]) > 0
            }
def test_cropping_dataset_has_reproducible_randomness(cropping_dataset: CroppingDataset,
                                                      num_dataload_workers: int) -> None:
    cropping_dataset.dataset_indices = [1, 2] * 2
    expected_center_indices = None
    for k in range(3):
        ml_util.set_random_seed(1)
        loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=4,
                                                 num_dataload_workers=num_dataload_workers)
        for i, item in enumerate(loader):
            item = CroppedSample.from_dict(sample=item)
            if expected_center_indices is None:
                expected_center_indices = item.center_indices
            else:
                assert np.array_equal(expected_center_indices, item.center_indices)
Example #3
0
def test_cropping_dataset_as_data_loader(cropping_dataset: CroppingDataset,
                                         num_dataload_workers: int) -> None:
    batch_size = 2
    loader = cropping_dataset.as_data_loader(
        shuffle=True,
        batch_size=batch_size,
        num_dataload_workers=num_dataload_workers)
    for i, item in enumerate(loader):
        item = CroppedSample.from_dict(sample=item)
        assert item is not None
        assert item.image.shape == \
               (batch_size,
                cropping_dataset.args.number_of_image_channels) + cropping_dataset.args.crop_size  # type: ignore
        assert item.mask.shape == (
            batch_size, ) + cropping_dataset.args.crop_size  # type: ignore
        assert item.labels.shape == \
               (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.crop_size  # type: ignore
        # check the mask center crops are as expected
        assert item.mask_center_crop.shape == (
            batch_size, ) + cropping_dataset.args.center_size  # type: ignore
        assert item.labels_center_crop.shape == \
               (batch_size, cropping_dataset.args.number_of_classes) + cropping_dataset.args.center_size  # type: ignore

        # check the contents of the center crops
        for b in range(batch_size):
            expected = image_util.get_center_crop(
                image=item.mask[b],
                crop_shape=cropping_dataset.args.center_size)
            assert np.array_equal(item.mask_center_crop[b], expected)

            for c in range(len(item.labels_center_crop[b])):
                expected = image_util.get_center_crop(
                    image=item.labels[b][c],
                    crop_shape=cropping_dataset.args.center_size)
                assert np.array_equal(item.labels_center_crop[b][c], expected)
Example #4
0
def test_cropped_sample(use_mask: bool) -> None:
    ml_util.set_random_seed(1)
    image_size = [4] * 3
    crop_size = (2, 2, 2)
    center_size = (1, 1, 1)

    # create small image sample for random cropping
    image = np.random.uniform(size=[1] + image_size)
    labels = np.zeros(shape=[2] + image_size)
    # Two foreground points in the corners at (0, 0, 0) and (3, 3, 3)
    labels[0] = 1
    labels[0, 0, 0, 0] = 0
    labels[0, 3, 3, 3] = 0
    labels[1, 0, 0, 0] = 1
    labels[1, 3, 3, 3] = 1
    crop_slicer: Optional[slice]
    if use_mask:
        # If mask is used, the cropping center point should be inside the mask.
        # Create a mask that has exactly 1 point of overlap with the labels,
        # that point must then be the center
        mask = np.zeros(shape=image_size, dtype=ImageDataType.MASK.value)
        mask[3, 3, 3] = 1
        expected_center: Optional[List[int]] = [3, 3, 3]
        crop_slicer = slice(2, 4)
    else:
        mask = np.ones(shape=image_size, dtype=ImageDataType.MASK.value)
        expected_center = None
        crop_slicer = None

    sample = Sample(image=image,
                    labels=labels,
                    mask=mask,
                    metadata=DummyPatientMetadata)

    for _ in range(0, 100):
        cropped_sample = CroppingDataset.create_random_cropped_sample(
            sample=sample,
            crop_size=crop_size,
            center_size=center_size,
            class_weights=[0, 1])

        if expected_center is not None:
            assert list(cropped_sample.center_indices
                        ) == expected_center  # type: ignore
            assert np.array_equal(
                cropped_sample.image, sample.image[:, crop_slicer, crop_slicer,
                                                   crop_slicer])
            assert np.array_equal(
                cropped_sample.labels, sample.labels[:, crop_slicer,
                                                     crop_slicer, crop_slicer])
            assert np.array_equal(
                cropped_sample.mask, sample.mask[crop_slicer, crop_slicer,
                                                 crop_slicer])
        else:
            # The crop center point must be any point that has a positive foreground label
            center = cropped_sample.center_indices
            print("Center point chosen: {}".format(center))
            assert labels[1, center[0], center[1], center[2]] != 0
def test_cropping_dataset_padding(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None:
    """
    Tests the data type of torch tensors (e.g. image, labels, and mask) created by the dataset generator,
    which are provided as input into the computational graph
    :return:
    """
    cropping_dataset.args.crop_size = (300, 300, 300)
    cropping_dataset.args.padding_mode = PaddingMode.Zero
    loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=2, num_dataload_workers=1)

    for i, item in enumerate(loader):
        sample = CroppedSample.from_dict(item)
        assert sample.image.shape[-3:] == cropping_dataset.args.crop_size
def test_create_possibly_padded_sample_for_cropping(crop_size: Any) -> None:
    image_size = [4] * 3
    image = np.random.uniform(size=[1] + image_size)
    labels = np.zeros(shape=[2] + image_size)
    mask = np.zeros(shape=image_size, dtype=ImageDataType.MASK.value)

    cropped_sample = CroppingDataset.create_possibly_padded_sample_for_cropping(
        sample=Sample(image=image, labels=labels, mask=mask, metadata=DummyPatientMetadata),
        crop_size=crop_size,
        padding_mode=PaddingMode.Zero
    )

    assert cropped_sample.image.shape[-3:] == crop_size
    assert cropped_sample.labels.shape[-3:] == crop_size
    assert cropped_sample.mask.shape[-3:] == crop_size
def test_cropping_dataset_sample_dtype(cropping_dataset: CroppingDataset, num_dataload_workers: int) -> None:
    """
    Tests the data type of torch tensors (e.g. image, labels, and mask) created by the dataset generator,
    which are provided as input into the computational graph
    :return:
    """
    loader = cropping_dataset.as_data_loader(shuffle=True, batch_size=2,
                                             num_dataload_workers=num_dataload_workers)
    for i, item in enumerate(loader):
        item = CroppedSample.from_dict(item)
        assert item.image.numpy().dtype == ImageDataType.IMAGE.value
        assert item.labels.numpy().dtype == ImageDataType.SEGMENTATION.value
        assert item.mask.numpy().dtype == ImageDataType.MASK.value
        assert item.mask_center_crop.numpy().dtype == ImageDataType.MASK.value
        assert item.labels_center_crop.numpy().dtype == ImageDataType.SEGMENTATION.value
Example #8
0
def cropping_dataset(default_config: SegmentationModelBase,
                     normalize_fn: Callable) -> CroppingDataset:
    df = default_config.get_dataset_splits()
    return CroppingDataset(args=default_config, data_frame=df.train)
Example #9
0
def visualize_random_crops(sample: Sample, config: SegmentationModelBase,
                           output_folder: Path) -> np.ndarray:
    """
    Simulate the effect of sampling random crops (as is done for trainig segmentation models), and store the results
    as a Nifti heatmap and as 3 axial/sagittal/coronal slices. The heatmap and the slices are stored in the given
    output folder, with filenames that contain the patient ID as the prefix.
    :param sample: The patient information from the dataset, with scans and ground truth labels.
    :param config: The model configuration.
    :param output_folder: The folder into which the heatmap and thumbnails should be written.
    :return: A numpy array that has the same size as the image, containing how often each voxel was contained in
    """
    output_folder.mkdir(exist_ok=True, parents=True)
    sample = CroppingDataset.create_possibly_padded_sample_for_cropping(
        sample=sample,
        crop_size=config.crop_size,
        padding_mode=config.padding_mode)
    logging.info(f"Processing sample: {sample.patient_id}")
    # Exhaustively sample with random crop function
    image_channel0 = sample.image[0]
    heatmap = np.zeros(image_channel0.shape, dtype=np.uint16)
    # Number of repeats should fit into the range of UInt16, because we will later save the heatmap as an integer
    # Nifti file of that datatype.
    repeats = 200
    for _ in range(repeats):
        slicers, _ = augmentation.slicers_for_random_crop(
            sample=sample,
            crop_size=config.crop_size,
            class_weights=config.class_weights)
        heatmap[slicers[0], slicers[1], slicers[2]] += 1
    is_3dim = heatmap.shape[0] > 1
    header = sample.metadata.image_header
    if not header:
        logging.warning(
            f"No image header found for patient {sample.patient_id}. Using default header."
        )
        header = get_unit_image_header()
    if is_3dim:
        ct_output_name = str(output_folder / f"{sample.patient_id}_ct.nii.gz")
        heatmap_output_name = str(
            output_folder / f"{sample.patient_id}_sampled_patches.nii.gz")
        io_util.store_as_nifti(image=heatmap,
                               header=header,
                               file_name=heatmap_output_name,
                               image_type=heatmap.dtype,
                               scale=False)
        io_util.store_as_nifti(image=image_channel0,
                               header=header,
                               file_name=ct_output_name,
                               image_type=sample.image.dtype,
                               scale=False)
    heatmap_scaled = heatmap.astype(dtype=np.float) / heatmap.max()
    # If the incoming image is effectively a 2D image with degenerate Z dimension, then only plot a single
    # axial thumbnail. Otherwise, plot thumbnails for all 3 dimensions.
    dimensions = list(range(3)) if is_3dim else [0]
    # Center the 3 thumbnails at one of the points where the heatmap attains a maximum. This should ensure that
    # the thumbnails are in an area where many of the organs of interest are located.
    max_heatmap_index = np.unravel_index(
        heatmap.argmax(), heatmap.shape) if is_3dim else (0, 0, 0)
    for dimension in dimensions:
        plt.clf()
        scan_with_transparent_overlay(
            scan=image_channel0,
            overlay=heatmap_scaled,
            dimension=dimension,
            position=max_heatmap_index[dimension] if is_3dim else 0,
            spacing=header.spacing)
        # Construct a filename that has a dimension suffix if we are generating 3 of them. For 2dim images, skip
        # the suffix.
        thumbnail = f"{sample.patient_id}_sampled_patches"
        if is_3dim:
            thumbnail += f"_dim{dimension}"
        thumbnail += ".png"
        resize_and_save(width_inch=5,
                        height_inch=5,
                        filename=output_folder / thumbnail)
    return heatmap
def main(args: CheckPatchSamplingConfig) -> None:
    # Identify paths to inputs and outputs
    commandline_args = {
        "train_batch_size": 1,
        "local_dataset": Path(args.local_dataset)
    }
    output_folder = Path(args.output_folder)
    output_folder.mkdir(parents=True, exist_ok=True)

    # Create a config file
    config = ModelConfigLoader[SegmentationModelBase](
    ).create_model_config_from_name(args.model_name,
                                    overrides=commandline_args)

    # Set a random seed
    ml_util.set_random_seed(config.random_seed)

    # Get a dataloader object that checks csv
    dataset_splits = config.get_dataset_splits()

    # Load a sample using the full image data loader
    full_image_dataset = FullImageDataset(config, dataset_splits.train)

    for sample_index in range(args.number_samples):
        sample = CroppingDataset.create_possibly_padded_sample_for_cropping(
            sample=full_image_dataset.get_samples_at_index(
                index=sample_index)[0],
            crop_size=config.crop_size,
            padding_mode=config.padding_mode)
        print("Processing sample: ", sample.patient_id)

        # Exhaustively sample with random crop function
        heatmap = np.zeros(sample.mask.shape, dtype=np.uint16)
        for _ in range(args.number_crop_iterations):
            cropped_sample, center_point = augmentation.random_crop(
                sample=sample,
                crop_size=config.crop_size,
                class_weights=config.class_weights)
            patch_mask = create_mask_for_patch(output_shape=heatmap.shape,
                                               output_dtype=heatmap.dtype,
                                               center=center_point,
                                               crop_size=config.crop_size)
            heatmap += patch_mask

        ct_output_name = str(output_folder /
                             "{}_ct.nii.gz".format(int(sample.patient_id)))
        heatmap_output_name = str(
            output_folder /
            "{}_sampled_patches.nii.gz".format(int(sample.patient_id)))
        if not sample.metadata.image_header:
            raise ValueError("None header expected some header")
        io_util.store_as_nifti(image=heatmap,
                               header=sample.metadata.image_header,
                               file_name=heatmap_output_name,
                               image_type=heatmap.dtype,
                               scale=False)
        io_util.store_as_nifti(image=sample.image[0],
                               header=sample.metadata.image_header,
                               file_name=ct_output_name,
                               image_type=sample.image.dtype,
                               scale=False)