Ejemplo n.º 1
0
def assert_sitk_img_equivalence(img: SimpleITK.Image,
                                img_ref: SimpleITK.Image):
    assert img.GetDimension() == img_ref.GetDimension()
    assert img.GetSize() == img_ref.GetSize()
    assert img.GetOrigin() == img_ref.GetOrigin()
    assert img.GetSpacing() == img_ref.GetSpacing()
    assert (img.GetNumberOfComponentsPerPixel() ==
            img_ref.GetNumberOfComponentsPerPixel())
    assert img.GetPixelIDValue() == img_ref.GetPixelIDValue()
    assert img.GetPixelIDTypeAsString() == img_ref.GetPixelIDTypeAsString()
Ejemplo n.º 2
0
    def _assert_3d(self, image: sitk.Image, is_vector=False):
        self.assertEqual(self.properties_3d.size, image.GetSize())

        if is_vector:
            self.assertEqual(self.no_vector_components,
                             image.GetNumberOfComponentsPerPixel())
        else:
            self.assertEqual(1, image.GetNumberOfComponentsPerPixel())

        self.assertEqual(self.origin_spacing_3d, image.GetOrigin())
        self.assertEqual(self.origin_spacing_3d, image.GetSpacing())
        self.assertEqual(self.direction_3d, image.GetDirection())
Ejemplo n.º 3
0
    def _image_as_numpy_array(image: sitk.Image, mask: np.ndarray = None):
        """Gets an image as numpy array where each row is a voxel and each column is a feature.

        Args:
            image (sitk.Image): The image.
            mask (np.ndarray): A mask defining which voxels to return. True is background, False is a masked voxel.

        Returns:
            np.ndarray: An array where each row is a voxel and each column is a feature.
        """

        number_of_components = image.GetNumberOfComponentsPerPixel(
        )  # the number of features for this image
        no_voxels = np.prod(image.GetSize())
        image = sitk.GetArrayFromImage(image)

        if mask is not None:
            no_voxels = np.size(mask) - np.count_nonzero(mask)

            if number_of_components == 1:
                masked_image = np.ma.masked_array(image, mask=mask)
            else:
                # image is a vector image, make a vector mask
                vector_mask = np.expand_dims(
                    mask, axis=3)  # shape is now (z, x, y, 1)
                vector_mask = np.repeat(
                    vector_mask, number_of_components,
                    axis=3)  # shape is now (z, x, y, number_of_components)
                masked_image = np.ma.masked_array(image, mask=vector_mask)

            image = masked_image[~masked_image.mask]

        return image.reshape((no_voxels, number_of_components))
