Esempio n. 1
0
def _write_dicom_slice(writer: sitk.ImageFileWriter,
                       series_tag_values: Dict[DicomTags, str],
                       image: sitk.Image, folder: Path, i: int) -> None:
    """
    Write a DICOM slice as a single file.

    :param writer: sitk ImageFileWriter.
    :param series_tag_values: DICOM tags.
    :param image: Image to slice.
    :param folder: Folder to store slice in.
    :param i: Slice number.
    """
    instance_number = str(i)
    image_position_patient = '\\'.join(
        map(str, image.TransformIndexToPhysicalPoint((0, 0, i))))

    # Copy all series tags and add specific tags for this slice.
    slice_tag_values = series_tag_values.copy()
    slice_tag_values.update({
        DicomTags.InstanceNumber:
        instance_number,
        DicomTags.ImagePositionPatient:
        image_position_patient,
    })

    image_slice = image[:, :, i]

    for tag, value in slice_tag_values.items():
        image_slice.SetMetaData(tag.value, value)

    slice_filename = folder / (str(i) + '.dcm')
    writer.SetFileName(str(slice_filename))
    writer.Execute(image_slice)
Esempio n. 2
0
def create_circle_mask_itk(
    image_itk: sitk.Image,
    world_centers: Sequence[Sequence[float]],
    world_rads: Sequence[float],
    ndim: int = 3,
) -> sitk.Image:
    """
    Creates an itk image with circles defined by center points and radii

    Args:
        image_itk: original image (used for the coordinate frame)
        world_centers: Sequence of center points in world coordiantes (x, y, z)
        world_rads: Sequence of radii to use
        ndim: number of spatial dimensions

    Returns:
        sitk.Image: mask with circles
    """
    image_np = sitk.GetArrayFromImage(image_itk)
    min_spacing = min(image_itk.GetSpacing())

    if image_np.ndim > ndim:
        image_np = image_np[0]
    mask_np = np.zeros_like(image_np).astype(np.uint8)

    for _id, (world_center,
              world_rad) in enumerate(zip(world_centers, world_rads), start=1):
        check_rad = (world_rad / min_spacing) * 1.5  # add some buffer to it
        bounds = []
        center = image_itk.TransformPhysicalPointToContinuousIndex(
            world_center)[::-1]
        for ax, c in enumerate(center):
            bounds.append((
                max(0, int(c - check_rad)),
                min(mask_np.shape[ax], int(c + check_rad)),
            ))
        coord_box = product(*[list(range(b[0], b[1])) for b in bounds])

        # loop over every pixel position
        for coord in coord_box:
            world_coord = image_itk.TransformIndexToPhysicalPoint(
                tuple(reversed(coord)))  # reverse order to x, y, z for sitk
            dist = np.linalg.norm(
                np.array(world_coord) - np.array(world_center))
            if dist <= world_rad:
                mask_np[tuple(coord)] = _id
        assert mask_np.max() == _id

    mask_itk = sitk.GetImageFromArray(mask_np)
    return copy_meta_data_itk(image_itk, mask_itk)
Esempio n. 3
0
def rotate(image: sitk.Image,
           rotation_centre: Sequence[float],
           angles: Union[float, Sequence[float]],
           interpolation: str = "linear") -> sitk.Image:
    """Rotate an image around a given centre.

    Parameters
    ----------
    image
        The image to rotate.

    rotation_centre
        The centre of rotation in image coordinates.

    angles
        The angles of rotation around x, y and z axes.

    Returns
    -------
    sitk.Image
        The rotated image.
    """
    if isinstance(rotation_centre, np.ndarray):
        rotation_centre = rotation_centre.tolist()

    rotation_centre = image.TransformIndexToPhysicalPoint(rotation_centre)

    if image.GetDimension() == 2:
        rotation = sitk.Euler2DTransform(
            rotation_centre,
            angles,
            (0., 0.)  # no translation
        )
    elif image.GetDimension() == 3:
        x_angle, y_angle, z_angle = angles

        rotation = sitk.Euler3DTransform(
            rotation_centre,
            x_angle,  # the angle of rotation around the x-axis, in radians -> coronal rotation
            y_angle,  # the angle of rotation around the y-axis, in radians -> saggittal rotation
            z_angle,  # the angle of rotation around the z-axis, in radians -> axial rotation
            (0., 0., 0.)  # no translation
        )
    return resample(image,
                    spacing=image.GetSpacing(),
                    interpolation=interpolation,
                    transform=rotation)
