def test_store_as_nifti_fail(test_output_dirs: TestOutputDirectories, image_type: Any, scale: Any, input_range: Any,
                             output_range: Any) \
        -> None:
    header = ImageHeader(origin=(1, 1, 1), direction=(1, 0, 0, 1, 0, 0, 1, 0, 0), spacing=(1, 2, 4))
    image = np.random.random_sample((dim_z, dim_y, dim_x))
    with pytest.raises(Exception):
        io_util.store_as_nifti(image, header, test_output_dirs.create_file_or_folder_path(default_image_name),
                               image_type, scale, input_range, output_range)
def test_store_as_nifti(test_output_dirs: TestOutputDirectories, image_type: Any, scale: Any, input_range: Any,
                        output_range: Any) \
        -> None:
    image = np.random.random_sample((dim_z, dim_y, dim_x))
    spacingzyx = (1, 2, 3)
    path_image = test_output_dirs.create_file_or_folder_path(default_image_name)
    header = ImageHeader(origin=(1, 1, 1), direction=(1, 0, 0, 0, 1, 0, 0, 0, 1), spacing=spacingzyx)
    io_util.store_as_nifti(image, header, path_image,
                           image_type, scale, input_range, output_range)
    if scale:
        linear_transform = LinearTransform.transform(data=image, input_range=input_range, output_range=output_range)
        image = linear_transform.astype(image_type)  # type: ignore
    assert_nifti_content(test_output_dirs.create_file_or_folder_path(default_image_name),
                         image.shape, header, list(np.unique(image.astype(image_type))), image_type)

    loaded_image = io_util.load_nifti_image(path_image, image_type)
    assert loaded_image.header.spacing == spacingzyx
Example #3
0
def create_smaller_image(image_size: TupleInt3, source_image_dir: Path,
                         target_image_dir: Path, image_file_name: str) -> None:
    """
    Load an image from source_image_dir and create another random image in target_image_dir with same header and
    target size.

    :param image_size: Target image size.
    :param source_image_dir: Source image directory.
    :param target_image_dir: Target image directory.
    :param image_file_name: Common image file name.
    :return: None.
    """
    source_image = io_util.load_nifti_image(source_image_dir / image_file_name)
    source_image_data = source_image.image
    min_data_val = np.min(source_image_data)
    max_data_val = np.max(source_image_data)

    image = np.random.randint(low=min_data_val,
                              high=max_data_val + 1,
                              size=image_size)
    io_util.store_as_nifti(image, source_image.header,
                           target_image_dir / image_file_name, np.short)