Ejemplo n.º 4
0
def sitk_to_nib(
    image: sitk.Image,
    keepdim: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
    data = sitk.GetArrayFromImage(image).transpose()
    num_components = image.GetNumberOfComponentsPerPixel()
    if num_components == 1:
        data = data[np.newaxis]  # add channels dimension
    input_spatial_dims = image.GetDimension()
    if input_spatial_dims == 2:
        data = data[..., np.newaxis]
    if not keepdim:
        data = ensure_4d(data, num_spatial_dims=input_spatial_dims)
    assert data.shape[0] == num_components
    assert data.shape[1:1 + input_spatial_dims] == image.GetSize()
    spacing = np.array(image.GetSpacing())
    direction = np.array(image.GetDirection())
    origin = image.GetOrigin()
    if len(direction) == 9:
        rotation = direction.reshape(3, 3)
    elif len(direction) == 4:  # ignore first dimension if 2D (1, W, H, 1)
        rotation_2d = direction.reshape(2, 2)
        rotation = np.eye(3)
        rotation[:2, :2] = rotation_2d
        spacing = *spacing, 1
        origin = *origin, 0
    else:
        raise RuntimeError(f'Direction not understood: {direction}')
    rotation = np.dot(FLIP_XY, rotation)
    rotation_zoom = rotation * spacing
    translation = np.dot(FLIP_XY, origin)
    affine = np.eye(4)
    affine[:3, :3] = rotation_zoom
    affine[:3, 3] = translation
    return data, affine
Ejemplo n.º 5
0
    def assert_img_properties(img: SimpleITK.Image,
                              internal_image: SimpleITKImage):
        color_space = {
            1: ColorSpace.GRAY,
            3: ColorSpace.RGB,
            4: ColorSpace.RGBA,
        }

        assert internal_image.color_space == color_space.get(
            img.GetNumberOfComponentsPerPixel())
        if img.GetDimension() == 4:
            assert internal_image.timepoints == img.GetSize()[-1]
        else:
            assert internal_image.timepoints is None
        if img.GetDepth():
            assert internal_image.depth == img.GetDepth()
            assert internal_image.voxel_depth_mm == img.GetSpacing()[2]
        else:
            assert internal_image.depth is None
            assert internal_image.voxel_depth_mm is None

        assert internal_image.width == img.GetWidth()
        assert internal_image.height == img.GetHeight()
        assert internal_image.voxel_width_mm == approx(img.GetSpacing()[0])
        assert internal_image.voxel_height_mm == approx(img.GetSpacing()[1])
Ejemplo n.º 6
0
    def preprocess_image(self, reg_image: sitk.Image) -> None:
        """
        Run full intensity and spatial preprocessing. Creates the `reg_image` attribute

        Parameters
        ----------
        reg_image: sitk.Image
            Raw form of image to be preprocessed

        """

        reg_image = self.preprocess_reg_image_intensity(
            reg_image, self.preprocessing)

        if reg_image.GetDepth() >= 1:
            raise ValueError(
                "preprocessing did not result in a single image plane\n"
                "multi-channel or 3D image return")

        if reg_image.GetNumberOfComponentsPerPixel() > 1:
            raise ValueError(
                "preprocessing did not result in a single image plane\n"
                "multi-component / RGB(A) image returned")

        reg_image, pre_reg_transforms = self.preprocess_reg_image_spatial(
            reg_image, self.preprocessing, self.pre_reg_transforms)

        if len(pre_reg_transforms) > 0:
            self.pre_reg_transforms = pre_reg_transforms

        self._reg_image = reg_image
Ejemplo n.º 7
0
def sitk_to_nib(
    image: sitk.Image,
    keepdim: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
    data = sitk.GetArrayFromImage(image).transpose()
    num_components = image.GetNumberOfComponentsPerPixel()
    if num_components == 1:
        data = data[np.newaxis]  # add channels dimension
    input_spatial_dims = image.GetDimension()
    if not keepdim:
        data = ensure_4d(data, False, num_spatial_dims=input_spatial_dims)
    assert data.shape[0] == num_components
    assert data.shape[-input_spatial_dims:] == image.GetSize()
    spacing = np.array(image.GetSpacing())
    direction = np.array(image.GetDirection())
    origin = image.GetOrigin()
    if len(direction) == 9:
        rotation = direction.reshape(3, 3)
    elif len(direction) == 4:  # ignore first dimension if 2D (1, 1, H, W)
        rotation_2d = direction.reshape(2, 2)
        rotation = np.eye(3)
        rotation[1:3, 1:3] = rotation_2d
        spacing = 1, *spacing
        origin = 0, *origin
    rotation = np.dot(FLIP_XY, rotation)
    rotation_zoom = rotation * spacing
    translation = np.dot(FLIP_XY, origin)
    affine = np.eye(4)
    affine[:3, :3] = rotation_zoom
    affine[:3, 3] = translation
    return data, affine
Ejemplo n.º 8
0
def check_if_image_is_rgb(image: sitk.Image):
    """Check if an image is RGB by looking for a 3-component 8 bit pixel image"""
    components = image.GetNumberOfComponentsPerPixel()
    pixel_type = image.GetPixelIDTypeAsString()

    if components == 3 and pixel_type == 'vector of 8-bit unsigned integer':
        image_is_rgb = True
    else:
        image_is_rgb = False

    return image_is_rgb
Ejemplo n.º 9
0
def plot_2d_segmentation_series(path: str,
                                file_name_suffix: str,
                                image: sitk.Image,
                                ground_truth: sitk.Image,
                                segmentation: sitk.Image,
                                alpha: float = 0.5,
                                label: int = 1,
                                file_extension: str = '.png') -> None:
    """Plots an image with an overlaid mask, which indicates under-, correct-, and over-segmentation.

    Args:
        path (str): The output directory path.
        file_name_suffix (str): The output file name suffix.
        image (sitk.Image): The image.
        ground_truth (sitk.Image): The ground truth.
        segmentation (sitk.Image): The segmentation.
        alpha (float): The alpha blending value, between 0 (transparent) and 1 (opaque).
        label (int): The ground truth and segmentation label.
        file_extension (str): The output file extension (with or without dot).

    Examples:
        >>> img_t2 = sitk.ReadImage('your/path/image.mha')
        >>> ground_truth = sitk.ReadImage('your/path/ground_truth.mha')
        >>> segmentation = sitk.ReadImage('your/path/segmentation.mha')
        >>> plot_2d_segmentation_series('/your/path/', 'mysegmentation', img_t2, ground_truth, segmentation)
    """

    if not image.GetSize() == ground_truth.GetSize() == segmentation.GetSize():
        raise ValueError(
            'image, ground_truth, and segmentation must have equal size')
    if not image.GetDimension() == 3:
        raise ValueError('only 3-dimensional images supported')
    if not image.GetNumberOfComponentsPerPixel() == 1:
        raise ValueError('only scalar images supported')

    img_arr = sitk.GetArrayFromImage(image)
    gt_arr = sitk.GetArrayFromImage(ground_truth)
    seg_arr = sitk.GetArrayFromImage(segmentation)

    os.makedirs(path, exist_ok=True)
    file_extension = file_extension if file_extension.startswith(
        '.') else '.' + file_extension

    for slice in range(img_arr.shape[0]):
        full_file_path = os.path.join(
            path, file_name_suffix + str(slice) + file_extension)
        plot_2d_segmentation(full_file_path,
                             img_arr[slice, ...],
                             gt_arr[slice, ...],
                             seg_arr[slice, ...],
                             alpha=alpha,
                             label=label)
Ejemplo n.º 10
0
    def __init__(self, image: sitk.Image):
        """Initializes a new instance of the ImageInformation class.

        Args:
            image (sitk.Image): The image whose properties to hold.
        """
        self.size = image.GetSize()
        self.origin = image.GetOrigin()
        self.spacing = image.GetSpacing()
        self.direction = image.GetDirection()
        self.dimensions = image.GetDimension()
        self.number_of_components_per_pixel = image.GetNumberOfComponentsPerPixel()
        self.pixel_id = image.GetPixelID()
Ejemplo n.º 11
0
    def __init__(self, image: sitk.Image):
        """Represents ITK image properties.

        Holds common ITK image meta-data such as the size, origin, spacing, and direction.

        See Also:
            SimpleITK provides `itk::simple::Image::CopyInformation`_ to copy image information.

        .. _itk::simple::Image::CopyInformation:
            https://itk.org/SimpleITKDoxygen/html/classitk_1_1simple_1_1Image.html#afa8a4757400c414e809d1767ee616bd0

        Args:
            image (sitk.Image): The image whose properties to hold.
        """
        self.size = image.GetSize()
        self.origin = image.GetOrigin()
        self.spacing = image.GetSpacing()
        self.direction = image.GetDirection()
        self.dimensions = image.GetDimension()
        self.number_of_components_per_pixel = image.GetNumberOfComponentsPerPixel()
        self.pixel_id = image.GetPixelID()
Ejemplo n.º 12
0
 def get_reference_image(
     floating_sitk: sitk.Image,
     spacing: TypeTripletFloat,
 ) -> sitk.Image:
     old_spacing = np.array(floating_sitk.GetSpacing())
     new_spacing = np.array(spacing)
     old_size = np.array(floating_sitk.GetSize())
     new_size = old_size * old_spacing / new_spacing
     new_size = np.ceil(new_size).astype(np.uint16)
     new_size[old_size == 1] = 1  # keep singleton dimensions
     new_origin_index = 0.5 * (new_spacing / old_spacing - 1)
     new_origin_lps = floating_sitk.TransformContinuousIndexToPhysicalPoint(
         new_origin_index)
     reference = sitk.Image(
         new_size.tolist(),
         floating_sitk.GetPixelID(),
         floating_sitk.GetNumberOfComponentsPerPixel(),
     )
     reference.SetDirection(floating_sitk.GetDirection())
     reference.SetSpacing(new_spacing.tolist())
     reference.SetOrigin(new_origin_lps)
     return reference
Ejemplo n.º 13
0
def smooth_probabilities(image: sitk.Image, variance=1.0) -> sitk.Image:
    """ Gaussian smoothing of the probability image

    Args:
        image (sitk.Image): vector image with probabilities for each pixel and label
        variance (Float): Variance of the gaussian smoothing

    Returns (sitk.Image): smoothed vector image

    """
    # get number of labels
    n_labels = image.GetNumberOfComponentsPerPixel()

    # generate filter to extract probabilities of each label
    extract_label_filter = sitk.VectorIndexSelectionCastImageFilter()

    # generate gaussian filter to smooth the probability images
    gauss_filter = sitk.DiscreteGaussianImageFilter()
    gauss_filter.SetVariance(variance)

    # generate vector to store the different label probabilities image
    images_probabilities = []

    # extract the probabilities for a single label and smooth the "image"
    for label in range(n_labels):
        probabilities = extract_label_filter.Execute(image, label,
                                                     sitk.sitkFloat32)
        images_probabilities.append(gauss_filter.Execute(probabilities))

    # compose the single images back to a vector image
    compose_filter = sitk.ComposeImageFilter()
    img_out = compose_filter.Execute(images_probabilities)

    # arr_image = sitk.GetArrayFromImage(img_out)
    # plt.imshow(arr_image[:, :, 90], cmap='jet')
    # plt.show()

    return img_out
Ejemplo n.º 14
0
def sitk_to_nib(
    image: sitk.Image,
    keepdim: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
    data = sitk.GetArrayFromImage(image).transpose()
    data = check_uint_to_int(data)
    num_components = image.GetNumberOfComponentsPerPixel()
    if num_components == 1:
        data = data[np.newaxis]  # add channels dimension
    input_spatial_dims = image.GetDimension()
    if input_spatial_dims == 2:
        data = data[..., np.newaxis]
    elif input_spatial_dims == 4:  # probably a bad NIfTI (1, sx, sy, sz, c)
        # Try to fix it
        num_components = data.shape[-1]
        data = data[0]
        data = data.transpose(3, 0, 1, 2)
        input_spatial_dims = 3
    if not keepdim:
        data = ensure_4d(data, num_spatial_dims=input_spatial_dims)
    assert data.shape[0] == num_components
    affine = get_ras_affine_from_sitk(image)
    return data, affine
Ejemplo n.º 15
0
    def write(self, segmentation: sitk.Image,
              source_images: List[pydicom.Dataset]) -> pydicom.Dataset:
        """Writes a DICOM-SEG dataset from a segmentation image and the
        corresponding DICOM source images.

        Args:
            segmentation: A `SimpleITK.Image` with integer labels and a single
                component per spatial location.
            source_images: A list of `pydicom.Dataset` which are the
                source images for the segmentation image.

        Returns:
            A `pydicom.Dataset` instance with all necessary information and
            meta information for writing the dataset to disk.
        """
        if segmentation.GetDimension() != 3:
            raise ValueError("Only 3D segmentation data is supported")

        if segmentation.GetNumberOfComponentsPerPixel() > 1:
            raise ValueError("Multi-class segmentations can only be "
                             "represented with a single component per voxel")

        if segmentation.GetPixelID() not in [
                sitk.sitkUInt8,
                sitk.sitkUInt16,
                sitk.sitkUInt32,
                sitk.sitkUInt64,
        ]:
            raise ValueError("Unsigned integer data type required")

        # TODO Add further checks if source images are from the same series
        slice_to_source_images = self._map_source_images_to_segmentation(
            segmentation, source_images)

        # Compute unique labels and their respective bounding boxes
        label_statistics_filter = sitk.LabelStatisticsImageFilter()
        label_statistics_filter.Execute(segmentation, segmentation)
        unique_labels = set(
            [x for x in label_statistics_filter.GetLabels() if x != 0])
        if len(unique_labels) == 0:
            raise ValueError("Segmentation does not contain any labels")

        # Check if all present labels where declared in the DICOM template
        declared_segments = set(
            [x.SegmentNumber for x in self._template.SegmentSequence])
        missing_declarations = unique_labels.difference(declared_segments)
        if missing_declarations:
            missing_segment_numbers = ", ".join(
                [str(x) for x in missing_declarations])
            message = (
                f"Skipping segment(s) {missing_segment_numbers}, since their "
                "declaration is missing in the DICOM template")
            if not self._skip_missing_segment:
                raise ValueError(message)
            logger.warning(message)
        labels_to_process = unique_labels.intersection(declared_segments)
        if not labels_to_process:
            raise ValueError("No segments found for encoding as DICOM-SEG")

        # Compute bounding boxes for each present label and optionally restrict
        # the volume to serialize to the joined maximum extent
        bboxs = {
            x: label_statistics_filter.GetBoundingBox(x)
            for x in labels_to_process
        }
        if self._inplane_cropping:
            min_x, min_y, _ = np.min([x[::2] for x in bboxs.values()],
                                     axis=0).tolist()
            max_x, max_y, _ = (
                np.max([x[1::2]
                        for x in bboxs.values()], axis=0) + 1).tolist()
            logger.info(
                "Serializing cropped image planes starting at coordinates "
                f"({min_x}, {min_y}) with size ({max_x - min_x}, {max_y - min_y})"
            )
        else:
            min_x, min_y = 0, 0
            max_x, max_y = segmentation.GetWidth(), segmentation.GetHeight()
            logger.info(
                f"Serializing image planes at full size ({max_x}, {max_y})")

        # Create target dataset for storing serialized data
        result = SegmentationDataset(
            reference_dicom=source_images[0] if source_images else None,
            rows=max_y - min_y,
            columns=max_x - min_x,
            segmentation_type=SegmentationType.BINARY,
        )
        dimension_organization = DimensionOrganizationSequence()
        dimension_organization.add_dimension("ReferencedSegmentNumber",
                                             "SegmentIdentificationSequence")
        dimension_organization.add_dimension("ImagePositionPatient",
                                             "PlanePositionSequence")
        result.add_dimension_organization(dimension_organization)
        writer_utils.copy_segmentation_template(
            target=result,
            template=self._template,
            segments=labels_to_process,
            skip_missing_segment=self._skip_missing_segment,
        )
        writer_utils.set_shared_functional_groups_sequence(
            target=result, segmentation=segmentation)

        # FIX - Use ImageOrientationPatient value from DICOM source rather than the segmentation
        result.SharedFunctionalGroupsSequence[0].PlaneOrientationSequence[
            0].ImageOrientationPatient = source_images[
                0].ImageOrientationPatient

        buffer = sitk.GetArrayFromImage(segmentation)
        for segment in labels_to_process:
            logger.info(f"Processing segment {segment}")

            if self._skip_empty_slices:
                bbox = bboxs[segment]
                min_z, max_z = bbox[4], bbox[5] + 1
            else:
                min_z, max_z = 0, segmentation.GetDepth()
            logger.info(
                "Total number of slices that will be processed for segment "
                f"{segment} is {max_z - min_z} (inclusive from {min_z} to {max_z})"
            )

            skipped_slices = []
            for slice_idx in range(min_z, max_z):
                frame_index = (min_x, min_y, slice_idx)
                frame_position = segmentation.TransformIndexToPhysicalPoint(
                    frame_index)
                frame_data = np.equal(
                    buffer[slice_idx, min_y:max_y, min_x:max_x], segment)
                if self._skip_empty_slices and not frame_data.any():
                    skipped_slices.append(slice_idx)
                    continue

                frame_fg_item = result.add_frame(
                    data=frame_data.astype(np.uint8),
                    referenced_segment=segment,
                    referenced_images=slice_to_source_images[slice_idx],
                )

                frame_fg_item.FrameContentSequence = [pydicom.Dataset()]
                frame_fg_item.FrameContentSequence[0].DimensionIndexValues = [
                    segment,  # Segment number
                    slice_idx - min_z + 1,  # Slice index within cropped volume
                ]
                frame_fg_item.PlanePositionSequence = [pydicom.Dataset()]
                frame_fg_item.PlanePositionSequence[0].ImagePositionPatient = [
                    f"{x:e}" for x in frame_position
                ]

            if skipped_slices:
                logger.info(f"Skipped empty slices for segment {segment}: "
                            f'{", ".join([str(x) for x in skipped_slices])}')

        # Encode all frames into a bytearray
        if self._inplane_cropping or self._skip_empty_slices:
            num_encoded_bytes = len(result.PixelData)
            max_encoded_bytes = (segmentation.GetWidth() *
                                 segmentation.GetHeight() *
                                 segmentation.GetDepth() *
                                 len(result.SegmentSequence) // 8)
            savings = (1 - num_encoded_bytes / max_encoded_bytes) * 100
            logger.info(
                f"Optimized frame data length is {num_encoded_bytes:,}B "
                f"instead of {max_encoded_bytes:,}B (saved {savings:.2f}%)")

        result.SegmentsOverlap = "NO"

        return result