Esempio n. 4
0
    def write(self, segmentation: sitk.Image,
              source_images: List[pydicom.Dataset]) -> pydicom.Dataset:
        """Writes a DICOM-SEG dataset from a segmentation image and the
        corresponding DICOM source images.

        Args:
            segmentation: A `SimpleITK.Image` with integer labels and a single
                component per spatial location.
            source_images: A list of `pydicom.Dataset` which are the
                source images for the segmentation image.

        Returns:
            A `pydicom.Dataset` instance with all necessary information and
            meta information for writing the dataset to disk.
        """
        if segmentation.GetDimension() != 3:
            raise ValueError("Only 3D segmentation data is supported")

        if segmentation.GetNumberOfComponentsPerPixel() > 1:
            raise ValueError("Multi-class segmentations can only be "
                             "represented with a single component per voxel")

        if segmentation.GetPixelID() not in [
                sitk.sitkUInt8,
                sitk.sitkUInt16,
                sitk.sitkUInt32,
                sitk.sitkUInt64,
        ]:
            raise ValueError("Unsigned integer data type required")

        # TODO Add further checks if source images are from the same series
        slice_to_source_images = self._map_source_images_to_segmentation(
            segmentation, source_images)

        # Compute unique labels and their respective bounding boxes
        label_statistics_filter = sitk.LabelStatisticsImageFilter()
        label_statistics_filter.Execute(segmentation, segmentation)
        unique_labels = set(
            [x for x in label_statistics_filter.GetLabels() if x != 0])
        if len(unique_labels) == 0:
            raise ValueError("Segmentation does not contain any labels")

        # Check if all present labels where declared in the DICOM template
        declared_segments = set(
            [x.SegmentNumber for x in self._template.SegmentSequence])
        missing_declarations = unique_labels.difference(declared_segments)
        if missing_declarations:
            missing_segment_numbers = ", ".join(
                [str(x) for x in missing_declarations])
            message = (
                f"Skipping segment(s) {missing_segment_numbers}, since their "
                "declaration is missing in the DICOM template")
            if not self._skip_missing_segment:
                raise ValueError(message)
            logger.warning(message)
        labels_to_process = unique_labels.intersection(declared_segments)
        if not labels_to_process:
            raise ValueError("No segments found for encoding as DICOM-SEG")

        # Compute bounding boxes for each present label and optionally restrict
        # the volume to serialize to the joined maximum extent
        bboxs = {
            x: label_statistics_filter.GetBoundingBox(x)
            for x in labels_to_process
        }
        if self._inplane_cropping:
            min_x, min_y, _ = np.min([x[::2] for x in bboxs.values()],
                                     axis=0).tolist()
            max_x, max_y, _ = (
                np.max([x[1::2]
                        for x in bboxs.values()], axis=0) + 1).tolist()
            logger.info(
                "Serializing cropped image planes starting at coordinates "
                f"({min_x}, {min_y}) with size ({max_x - min_x}, {max_y - min_y})"
            )
        else:
            min_x, min_y = 0, 0
            max_x, max_y = segmentation.GetWidth(), segmentation.GetHeight()
            logger.info(
                f"Serializing image planes at full size ({max_x}, {max_y})")

        # Create target dataset for storing serialized data
        result = SegmentationDataset(
            reference_dicom=source_images[0] if source_images else None,
            rows=max_y - min_y,
            columns=max_x - min_x,
            segmentation_type=SegmentationType.BINARY,
        )
        dimension_organization = DimensionOrganizationSequence()
        dimension_organization.add_dimension("ReferencedSegmentNumber",
                                             "SegmentIdentificationSequence")
        dimension_organization.add_dimension("ImagePositionPatient",
                                             "PlanePositionSequence")
        result.add_dimension_organization(dimension_organization)
        writer_utils.copy_segmentation_template(
            target=result,
            template=self._template,
            segments=labels_to_process,
            skip_missing_segment=self._skip_missing_segment,
        )
        writer_utils.set_shared_functional_groups_sequence(
            target=result, segmentation=segmentation)

        # FIX - Use ImageOrientationPatient value from DICOM source rather than the segmentation
        result.SharedFunctionalGroupsSequence[0].PlaneOrientationSequence[
            0].ImageOrientationPatient = source_images[
                0].ImageOrientationPatient

        buffer = sitk.GetArrayFromImage(segmentation)
        for segment in labels_to_process:
            logger.info(f"Processing segment {segment}")

            if self._skip_empty_slices:
                bbox = bboxs[segment]
                min_z, max_z = bbox[4], bbox[5] + 1
            else:
                min_z, max_z = 0, segmentation.GetDepth()
            logger.info(
                "Total number of slices that will be processed for segment "
                f"{segment} is {max_z - min_z} (inclusive from {min_z} to {max_z})"
            )

            skipped_slices = []
            for slice_idx in range(min_z, max_z):
                frame_index = (min_x, min_y, slice_idx)
                frame_position = segmentation.TransformIndexToPhysicalPoint(
                    frame_index)
                frame_data = np.equal(
                    buffer[slice_idx, min_y:max_y, min_x:max_x], segment)
                if self._skip_empty_slices and not frame_data.any():
                    skipped_slices.append(slice_idx)
                    continue

                frame_fg_item = result.add_frame(
                    data=frame_data.astype(np.uint8),
                    referenced_segment=segment,
                    referenced_images=slice_to_source_images[slice_idx],
                )

                frame_fg_item.FrameContentSequence = [pydicom.Dataset()]
                frame_fg_item.FrameContentSequence[0].DimensionIndexValues = [
                    segment,  # Segment number
                    slice_idx - min_z + 1,  # Slice index within cropped volume
                ]
                frame_fg_item.PlanePositionSequence = [pydicom.Dataset()]
                frame_fg_item.PlanePositionSequence[0].ImagePositionPatient = [
                    f"{x:e}" for x in frame_position
                ]

            if skipped_slices:
                logger.info(f"Skipped empty slices for segment {segment}: "
                            f'{", ".join([str(x) for x in skipped_slices])}')

        # Encode all frames into a bytearray
        if self._inplane_cropping or self._skip_empty_slices:
            num_encoded_bytes = len(result.PixelData)
            max_encoded_bytes = (segmentation.GetWidth() *
                                 segmentation.GetHeight() *
                                 segmentation.GetDepth() *
                                 len(result.SegmentSequence) // 8)
            savings = (1 - num_encoded_bytes / max_encoded_bytes) * 100
            logger.info(
                f"Optimized frame data length is {num_encoded_bytes:,}B "
                f"instead of {max_encoded_bytes:,}B (saved {savings:.2f}%)")

        result.SegmentsOverlap = "NO"

        return result