Example #4
0
def visualize_random_crops(sample: Sample, config: SegmentationModelBase,
                           output_folder: Path) -> np.ndarray:
    """
    Simulate the effect of sampling random crops (as is done for trainig segmentation models), and store the results
    as a Nifti heatmap and as 3 axial/sagittal/coronal slices. The heatmap and the slices are stored in the given
    output folder, with filenames that contain the patient ID as the prefix.
    :param sample: The patient information from the dataset, with scans and ground truth labels.
    :param config: The model configuration.
    :param output_folder: The folder into which the heatmap and thumbnails should be written.
    :return: A numpy array that has the same size as the image, containing how often each voxel was contained in
    """
    output_folder.mkdir(exist_ok=True, parents=True)
    sample = CroppingDataset.create_possibly_padded_sample_for_cropping(
        sample=sample,
        crop_size=config.crop_size,
        padding_mode=config.padding_mode)
    logging.info(f"Processing sample: {sample.patient_id}")
    # Exhaustively sample with random crop function
    image_channel0 = sample.image[0]
    heatmap = np.zeros(image_channel0.shape, dtype=np.uint16)
    # Number of repeats should fit into the range of UInt16, because we will later save the heatmap as an integer
    # Nifti file of that datatype.
    repeats = 200
    for _ in range(repeats):
        slicers, _ = augmentation.slicers_for_random_crop(
            sample=sample,
            crop_size=config.crop_size,
            class_weights=config.class_weights)
        heatmap[slicers[0], slicers[1], slicers[2]] += 1
    is_3dim = heatmap.shape[0] > 1
    header = sample.metadata.image_header
    if not header:
        logging.warning(
            f"No image header found for patient {sample.patient_id}. Using default header."
        )
        header = get_unit_image_header()
    if is_3dim:
        ct_output_name = str(output_folder / f"{sample.patient_id}_ct.nii.gz")
        heatmap_output_name = str(
            output_folder / f"{sample.patient_id}_sampled_patches.nii.gz")
        io_util.store_as_nifti(image=heatmap,
                               header=header,
                               file_name=heatmap_output_name,
                               image_type=heatmap.dtype,
                               scale=False)
        io_util.store_as_nifti(image=image_channel0,
                               header=header,
                               file_name=ct_output_name,
                               image_type=sample.image.dtype,
                               scale=False)
    heatmap_scaled = heatmap.astype(dtype=np.float) / heatmap.max()
    # If the incoming image is effectively a 2D image with degenerate Z dimension, then only plot a single
    # axial thumbnail. Otherwise, plot thumbnails for all 3 dimensions.
    dimensions = list(range(3)) if is_3dim else [0]
    # Center the 3 thumbnails at one of the points where the heatmap attains a maximum. This should ensure that
    # the thumbnails are in an area where many of the organs of interest are located.
    max_heatmap_index = np.unravel_index(
        heatmap.argmax(), heatmap.shape) if is_3dim else (0, 0, 0)
    for dimension in dimensions:
        plt.clf()
        scan_with_transparent_overlay(
            scan=image_channel0,
            overlay=heatmap_scaled,
            dimension=dimension,
            position=max_heatmap_index[dimension] if is_3dim else 0,
            spacing=header.spacing)
        # Construct a filename that has a dimension suffix if we are generating 3 of them. For 2dim images, skip
        # the suffix.
        thumbnail = f"{sample.patient_id}_sampled_patches"
        if is_3dim:
            thumbnail += f"_dim{dimension}"
        thumbnail += ".png"
        resize_and_save(width_inch=5,
                        height_inch=5,
                        filename=output_folder / thumbnail)
    return heatmap
def main(args: CheckPatchSamplingConfig) -> None:
    # Identify paths to inputs and outputs
    commandline_args = {
        "train_batch_size": 1,
        "local_dataset": Path(args.local_dataset)
    }
    output_folder = Path(args.output_folder)
    output_folder.mkdir(parents=True, exist_ok=True)

    # Create a config file
    config = ModelConfigLoader[SegmentationModelBase](
    ).create_model_config_from_name(args.model_name,
                                    overrides=commandline_args)

    # Set a random seed
    ml_util.set_random_seed(config.random_seed)

    # Get a dataloader object that checks csv
    dataset_splits = config.get_dataset_splits()

    # Load a sample using the full image data loader
    full_image_dataset = FullImageDataset(config, dataset_splits.train)

    for sample_index in range(args.number_samples):
        sample = CroppingDataset.create_possibly_padded_sample_for_cropping(
            sample=full_image_dataset.get_samples_at_index(
                index=sample_index)[0],
            crop_size=config.crop_size,
            padding_mode=config.padding_mode)
        print("Processing sample: ", sample.patient_id)

        # Exhaustively sample with random crop function
        heatmap = np.zeros(sample.mask.shape, dtype=np.uint16)
        for _ in range(args.number_crop_iterations):
            cropped_sample, center_point = augmentation.random_crop(
                sample=sample,
                crop_size=config.crop_size,
                class_weights=config.class_weights)
            patch_mask = create_mask_for_patch(output_shape=heatmap.shape,
                                               output_dtype=heatmap.dtype,
                                               center=center_point,
                                               crop_size=config.crop_size)
            heatmap += patch_mask

        ct_output_name = str(output_folder /
                             "{}_ct.nii.gz".format(int(sample.patient_id)))
        heatmap_output_name = str(
            output_folder /
            "{}_sampled_patches.nii.gz".format(int(sample.patient_id)))
        if not sample.metadata.image_header:
            raise ValueError("None header expected some header")
        io_util.store_as_nifti(image=heatmap,
                               header=sample.metadata.image_header,
                               file_name=heatmap_output_name,
                               image_type=heatmap.dtype,
                               scale=False)
        io_util.store_as_nifti(image=sample.image[0],
                               header=sample.metadata.image_header,
                               file_name=ct_output_name,
                               image_type=sample.image.dtype,
                               scale=False)