Пример #1
0
def verify_same_geometry(img_1: sitk.Image, img_2: sitk.Image):
    ori1, spacing1, direction1, size1 = img_1.GetOrigin(), img_1.GetSpacing(
    ), img_1.GetDirection(), img_1.GetSize()
    ori2, spacing2, direction2, size2 = img_2.GetOrigin(), img_2.GetSpacing(
    ), img_2.GetDirection(), img_2.GetSize()

    same_ori = np.all(np.isclose(ori1, ori2))
    if not same_ori:
        print("the origin does not match between the images:")
        print(ori1)
        print(ori2)

    same_spac = np.all(np.isclose(spacing1, spacing2))
    if not same_spac:
        print("the spacing does not match between the images")
        print(spacing1)
        print(spacing2)

    same_dir = np.all(np.isclose(direction1, direction2))
    if not same_dir:
        print("the direction does not match between the images")
        print(direction1)
        print(direction2)

    same_size = np.all(np.isclose(size1, size2))
    if not same_size:
        print("the size does not match between the images")
        print(size1)
        print(size2)

    if same_ori and same_spac and same_dir and same_size:
        return True
    else:
        return False
Пример #2
0
def assert_sitk_img_equivalence(img: SimpleITK.Image,
                                img_ref: SimpleITK.Image):
    assert img.GetDimension() == img_ref.GetDimension()
    assert img.GetSize() == img_ref.GetSize()
    assert img.GetOrigin() == img_ref.GetOrigin()
    assert img.GetSpacing() == img_ref.GetSpacing()
    assert (img.GetNumberOfComponentsPerPixel() ==
            img_ref.GetNumberOfComponentsPerPixel())
    assert img.GetPixelIDValue() == img_ref.GetPixelIDValue()
    assert img.GetPixelIDTypeAsString() == img_ref.GetPixelIDTypeAsString()
Пример #3
0
def compatible_metadata(image1: sitk.Image,
                        image2: sitk.Image,
                        check_size: bool = True,
                        check_spacing: bool = True,
                        check_origin: bool = True) -> bool:
    """ Compares the metadata of two images and determines if all checks are successful or not.
    Comparisons are carried out with a small tolerance (0.0001).

    @param image1: first image
    @param image2: second image
    @param check_size: if true, check if the sizes of the images are equal
    @param check_spacing: if true, check if the spacing of the images are equal
    @param check_origin: if true, check if the origin of the images are equal
    @return: true, if images are equal in all given checks, false if one of them failed
    """
    all_parameters_equal = True
    tolerance = 1e-4

    if check_size:
        size1 = image1.GetSize()
        size2 = image2.GetSize()
        if size1 != size2:
            all_parameters_equal = False
            print(f'Images do not have the same size ({size1} != {size2})')

    if check_spacing:
        spacing1 = image1.GetSpacing()
        spacing2 = image2.GetSpacing()
        if any(
                list(
                    abs(s1 - s2) > tolerance
                    for s1, s2 in zip(spacing1, spacing2))):
            all_parameters_equal = False
            print(
                f'Images do not have the same spacing ({spacing1} != {spacing2})'
            )

    if check_origin:
        origin1 = image1.GetOrigin()
        origin2 = image2.GetOrigin()
        if any(
                list(
                    abs(o1 - o2) > tolerance
                    for o1, o2 in zip(origin1, origin2))):
            all_parameters_equal = False
            print(
                f'Images do not have the same origin ({origin1} != {origin2})')

    return all_parameters_equal
Пример #4
0
def split_4d_itk(img_itk: sitk.Image) -> List[sitk.Image]:
    """
    Helper function to split 4d itk images into multiple 3 images

    Args:
        img_itk: 4D input image

    Returns:
        List[sitk.Image]: 3d output images
    """
    img_npy = sitk.GetArrayFromImage(img_itk)
    spacing = img_itk.GetSpacing()
    origin = img_itk.GetOrigin()
    direction = np.array(img_itk.GetDirection()).reshape(4, 4)

    spacing = tuple(list(spacing[:-1]))
    assert len(spacing) == 3
    origin = tuple(list(origin[:-1]))
    assert len(origin) == 3
    direction = tuple(direction[:-1, :-1].reshape(-1))
    assert len(direction) == 9

    images_new = []
    for i, t in enumerate(range(img_npy.shape[0])):
            img = img_npy[t]
            images_new.append(
                create_itk_image_spatial_props(img, spacing, origin, direction))
    return images_new
Пример #5
0
def image_resample(image: sitk.Image):
    '''
    image = sitk.ReadImage(image_path)
    使用 SimpleITK 自带函数重新缩放图像
    '''
    origin_spacing = image.GetSpacing()  #  获取源分辨率
    origin_size = image.GetSize()

    new_spacing = [1, 1, 1]  #  设置新分辨率

    resample = sitk.ResampleImageFilter()
    resample.SetInterpolator(sitk.sitkLinear)
    resample.SetDefaultPixelValue(0)

    resample.SetOutputSpacing(new_spacing)
    resample.SetOutputOrigin(image.GetOrigin())
    resample.SetOutputDirection(image.GetDirection())

    #  计算新图像的大小
    new_size = [
        int(np.round(origin_size[0] * (origin_spacing[0] / new_spacing[0]))),
        int(np.round(origin_size[1] * (origin_spacing[1] / new_spacing[1]))),
        int(np.round(origin_size[2] * (origin_spacing[2] / new_spacing[2])))
    ]
    resample.SetSize(new_size)

    new_image = resample.Execute(image)
    return new_image
Пример #6
0
def sitk_to_nib(
    image: sitk.Image,
    keepdim: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
    data = sitk.GetArrayFromImage(image).transpose()
    num_components = image.GetNumberOfComponentsPerPixel()
    if num_components == 1:
        data = data[np.newaxis]  # add channels dimension
    input_spatial_dims = image.GetDimension()
    if input_spatial_dims == 2:
        data = data[..., np.newaxis]
    if not keepdim:
        data = ensure_4d(data, num_spatial_dims=input_spatial_dims)
    assert data.shape[0] == num_components
    assert data.shape[1:1 + input_spatial_dims] == image.GetSize()
    spacing = np.array(image.GetSpacing())
    direction = np.array(image.GetDirection())
    origin = image.GetOrigin()
    if len(direction) == 9:
        rotation = direction.reshape(3, 3)
    elif len(direction) == 4:  # ignore first dimension if 2D (1, W, H, 1)
        rotation_2d = direction.reshape(2, 2)
        rotation = np.eye(3)
        rotation[:2, :2] = rotation_2d
        spacing = *spacing, 1
        origin = *origin, 0
    else:
        raise RuntimeError(f'Direction not understood: {direction}')
    rotation = np.dot(FLIP_XY, rotation)
    rotation_zoom = rotation * spacing
    translation = np.dot(FLIP_XY, origin)
    affine = np.eye(4)
    affine[:3, :3] = rotation_zoom
    affine[:3, 3] = translation
    return data, affine
Пример #7
0
def centre_of_mass(image: sitk.Image) -> np.ndarray:
    r""" Compute the centre of mass of the image.

    A real-valued image represents the distribution of mass, and its
    centre of mass is defined as :math:`\frac{1}{\sum_p I(p)} \sum_p p I(p)`.

    Parameters
    ----------
    image : sitk.Image
        Input image.

    Returns
    -------
    np.ndarray
        World coordinates (x, y, z) of the centre of mass.
    """
    data = sitk.GetArrayViewFromImage(image)

    grid = np.meshgrid(*[range(i) for i in data.shape], indexing='ij')
    grid = np.vstack([a.flatten() for a in grid]).T

    cm = np.average(grid, axis=0, weights=data.flatten())

    return np.multiply(np.flip(cm, axis=0),
                       image.GetSpacing()) - image.GetOrigin()
Пример #8
0
def sitk_to_nib(
    image: sitk.Image,
    keepdim: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
    data = sitk.GetArrayFromImage(image).transpose()
    num_components = image.GetNumberOfComponentsPerPixel()
    if num_components == 1:
        data = data[np.newaxis]  # add channels dimension
    input_spatial_dims = image.GetDimension()
    if not keepdim:
        data = ensure_4d(data, False, num_spatial_dims=input_spatial_dims)
    assert data.shape[0] == num_components
    assert data.shape[-input_spatial_dims:] == image.GetSize()
    spacing = np.array(image.GetSpacing())
    direction = np.array(image.GetDirection())
    origin = image.GetOrigin()
    if len(direction) == 9:
        rotation = direction.reshape(3, 3)
    elif len(direction) == 4:  # ignore first dimension if 2D (1, 1, H, W)
        rotation_2d = direction.reshape(2, 2)
        rotation = np.eye(3)
        rotation[1:3, 1:3] = rotation_2d
        spacing = 1, *spacing
        origin = 0, *origin
    rotation = np.dot(FLIP_XY, rotation)
    rotation_zoom = rotation * spacing
    translation = np.dot(FLIP_XY, origin)
    affine = np.eye(4)
    affine[:3, :3] = rotation_zoom
    affine[:3, 3] = translation
    return data, affine
 def __init__(self, data: sitk.Image, segmentation: sitk.Image, bb_size):
     self.segmentation = segmentation
     # place bb so that the segmentation is centered
     self.offset = (
         (np.array(bb_size) - np.array(self.segmentation.GetSize())) /
         2).astype(int)
     # adjust offset if resulting bb is out of bounds
     segmentation_origin_in_data = data.TransformPhysicalPointToIndex(
         segmentation.GetOrigin())
     self.offset = [
         seg_or if seg_or - off < 0 else off
         for seg_or, off in zip(segmentation_origin_in_data, self.offset)
     ]
     self.offset = [
         off + ((seg_or - off + bb_sz) - data_sz)
         if seg_or - off + bb_sz > data_sz else off for seg_or, off, bb_sz,
         data_sz in zip(segmentation_origin_in_data, self.offset, bb_size,
                        data.GetSize())
     ]
     cropped_origin = np.array(segmentation_origin_in_data) - self.offset
     assert all(cr_or >= 0 for cr_or in cropped_origin), \
         f"Data size: {data.GetSize()}, BB size: {bb_size}, BB origin: {cropped_origin}"
     assert all(cr_or + bb_s <= data_s for cr_or, bb_s, data_s in zip(cropped_origin, bb_size, data.GetSize())), \
         f"Data size: {data.GetSize()}, BB size: {bb_size}, BB origin: {cropped_origin}"
     self.data = data[cropped_origin[0]:cropped_origin[0] + bb_size[0],
                      cropped_origin[1]:cropped_origin[1] + bb_size[1],
                      cropped_origin[2]:cropped_origin[2] + bb_size[2]]
Пример #10
0
        def slice_by_slice(image: sitk.Image, *args, **kwargs):

            dim = image.GetDimension()
            iter_dim = 2

            if dim <= iter_dim:
                image = func(image, *args, **kwargs)
                return image

            extract_size = list(image.GetSize())
            extract_size[iter_dim:] = itertools.repeat(0, dim - iter_dim)

            extract_index = [0] * dim
            paste_idx = [slice(None, None)] * dim

            extractor = sitk.ExtractImageFilter()
            extractor.SetSize(extract_size)
            if inplace:
                for high_idx in itertools.product(
                        *[range(s) for s in image.GetSize()[iter_dim:]]):
                    extract_index[iter_dim:] = high_idx
                    extractor.SetIndex(extract_index)

                    paste_idx[iter_dim:] = high_idx
                    image[paste_idx] = func(extractor.Execute(image), *args,
                                            **kwargs)

            else:
                img_list = []
                for high_idx in itertools.product(
                        *[range(s) for s in image.GetSize()[iter_dim:]]):
                    extract_index[iter_dim:] = high_idx
                    extractor.SetIndex(extract_index)

                    paste_idx[iter_dim:] = high_idx

                    img_list.append(
                        func(extractor.Execute(image), *args, **kwargs))

                for d in range(iter_dim, dim):
                    step = reduce((lambda x, y: x * y),
                                  image.GetSize()[d + 1:], 1)

                    join_series_filter = sitk.JoinSeriesImageFilter()
                    join_series_filter.SetSpacing(image.GetSpacing()[d])
                    join_series_filter.SetOrigin(image.GetOrigin()[d])

                    img_list = [
                        join_series_filter.Execute(img_list[i::step])
                        for i in range(step)
                    ]

                assert len(img_list) == 1
                image = img_list[0]

            return image
Пример #11
0
def resample_sitk_image(sitk_image: sitk.Image,
                        new_size,
                        interpolator="gaussian",
                        fill_value=0) -> sitk.Image:
    """
        modified version from:
            https://github.com/jonasteuwen/SimpleITK-examples/blob/master/examples/resample_isotropically.py
    """

    # if pass a path to image
    if isinstance(sitk_image, str):
        sitk_image = sitk.ReadImage(sitk_image)

    assert (interpolator in _SITK_INTERPOLATOR_DICT.keys()
            ), "`interpolator` should be one of {}".format(
                _SITK_INTERPOLATOR_DICT.keys())

    if not interpolator:
        interpolator = "linear"
        pixelid = sitk_image.GetPixelIDValue()

        if pixelid not in [1, 2, 4]:
            raise NotImplementedError(
                "Set `interpolator` manually, "
                "can only infer for 8-bit unsigned or 16, 32-bit signed integers"
            )

    #  8-bit unsigned int
    if sitk_image.GetPixelIDValue() == 1:
        # if binary mask interpolate it as nearest
        interpolator = "nearest"

    sitk_interpolator = _SITK_INTERPOLATOR_DICT[interpolator]
    orig_pixelid = sitk_image.GetPixelIDValue()
    orig_origin = sitk_image.GetOrigin()
    orig_direction = sitk_image.GetDirection()

    # new spacing based on the desired output shape
    new_spacing = tuple(
        np.array(sitk_image.GetSpacing()) * np.array(sitk_image.GetSize()) /
        np.array(new_size))

    # setup image resampler - SimpleITK 2.0
    resample_filter = sitk.ResampleImageFilter()
    resample_filter.SetOutputSpacing(new_spacing)
    resample_filter.SetSize(new_size)
    resample_filter.SetOutputDirection(orig_direction)
    resample_filter.SetOutputOrigin(orig_origin)
    resample_filter.SetTransform(sitk.Transform())
    resample_filter.SetDefaultPixelValue(orig_pixelid)
    resample_filter.SetInterpolator(sitk_interpolator)
    resample_filter.SetDefaultPixelValue(fill_value)
    # run it
    resampled_sitk_image = resample_filter.Execute(sitk_image)

    return resampled_sitk_image
Пример #12
0
def match_world_info(source: sitk.Image,
                     target: sitk.Image,
                     spacing: Union[bool, Tuple[int], List[int]] = True,
                     origin: Union[bool, Tuple[int], List[int]] = True,
                     direction: Union[bool, Tuple[int], List[int]] = True):
    """Copy world information (eg spacing, origin, direction) from one
    image object to another.

    This matching is sometimes necessary for slight differences in
    metadata perhaps from founding that may prevent ITK filters from executing.

    Args:
        source (:obj:`sitk.Image`): Source object whose relevant metadata
            will be copied into ``target``.
        target (:obj:`sitk.Image`): Target object whose corresponding
            metadata will be overwritten by that of ``source``.
        spacing: True to copy the spacing from ``source`` to ``target``, or
            the spacing to set in ``target``; defaults to True.
        origin: True to copy the origin from ``source`` to ``target``, or
            the origin to set in ``target``; defaults to True.
        direction: True to copy the direction from ``source`` to ``target``, or
            the direction to set in ``target``; defaults to True.

    """
    # get the world info from the source if not already set
    if spacing is True:
        spacing = source.GetSpacing()
    if origin is True:
        origin = source.GetOrigin()
    if direction is True:
        direction = source.GetDirection()

    # set the values in the target
    _logger.debug(
        "Adjusting spacing from %s to %s, origin from %s to %s, "
        "direction from %s to %s", target.GetSpacing(), spacing,
        target.GetOrigin(), origin, target.GetDirection(), direction)
    if spacing:
        target.SetSpacing(spacing)
    if origin:
        target.SetOrigin(origin)
    if direction:
        target.SetDirection(direction)
Пример #13
0
    def set_origin_image(self, origin_img:sitk.Image) -> None : 
        """method to set the origin sitk.Image on which we want to resample 

        Args:
            origin_img (sitk.Image): []
        """
        self.origin_img = origin_img 
        self.origin_size = origin_img.GetSize()
        self.origin_spacing = origin_img.GetSpacing()
        self.origin_direction = origin_img.GetDirection()
        self.origin_origin = origin_img.GetOrigin()
Пример #14
0
def sitk_to_nib(image: sitk.Image) -> Tuple[np.ndarray, np.ndarray]:
    data = sitk.GetArrayFromImage(image).transpose()
    spacing = np.array(image.GetSpacing())
    rotation = np.array(image.GetDirection()).reshape(3, 3)
    rotation = np.dot(FLIP_XY, rotation)
    rotation_zoom = rotation * spacing
    translation = np.dot(FLIP_XY, image.GetOrigin())
    affine = np.eye(4)
    affine[:3, :3] = rotation_zoom
    affine[:3, 3] = translation
    return data, affine
Пример #15
0
def display_info(img: sitk.Image) -> None:
    """display information about a sitk.Image

    Args:
        img (sitk.Image): [sitk image]
    """
    print('img information :')
    print('\t Origin    :', img.GetOrigin())
    print('\t Size      :', img.GetSize())
    print('\t Spacing   :', img.GetSpacing())
    print('\t Direction :', img.GetDirection())
Пример #16
0
    def _assert_3d(self, image: sitk.Image, is_vector=False):
        self.assertEqual(self.properties_3d.size, image.GetSize())

        if is_vector:
            self.assertEqual(self.no_vector_components,
                             image.GetNumberOfComponentsPerPixel())
        else:
            self.assertEqual(1, image.GetNumberOfComponentsPerPixel())

        self.assertEqual(self.origin_spacing_3d, image.GetOrigin())
        self.assertEqual(self.origin_spacing_3d, image.GetSpacing())
        self.assertEqual(self.direction_3d, image.GetDirection())
Пример #17
0
def sitk_to_itk(image: sitk.Image) -> Any:
    r""" Function to convert an image object from SimpleITK to ITK.

    .. note::
        Data is copied to the new object (deep copy).

    Parameters
    ----------
    image : sitk.Image
        Input image.

    Returns
    -------
    any
        Image in ITK format.
    """

    if 'itk' not in sys.modules:
        raise Exception(
            'sitk_to_itk: itk module is required to use this feature.')

    a = sitk.GetArrayViewFromImage(image)

    if len(a.shape) < 4:
        result = itk.GetImageFromArray(a)

    else:
        # NOTE: This workaround is implemented this way since it
        # seems that itk.GetImageFromArray() is not working properly
        # with vector images.

        region = itk.ImageRegion[3]()
        region.SetSize(image.GetSize())
        region.SetIndex((0, 0, 0))

        PixelType = itk.Vector[itk_float_type, 3]
        ImageType = itk.Image[PixelType, 3]

        # Create new image
        result = ImageType.New()
        result.SetRegions(region)
        result.Allocate()
        result.SetSpacing(image.GetSpacing())

        # Copy image data
        b = itk.GetArrayViewFromImage(result)
        b[:] = a[:]

    result.SetSpacing(image.GetSpacing())
    result.SetOrigin(image.GetOrigin())

    return result
Пример #18
0
    def __init__(self, image: sitk.Image):
        """Initializes a new instance of the ImageInformation class.

        Args:
            image (sitk.Image): The image whose properties to hold.
        """
        self.size = image.GetSize()
        self.origin = image.GetOrigin()
        self.spacing = image.GetSpacing()
        self.direction = image.GetDirection()
        self.dimensions = image.GetDimension()
        self.number_of_components_per_pixel = image.GetNumberOfComponentsPerPixel()
        self.pixel_id = image.GetPixelID()
Пример #19
0
def rgb_to_grayscale_img(image: sitk.Image, white_light_filter_value=0.9):
    """Convert an RGB to grayscale image by extracting the average intensity, filtering out white light >0.9 max"""
    array = sitk.GetArrayFromImage(image)
    dimension = image.GetDimension()

    grayscale_array = np.average(array, 2)
    grayscale_array[grayscale_array > white_light_filter_value *
                    np.max(array)] = 0

    grayscale_image = sitk.GetImageFromArray(grayscale_array)
    grayscale_image.SetSpacing(image.GetSpacing())
    grayscale_image.SetOrigin(image.GetOrigin())

    return grayscale_image
Пример #20
0
    def get_mask(ground_truth: sitk.Image,
                 ground_truth_labels: list,
                 label_percentages: list,
                 background_mask: sitk.Image = None) -> sitk.Image:
        """Gets a training mask.

        Args:
            ground_truth (sitk.Image): The ground truth image.
            ground_truth_labels (list of int): The ground truth labels,
                where 0=background, 1=label1, 2=label2, ..., e.g. [0, 1]
            label_percentages (list of float): The percentage of voxels of a corresponding label to extract as mask,
                e.g. [0.2, 0.2].
            background_mask (sitk.Image): A mask, where intensity 0 indicates voxels to exclude independent of the label.

        Returns:
            sitk.Image: The training mask.
        """

        # initialize mask
        ground_truth_array = sitk.GetArrayFromImage(ground_truth)
        mask_array = np.zeros(ground_truth_array.shape, dtype=np.uint8)

        # exclude background
        if background_mask is not None:
            background_mask_array = sitk.GetArrayFromImage(background_mask)
            background_mask_array = np.logical_not(background_mask_array)
            ground_truth_array = ground_truth_array.astype(
                float)  # convert to float because of np.nan
            ground_truth_array[background_mask_array] = np.nan

        for label_idx, label in enumerate(ground_truth_labels):
            indices = np.transpose(np.where(ground_truth_array == label))
            np.random.shuffle(indices)

            no_mask_items = int(indices.shape[0] *
                                label_percentages[label_idx])

            for no in range(no_mask_items):
                x = indices[no][0]
                y = indices[no][1]
                z = indices[no][2]
                mask_array[x, y, z] = 1  # this is a masked item

        mask = sitk.GetImageFromArray(mask_array)
        mask.SetOrigin(ground_truth.GetOrigin())
        mask.SetDirection(ground_truth.GetDirection())
        mask.SetSpacing(ground_truth.GetSpacing())

        return mask
Пример #21
0
    def roi2mask(self, mask_img:sitk.Image, pet_img:sitk.Image) -> sitk.Image:
        """
        Generate the thresholded mask from the ROI 
        Args:
            :param mask_img: sitk.Image, raw mask (i.e ROI)
            :param pet_img: sitk.Image, the corresponding pet scan

        :return: sitk.Image, the ground truth segmentation
        """
        # transform to numpy
        origin = mask_img.GetOrigin()
        spacing = mask_img.GetSpacing()
        direction = tuple(mask_img.GetDirection())
        mask_array = sitk.GetArrayFromImage(mask_img) #[z,y,x,C]
        pet_array = sitk.GetArrayFromImage(pet_img) #[z,y,x]

        # get 3D meta information
        if len(mask_array.shape) == 3:
            mask_array = np.expand_dims(mask_array, axis=0) #[1,z,y,x]
        else : 
            mask_array = np.transpose(mask_array, (3,0,1,2)) #[C,z,y,x]

        new_mask = np.zeros(mask_array.shape[1:], dtype=np.int8) #[z,y,x]

        for num_slice in range(mask_array.shape[0]):
            mask_slice = mask_array[num_slice] #ROI 3D MATRIX
            roi = pet_array[mask_slice > 0]
            if len(roi) == 0:
                # R.O.I is empty
                continue
            try:
                threshold = self.calculate_threshold(roi)
                # apply threshold
                new_mask[np.where((pet_array >= threshold) & (mask_slice > 0))] = 1

            except Exception as e:
                print(e)
                print(sys.exc_info()[0])

        # reconvert to sitk and restore information
        new_mask = sitk.GetImageFromArray(new_mask)
        new_mask.SetOrigin(origin)
        new_mask.SetDirection(direction)
        new_mask.SetSpacing(spacing)

        return new_mask
Пример #22
0
    def execute(self,
                image: sitk.Image,
                params: fltr.IFilterParams = None) -> sitk.Image:
        """Executes a atlas coordinates feature extractor on an image.

        Args:
            image (sitk.Image): The image.
            params (fltr.IFilterParams): The parameters (unused).

        Returns:
            sitk.Image: The atlas coordinates image
            (a vector image with 3 components, which represent the physical x, y, z coordinates in mm).

        Raises:
            ValueError: If image is not 3-D.
        """

        if image.GetDimension() != 3:
            raise ValueError('image needs to be 3-D')

        x, y, z = image.GetSize()

        # create matrix with homogenous indices in axis 3
        coords = np.zeros((x, y, z, 4))
        coords[..., 0] = np.arange(x)[:, np.newaxis, np.newaxis]
        coords[..., 1] = np.arange(y)[np.newaxis, :, np.newaxis]
        coords[..., 2] = np.arange(z)[np.newaxis, np.newaxis, :]
        coords[..., 3] = 1

        # reshape such that each voxel is one row
        lin_coords = np.reshape(
            coords, [coords.shape[0] * coords.shape[1] * coords.shape[2], 4])

        # generate transformation matrix
        tmpmat = image.GetDirection() + image.GetOrigin()
        tfm = np.reshape(tmpmat, [3, 4], order='F')
        tfm = np.vstack((tfm, [0, 0, 0, 1]))

        atlas_coords = (tfm @ np.transpose(lin_coords))[0:3, :]
        atlas_coords = np.reshape(np.transpose(atlas_coords), [z, y, x, 3],
                                  'F')

        img_out = sitk.GetImageFromArray(atlas_coords)
        img_out.CopyInformation(image)

        return img_out
Пример #23
0
    def roi2mask(self, mask_img:sitk.Image, pet_img:sitk.Image) -> sitk.Image:
        """
        Generate the thresholded mask from the ROI with otsu, 41%, 2.5 and 4.0 segmentation
        
        Args:
            :param mask_img: sitk.Image, raw mask (i.e ROI)
            :param pet_img: sitk.Image, the corresponding pet scan

        :return: sitk.Image, the ground truth segmentation
        """
        # transform to numpy
        origin = mask_img.GetOrigin()
        spacing = mask_img.GetSpacing()
        direction = tuple(mask_img.GetDirection())
        mask_array = sitk.GetArrayFromImage(mask_img)
        pet_array = sitk.GetArrayFromImage(pet_img)

        # get 3D meta information
        if len(mask_array.shape) == 3:
            mask_array = np.expand_dims(mask_array, axis=0)
        else:
            mask_array = np.transpose(mask_array, (3,0,1,2))

        new_masks = []
        #otsu 
        #print('otsu')
        new_masks.append(self.__roi_seg(mask_array, pet_array, threshold='otsu'))
        #print('41%')
        new_masks.append(self.__roi_seg(mask_array, pet_array, threshold='0.41'))
        #2.5
        #print('2.5')
        new_masks.append(self.__roi_seg(mask_array, pet_array, threshold='2.5'))
        #4.0
        #print('4.0')
        new_masks.append(self.__roi_seg(mask_array, pet_array, threshold='4.0'))
        new_mask = np.stack(new_masks, axis=3)
        new_mask = np.mean(new_mask, axis=3)
 

        # reconvert to sitk and restore 3D meta-information
        new_mask = sitk.GetImageFromArray(new_mask)
        new_mask.SetOrigin(origin)
        new_mask.SetDirection(direction)
        new_mask.SetSpacing(spacing)
    
        return new_mask
Пример #24
0
    def __call__(self, img: sitk.Image) -> sitk.Image:
        size, spacing, origin, direction = img.GetSize(), img.GetSpacing(), img.GetOrigin(), img.GetDirection()

        scale = [ns / s for ns, s in zip(self.new_spacing, spacing)]
        new_size = [int(sz/sc) for sz, sc in zip(size, scale)]
        # new_origin = [o / sc for o, sc in zip(origin, scale)]

        resampler = sitk.ResampleImageFilter()
        resampler.SetSize(new_size)
        # resampler.SetTransform(sitk.Transform())
        resampler.SetInterpolator(sitk.sitkLinear)
        resampler.SetOutputDirection(direction)
        # resampler.SetOutputOrigin(new_origin)  # misfitted image when using adapted origin
        resampler.SetOutputOrigin(origin)
        resampler.SetOutputSpacing(self.new_spacing)

        return resampler.Execute(img)
Пример #25
0
def scale_image(image: sitk.Image,
                new_size: Tuple[int, ...],
                interpolator: int = sitk.sitkLinear) -> sitk.Image:
    r""" Scale an image in a grid of given size.

    Parameters
    ----------
    image : sitk.Image
        An input image.
    new_size : Tuple[int, ...]
        A tuple of integers expressing the new size.
    interpolator : int
        A SimpleITK interpolator enum value.

    Returns
    -------
    sitk.Image
        Resized image.
    """

    if type(image) is not sitk.SimpleITK.Image:
        raise Exception("unsupported image object type")

    if type(new_size) != type((1, 1)):
        raise Exception("new_size must be a tuple of integers")

    size = image.GetSize()
    if len(new_size) != len(size):
        raise Exception("new_size must match the image dimensionality")

    spacing = []
    for s, x, nx in zip(image.GetSpacing(), size, new_size):
        spacing.append(s * x / nx)

    resampler = sitk.ResampleImageFilter()
    resampler.SetSize(new_size)
    resampler.SetOutputSpacing(tuple(spacing))
    resampler.SetOutputOrigin(image.GetOrigin())
    resampler.SetOutputDirection(image.GetDirection())
    resampler.SetInterpolator(interpolator)

    # Anti-aliasing
    image = sitk.SmoothingRecursiveGaussian(image, 2.0)

    return resampler.Execute(image)
Пример #26
0
def copy_meta_data_itk(source: sitk.Image, target: sitk.Image) -> sitk.Image:
    """
    Copy meta data between files

    Args:
        source: source file
        target: target file

    Returns:
        sitk.Image: target file with copied meta data
    """
    # for i in source.GetMetaDataKeys():
    #     target.SetMetaData(i, source.GetMetaData(i))
    raise NotImplementedError("Does not work!")
    target.SetOrigin(source.GetOrigin())
    target.SetDirection(source.GetDirection())
    target.SetSpacing(source.GetSpacing())
    return target
Пример #27
0
def compute_dice(b1: sitk.Image, b2: sitk.Image):

    b2 = sitk.Resample(
        b2,
        b1.GetSize(),
        sitk.Transform(),
        sitk.sitkNearestNeighbor,
        b1.GetOrigin(),
        b1.GetSpacing(),
        b1.GetDirection(),
        0,
        b2.GetPixelID(),
    )
    labstats = sitk.LabelOverlapMeasuresImageFilter()

    labstats.Execute(b1, b2)

    return labstats.GetDiceCoefficient()
Пример #28
0
    def remove_small_roi(cls, binary_img: sitk.Image,
                         pet_img: sitk.Image) -> sitk.Image:
        """function to remove ROI under 30 ml on a binary sitk.Image

        Args:
            binary_img (sitk.Image): [sitk.Image of size (z,y,x)]
            pet_img (sitk.Image): [sitk.Image of the PET, size (z,y,x)]

        Raises:
            Exception: [raise Exception if not a 3D binary mask]

        Returns:
            [sitk.Image]: [Return cleaned image]
        """

        binary_array = sitk.GetArrayFromImage(binary_img)
        if len(binary_array.shape) != 3 or int(np.max(binary_array)) != 1:
            raise Exception(
                "Not a 3D binary mask, need to transform into 3D binary mask")
        else:
            pet_spacing = pet_img.GetSpacing()
            pet_origin = pet_img.GetOrigin()
            pet_direction = pet_img.GetDirection()
            labelled_img = sitk.ConnectedComponent(binary_img)
            stats = sitk.LabelIntensityStatisticsImageFilter()
            stats.Execute(labelled_img, pet_img)
            labelled_array = sitk.GetArrayFromImage(labelled_img).transpose()
            number_of_label = stats.GetNumberOfLabels()
            volume_voxel = pet_spacing[0] * pet_spacing[1] * pet_spacing[
                2] * 10**(-3)  #in ml
            for i in range(1, number_of_label + 1):
                volume_roi = stats.GetNumberOfPixels(i) * volume_voxel
                if volume_roi < float(30):
                    x, y, z = np.where(labelled_array == i)
                    for j in range(len(x)):
                        labelled_array[x[j], y[j], z[j]] = 0
            new_binary_array = np.zeros((labelled_array.shape))
            new_binary_array[np.where(labelled_array != 0)] = 1
            new_binary_img = sitk.GetImageFromArray(
                new_binary_array.transpose().astype(np.uint8))
            new_binary_img.SetOrigin(pet_origin)
            new_binary_img.SetSpacing(pet_spacing)
            new_binary_img.SetDirection(pet_direction)
            return new_binary_img
Пример #29
0
def sitk_to_nib(image: sitk.Image) -> Tuple[np.ndarray, np.ndarray]:
    data = sitk.GetArrayFromImage(image).transpose()
    spacing = np.array(image.GetSpacing())
    direction = np.array(image.GetDirection())
    origin = image.GetOrigin()
    if len(direction) == 9:
        rotation = direction.reshape(3, 3)
    elif len(direction) == 4:  # ignore first dimension if 2D (1, 1, H, W)
        rotation_2d = direction.reshape(2, 2)
        rotation = np.eye(3)
        rotation[1:3, 1:3] = rotation_2d
        spacing = 1, *spacing
        origin = 0, *origin
    rotation = np.dot(FLIP_XY, rotation)
    rotation_zoom = rotation * spacing
    translation = np.dot(FLIP_XY, origin)
    affine = np.eye(4)
    affine[:3, :3] = rotation_zoom
    affine[:3, 3] = translation
    return data, affine
Пример #30
0
def sitk_copy_metadata(img_source: sitk.Image, img_target: sitk.Image) -> sitk.Image:
    """
    Copy metadata (spacing, origin, direction) from source to target image

    Args
        img_source: source image
        img_target: target image

    Returns:
        SimpleITK.Image: target image with copied metadata
    """ 
    raise RuntimeError("Deprecated")
    spacing = img_source.GetSpacing()
    img_target.SetSpacing(spacing)

    origin = img_source.GetOrigin()
    img_target.SetOrigin(origin)

    direction = img_source.GetDirection()
    img_target.SetDirection(direction)
    return img_target