예제 #1
0
def verify_same_geometry(img_1: sitk.Image, img_2: sitk.Image):
    ori1, spacing1, direction1, size1 = img_1.GetOrigin(), img_1.GetSpacing(
    ), img_1.GetDirection(), img_1.GetSize()
    ori2, spacing2, direction2, size2 = img_2.GetOrigin(), img_2.GetSpacing(
    ), img_2.GetDirection(), img_2.GetSize()

    same_ori = np.all(np.isclose(ori1, ori2))
    if not same_ori:
        print("the origin does not match between the images:")
        print(ori1)
        print(ori2)

    same_spac = np.all(np.isclose(spacing1, spacing2))
    if not same_spac:
        print("the spacing does not match between the images")
        print(spacing1)
        print(spacing2)

    same_dir = np.all(np.isclose(direction1, direction2))
    if not same_dir:
        print("the direction does not match between the images")
        print(direction1)
        print(direction2)

    same_size = np.all(np.isclose(size1, size2))
    if not same_size:
        print("the size does not match between the images")
        print(size1)
        print(size2)

    if same_ori and same_spac and same_dir and same_size:
        return True
    else:
        return False
예제 #2
0
    def execute(self,
                image: sitk.Image,
                params: fltr.IFilterParams = None) -> sitk.Image:
        """Executes a neighborhood feature extractor on an image.

        Args:
            image (sitk.Image): The image.
            params (fltr.IFilterParams): The parameters (unused).

        Returns:
            sitk.Image: The normalized image.

        Raises:
            ValueError: If image is not 3-D.
        """

        if image.GetDimension() != 3:
            raise ValueError('image needs to be 3-D')

        # test the function and get the output dimension for later reshaping
        function_output = self.function(np.array([1, 2, 3]))
        if np.isscalar(function_output):
            img_out = sitk.Image(image.GetSize(), sitk.sitkFloat32)
        elif not isinstance(function_output, np.ndarray):
            raise ValueError(
                'function must return a scalar or a 1-D np.ndarray')
        elif function_output.ndim > 1:
            raise ValueError(
                'function must return a scalar or a 1-D np.ndarray')
        elif function_output.shape[0] <= 1:
            raise ValueError(
                'function must return a scalar or a 1-D np.ndarray with at least two elements'
            )
        else:
            img_out = sitk.Image(image.GetSize(), sitk.sitkVectorFloat32,
                                 function_output.shape[0])

        img_out_arr = sitk.GetArrayFromImage(img_out)
        img_arr = sitk.GetArrayFromImage(image)
        z, y, x = img_arr.shape

        z_offset = self.kernel[2]
        y_offset = self.kernel[1]
        x_offset = self.kernel[0]
        pad = ((0, z_offset), (0, y_offset), (0, x_offset))
        img_arr_padded = np.pad(img_arr, pad, 'symmetric')

        for xx in range(x):
            for yy in range(y):
                for zz in range(z):

                    val = self.function(img_arr_padded[zz:zz + z_offset,
                                                       yy:yy + y_offset,
                                                       xx:xx + x_offset])
                    img_out_arr[zz, yy, xx] = val

        img_out = sitk.GetImageFromArray(img_out_arr)
        img_out.CopyInformation(image)

        return img_out
예제 #3
0
def fitting_index(image: sitk.Image,
                  norm: Union[int, str] = 2.0,
                  centre: Tuple[float, float, float] = None,
                  radius: float = None,
                  padding: bool = True) -> float:
    r""" Compute the fitting index of an input object.

    The fitting index of order `p` is defined as the Jaccard coefficient
    computed between the input object and a p-ball centred in the object's
    centre of mass.

    Parameters
    ----------
    image : sitk.Image
        Input binary image of the object.
    norm : Union[int,str]
        Order of the Minkowski norm ('inf' or 'max' to use the Chebyshev norm).
    centre : Tuple[float, float, float]
        Forces the p-ball to be centred in a specific point.
    radius : float
        Force the radius of the p-ball.
    padding : bool
        If `True`, add enough padding to be sure that the ball will entirely
        fit within the volume.

    Returns
    -------
    float
        Value of the fitting index.
    """

    if image.GetPixelID() != sitk.sitkUInt8:
        raise Exception('Unsupported %s pixel type' %
                        image.GetPixelIDTypeAsString())

    if centre is None:
        # Use the centroid as centre
        lssif = sitk.LabelShapeStatisticsImageFilter()
        lssif.Execute(image)
        centre = lssif.GetCentroid(1)

    if padding:
        # Add some padding to be sure that an isovolumetric 1-ball can fit
        # within the same volume of a sphere touching the boundary
        pad = tuple([x // 4 for x in image.GetSize()])
        image = sitk.ConstantPad(image, pad, pad, 0)
        image.SetOrigin((0, 0, 0))
        centre = tuple([x + y for x, y in zip(centre, pad)])

    if radius is None:
        radius = isovolumteric_radius(image, norm)

    size = image.GetSize()

    sphere = drawing.create_sphere(radius, size=size, centre=centre,
                                   norm=norm) > 0

    return jaccard(image, sphere)
예제 #4
0
def zoom(image: sitk.Image,
         scale_factor: Union[float, Sequence[float]],
         interpolation: str = "linear",
         anti_alias: bool = True,
         anti_alias_sigma: Optional[float] = None) -> sitk.Image:
    """Rescale image, preserving its spatial extent.

    The rescaled image will have the same spatial extent (size) but will be
    rescaled by `scale_factor` in each dimension. Alternatively, a separate
    scale factor for each dimension can be specified by passing a sequence
    of floats.

    Parameters
    ----------
    image
        The image to rescale.

    scale_factor
        If float, each dimension will be scaled by that factor. If tuple, each
        dimension will be scaled by the corresponding element.

    interpolation, optional
        The interpolation method to use. Valid options are:
        - "linear" for bi/trilinear interpolation (default)
        - "nearest" for nearest neighbour interpolation
        - "bspline" for order-3 b-spline interpolation

    anti_alias, optional
        Whether to smooth the image with a Gaussian kernel before resampling.
        Only used when downsampling, i.e. when `size < image.GetSize()`.
        This should be used to avoid aliasing artifacts.

    anti_alias_sigma, optional
        The standard deviation of the Gaussian kernel used for anti-aliasing.

    Returns
    -------
    sitk.Image
        The rescaled image.
    """
    dimension = image.GetDimension()

    if isinstance(scale_factor, float):
        scale_factor = (scale_factor, ) * dimension

    centre_idx = np.array(image.GetSize()) / 2
    centre = image.TransformContinuousIndexToPhysicalPoint(centre_idx)

    transform = sitk.ScaleTransform(dimension, scale_factor)
    transform.SetCenter(centre)

    return resample(image,
                    spacing=image.GetSpacing(),
                    interpolation=interpolation,
                    anti_alias=anti_alias,
                    anti_alias_sigma=anti_alias_sigma,
                    transform=transform,
                    output_size=image.GetSize())
예제 #5
0
        def slice_by_slice(image: sitk.Image, *args, **kwargs):

            dim = image.GetDimension()
            iter_dim = 2

            if dim <= iter_dim:
                image = func(image, *args, **kwargs)
                return image

            extract_size = list(image.GetSize())
            extract_size[iter_dim:] = itertools.repeat(0, dim - iter_dim)

            extract_index = [0] * dim
            paste_idx = [slice(None, None)] * dim

            extractor = sitk.ExtractImageFilter()
            extractor.SetSize(extract_size)
            if inplace:
                for high_idx in itertools.product(
                        *[range(s) for s in image.GetSize()[iter_dim:]]):
                    extract_index[iter_dim:] = high_idx
                    extractor.SetIndex(extract_index)

                    paste_idx[iter_dim:] = high_idx
                    image[paste_idx] = func(extractor.Execute(image), *args,
                                            **kwargs)

            else:
                img_list = []
                for high_idx in itertools.product(
                        *[range(s) for s in image.GetSize()[iter_dim:]]):
                    extract_index[iter_dim:] = high_idx
                    extractor.SetIndex(extract_index)

                    paste_idx[iter_dim:] = high_idx

                    img_list.append(
                        func(extractor.Execute(image), *args, **kwargs))

                for d in range(iter_dim, dim):
                    step = reduce((lambda x, y: x * y),
                                  image.GetSize()[d + 1:], 1)

                    join_series_filter = sitk.JoinSeriesImageFilter()
                    join_series_filter.SetSpacing(image.GetSpacing()[d])
                    join_series_filter.SetOrigin(image.GetOrigin()[d])

                    img_list = [
                        join_series_filter.Execute(img_list[i::step])
                        for i in range(step)
                    ]

                assert len(img_list) == 1
                image = img_list[0]

            return image
예제 #6
0
def assert_sitk_img_equivalence(img: SimpleITK.Image,
                                img_ref: SimpleITK.Image):
    assert img.GetDimension() == img_ref.GetDimension()
    assert img.GetSize() == img_ref.GetSize()
    assert img.GetOrigin() == img_ref.GetOrigin()
    assert img.GetSpacing() == img_ref.GetSpacing()
    assert (img.GetNumberOfComponentsPerPixel() ==
            img_ref.GetNumberOfComponentsPerPixel())
    assert img.GetPixelIDValue() == img_ref.GetPixelIDValue()
    assert img.GetPixelIDTypeAsString() == img_ref.GetPixelIDTypeAsString()
예제 #7
0
def crop(image: sitk.Image, crop_centre: Sequence[float],
         size: Union[int, Sequence[int], np.ndarray]) -> sitk.Image:
    """Crop an image to the desired size around a given centre.

    Note that the cropped image might be smaller than size in a particular
    direction if the cropping window exceeds image boundaries.

    Parameters
    ----------
    image
        The image to crop.

    crop_centre
        The centre of the cropping window in image coordinates.

    size
        The size of the cropping window along each dimension in pixels. If
        float, assumes the same size in all directions. Alternatively, a
        sequence of floats can be passed to specify size along x, y and z
        dimensions. Passing 0 at any position will keep the original size along
        that dimension.

    Returns
    -------
    sitk.Image
        The cropped image.
    """
    crop_centre = np.asarray(crop_centre, dtype=np.float64)
    original_size = np.asarray(image.GetSize())

    if isinstance(size, int):
        size = np.array([size for _ in image.GetSize()])
    else:
        size = np.asarray(size)

    if (crop_centre < 0).any() or (crop_centre > original_size).any():
        raise ValueError(
            f"Crop centre outside image boundaries. Image size = {original_size}, crop centre = {crop_centre}"
        )

    min_coords = np.clip(
        np.floor(crop_centre - size / 2).astype(np.int64), 0, original_size)
    min_coords = np.where(size == 0, 0, min_coords)

    max_coords = np.clip(
        np.floor(crop_centre + size / 2).astype(np.int64), 0, original_size)
    max_coords = np.where(size == 0, original_size, max_coords)

    min_x, min_y, min_z = min_coords
    max_x, max_y, max_z = max_coords

    return image[min_x:max_x, min_y:max_y, min_z:max_z]
예제 #8
0
def level_set_cut_v2(image: sitk.Image, seed: list,
                     pred_image_name: str) -> (sitk.Image, int):
    assert isinstance(image,
                      sitk.Image) and image.GetPixelID() == sitk.sitkUInt8
    seed = map(lambda each: list(map(int, each)), seed)
    ft = sitk.Image(image.GetSize(), sitk.sitkUInt8)
    ft.CopyInformation(image)

    logger.info(pred_image_name + ' have ' + str(len(list(seed))) +
                ' lesion region(s)')

    stats = sitk.LabelStatisticsImageFilter()

    factor = 1.8
    lsFilter = sitk.ThresholdSegmentationLevelSetImageFilter()
    lsFilter.SetMaximumRMSError(0.02)
    lsFilter.SetNumberOfIterations(500)
    lsFilter.SetCurvatureScaling(.5)
    lsFilter.SetPropagationScaling(1)
    lsFilter.ReverseExpansionDirectionOn()

    ex_flag = False
    for each in seed:
        tmp_seg = sitk.Image(image.GetSize(), sitk.sitkUInt8)
        tmp_seg.CopyInformation(image)
        tmp_seg[each] = 1
        tmp_seg = sitk.BinaryDilate(tmp_seg, 3)
        assert isinstance(tmp_seg, sitk.Image)
        init_ls = sitk.SignedMaurerDistanceMap(tmp_seg,
                                               insideIsPositive=True,
                                               useImageSpacing=True)

        stats.Execute(image, tmp_seg)
        lower_threshold = stats.GetMean(1) - factor * stats.GetSigma(
            1)  # - math.log(stats.GetMean(1))
        upper_threshold = stats.GetMean(1) + factor * stats.GetSigma(
            1)  # + math.log(stats.GetMean(1))
        logger.info('the lower_threshold and upper_threshold :' + \
                    str(lower_threshold) + ' ' + str(upper_threshold))
        if lower_threshold == 0 or upper_threshold == 0:
            logger.warn('Threshold Error. Ignoring...')
            continue
        ex_flag = True
        lsFilter.SetLowerThreshold(lower_threshold)
        lsFilter.SetUpperThreshold(upper_threshold)
        ls = lsFilter.Execute(init_ls, sitk.Cast(image, sitk.sitkFloat32))
        assert isinstance(ls, sitk.Image)
        ft += ls

    if ex_flag == True:
        return ft, 1
    return ft, -1
예제 #9
0
def plot_2d_segmentation_series(path: str,
                                file_name_suffix: str,
                                image: sitk.Image,
                                ground_truth: sitk.Image,
                                segmentation: sitk.Image,
                                alpha: float = 0.5,
                                label: int = 1,
                                file_extension: str = '.png') -> None:
    """Plots an image with an overlaid mask, which indicates under-, correct-, and over-segmentation.

    Args:
        path (str): The output directory path.
        file_name_suffix (str): The output file name suffix.
        image (sitk.Image): The image.
        ground_truth (sitk.Image): The ground truth.
        segmentation (sitk.Image): The segmentation.
        alpha (float): The alpha blending value, between 0 (transparent) and 1 (opaque).
        label (int): The ground truth and segmentation label.
        file_extension (str): The output file extension (with or without dot).

    Examples:
        >>> img_t2 = sitk.ReadImage('your/path/image.mha')
        >>> ground_truth = sitk.ReadImage('your/path/ground_truth.mha')
        >>> segmentation = sitk.ReadImage('your/path/segmentation.mha')
        >>> plot_2d_segmentation_series('/your/path/', 'mysegmentation', img_t2, ground_truth, segmentation)
    """

    if not image.GetSize() == ground_truth.GetSize() == segmentation.GetSize():
        raise ValueError(
            'image, ground_truth, and segmentation must have equal size')
    if not image.GetDimension() == 3:
        raise ValueError('only 3-dimensional images supported')
    if not image.GetNumberOfComponentsPerPixel() == 1:
        raise ValueError('only scalar images supported')

    img_arr = sitk.GetArrayFromImage(image)
    gt_arr = sitk.GetArrayFromImage(ground_truth)
    seg_arr = sitk.GetArrayFromImage(segmentation)

    os.makedirs(path, exist_ok=True)
    file_extension = file_extension if file_extension.startswith(
        '.') else '.' + file_extension

    for slice in range(img_arr.shape[0]):
        full_file_path = os.path.join(
            path, file_name_suffix + str(slice) + file_extension)
        plot_2d_segmentation(full_file_path,
                             img_arr[slice, ...],
                             gt_arr[slice, ...],
                             seg_arr[slice, ...],
                             alpha=alpha,
                             label=label)
예제 #10
0
def compatible_metadata(image1: sitk.Image,
                        image2: sitk.Image,
                        check_size: bool = True,
                        check_spacing: bool = True,
                        check_origin: bool = True) -> bool:
    """ Compares the metadata of two images and determines if all checks are successful or not.
    Comparisons are carried out with a small tolerance (0.0001).

    @param image1: first image
    @param image2: second image
    @param check_size: if true, check if the sizes of the images are equal
    @param check_spacing: if true, check if the spacing of the images are equal
    @param check_origin: if true, check if the origin of the images are equal
    @return: true, if images are equal in all given checks, false if one of them failed
    """
    all_parameters_equal = True
    tolerance = 1e-4

    if check_size:
        size1 = image1.GetSize()
        size2 = image2.GetSize()
        if size1 != size2:
            all_parameters_equal = False
            print(f'Images do not have the same size ({size1} != {size2})')

    if check_spacing:
        spacing1 = image1.GetSpacing()
        spacing2 = image2.GetSpacing()
        if any(
                list(
                    abs(s1 - s2) > tolerance
                    for s1, s2 in zip(spacing1, spacing2))):
            all_parameters_equal = False
            print(
                f'Images do not have the same spacing ({spacing1} != {spacing2})'
            )

    if check_origin:
        origin1 = image1.GetOrigin()
        origin2 = image2.GetOrigin()
        if any(
                list(
                    abs(o1 - o2) > tolerance
                    for o1, o2 in zip(origin1, origin2))):
            all_parameters_equal = False
            print(
                f'Images do not have the same origin ({origin1} != {origin2})')

    return all_parameters_equal
예제 #11
0
def image_resample(image: sitk.Image):
    '''
    image = sitk.ReadImage(image_path)
    使用 SimpleITK 自带函数重新缩放图像
    '''
    origin_spacing = image.GetSpacing()  #  获取源分辨率
    origin_size = image.GetSize()

    new_spacing = [1, 1, 1]  #  设置新分辨率

    resample = sitk.ResampleImageFilter()
    resample.SetInterpolator(sitk.sitkLinear)
    resample.SetDefaultPixelValue(0)

    resample.SetOutputSpacing(new_spacing)
    resample.SetOutputOrigin(image.GetOrigin())
    resample.SetOutputDirection(image.GetDirection())

    #  计算新图像的大小
    new_size = [
        int(np.round(origin_size[0] * (origin_spacing[0] / new_spacing[0]))),
        int(np.round(origin_size[1] * (origin_spacing[1] / new_spacing[1]))),
        int(np.round(origin_size[2] * (origin_spacing[2] / new_spacing[2])))
    ]
    resample.SetSize(new_size)

    new_image = resample.Execute(image)
    return new_image
예제 #12
0
    def assert_img_properties(img: SimpleITK.Image,
                              internal_image: SimpleITKImage):
        color_space = {
            1: ColorSpace.GRAY,
            3: ColorSpace.RGB,
            4: ColorSpace.RGBA,
        }

        assert internal_image.color_space == color_space.get(
            img.GetNumberOfComponentsPerPixel())
        if img.GetDimension() == 4:
            assert internal_image.timepoints == img.GetSize()[-1]
        else:
            assert internal_image.timepoints is None
        if img.GetDepth():
            assert internal_image.depth == img.GetDepth()
            assert internal_image.voxel_depth_mm == img.GetSpacing()[2]
        else:
            assert internal_image.depth is None
            assert internal_image.voxel_depth_mm is None

        assert internal_image.width == img.GetWidth()
        assert internal_image.height == img.GetHeight()
        assert internal_image.voxel_width_mm == approx(img.GetSpacing()[0])
        assert internal_image.voxel_height_mm == approx(img.GetSpacing()[1])
    def __call__(self, x: sitk.Image) -> sitk.Image:
        """Apply the transform.

        Parameters
        ----------
        image
            Image to transform.

        Returns
        -------
        sitk.Image
            The transformed image.
        """
        angle = -self.max_angle + 2 * self.max_angle * torch.rand(1).item()
        rotation_centre = np.array(x.GetSize()) / 2
        rotation_centre = x.TransformContinuousIndexToPhysicalPoint(
            rotation_centre)

        rotation = sitk.Euler3DTransform(
            rotation_centre,
            0,  # the angle of rotation around the x-axis, in radians -> coronal rotation
            0,  # the angle of rotation around the y-axis, in radians -> saggittal rotation
            angle,  # the angle of rotation around the z-axis, in radians -> axial rotation
            (0., 0., 0.)  # no translation
        )
        return sitk.Resample(x, x, rotation, sitk.sitkLinear, self.fill_value)
 def __init__(self, data: sitk.Image, segmentation: sitk.Image, bb_size):
     self.segmentation = segmentation
     # place bb so that the segmentation is centered
     self.offset = (
         (np.array(bb_size) - np.array(self.segmentation.GetSize())) /
         2).astype(int)
     # adjust offset if resulting bb is out of bounds
     segmentation_origin_in_data = data.TransformPhysicalPointToIndex(
         segmentation.GetOrigin())
     self.offset = [
         seg_or if seg_or - off < 0 else off
         for seg_or, off in zip(segmentation_origin_in_data, self.offset)
     ]
     self.offset = [
         off + ((seg_or - off + bb_sz) - data_sz)
         if seg_or - off + bb_sz > data_sz else off for seg_or, off, bb_sz,
         data_sz in zip(segmentation_origin_in_data, self.offset, bb_size,
                        data.GetSize())
     ]
     cropped_origin = np.array(segmentation_origin_in_data) - self.offset
     assert all(cr_or >= 0 for cr_or in cropped_origin), \
         f"Data size: {data.GetSize()}, BB size: {bb_size}, BB origin: {cropped_origin}"
     assert all(cr_or + bb_s <= data_s for cr_or, bb_s, data_s in zip(cropped_origin, bb_size, data.GetSize())), \
         f"Data size: {data.GetSize()}, BB size: {bb_size}, BB origin: {cropped_origin}"
     self.data = data[cropped_origin[0]:cropped_origin[0] + bb_size[0],
                      cropped_origin[1]:cropped_origin[1] + bb_size[1],
                      cropped_origin[2]:cropped_origin[2] + bb_size[2]]
예제 #15
0
def deformation_to_displacement(deformation: sitk.Image) -> sitk.Image:
    r""" Convert a deformation field to a displacement field.

    A deformation field :math:`D` is given by the sum of the identity
    transform and a displacement field :math:`d`:

    .. math::
        D(x) = x + d(x)

    Parameters
    ----------
    deformation :
        Input deformation field.

    Returns
    -------
    sitk.Image
        Displacement field associated to the deformation.
    """

    a = sitk.GetArrayFromImage(deformation)

    for x, y, z in np.ndindex(deformation.GetSize()):
        a[z, y, x, 0] -= x
        a[z, y, x, 1] -= y
        a[z, y, x, 2] -= z

    displacement = sitk.GetImageFromArray(a)
    displacement.CopyInformation(deformation)

    return displacement
예제 #16
0
def regularise(jacobian: sitk.Image, epsilon: float = 1e-5) -> sitk.Image:
    r""" Regularise the Jacobian, removing singularities.

    Given a 3D scalar image, replace all the entries that are smaller
    than `epsilon` with `epsilon`.

    Parameters
    ----------
    jacobian : sitk.Image
        Input Jacobian map
    epsilon  : float
        Lower threshold for the Jacobian.

    Returns
    -------
    sitk.Image
        Thresholded Jacobian.
    """

    jacobian = sitk.Cast(jacobian, sitk_float_type)

    if (3 != len(jacobian.GetSize())):
        raise Exception("Wrong jacobian dimensionality")

    # Object for the result
    result = sitk.Image(jacobian)

    # Call function from the underlying C library
    _disptools.regularise(sitk.GetArrayViewFromImage(result), epsilon)

    return result
예제 #17
0
def sitk_to_nib(
    image: sitk.Image,
    keepdim: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
    data = sitk.GetArrayFromImage(image).transpose()
    num_components = image.GetNumberOfComponentsPerPixel()
    if num_components == 1:
        data = data[np.newaxis]  # add channels dimension
    input_spatial_dims = image.GetDimension()
    if not keepdim:
        data = ensure_4d(data, False, num_spatial_dims=input_spatial_dims)
    assert data.shape[0] == num_components
    assert data.shape[-input_spatial_dims:] == image.GetSize()
    spacing = np.array(image.GetSpacing())
    direction = np.array(image.GetDirection())
    origin = image.GetOrigin()
    if len(direction) == 9:
        rotation = direction.reshape(3, 3)
    elif len(direction) == 4:  # ignore first dimension if 2D (1, 1, H, W)
        rotation_2d = direction.reshape(2, 2)
        rotation = np.eye(3)
        rotation[1:3, 1:3] = rotation_2d
        spacing = 1, *spacing
        origin = 0, *origin
    rotation = np.dot(FLIP_XY, rotation)
    rotation_zoom = rotation * spacing
    translation = np.dot(FLIP_XY, origin)
    affine = np.eye(4)
    affine[:3, :3] = rotation_zoom
    affine[:3, 3] = translation
    return data, affine
예제 #18
0
def sitk_to_nib(
    image: sitk.Image,
    keepdim: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
    data = sitk.GetArrayFromImage(image).transpose()
    num_components = image.GetNumberOfComponentsPerPixel()
    if num_components == 1:
        data = data[np.newaxis]  # add channels dimension
    input_spatial_dims = image.GetDimension()
    if input_spatial_dims == 2:
        data = data[..., np.newaxis]
    if not keepdim:
        data = ensure_4d(data, num_spatial_dims=input_spatial_dims)
    assert data.shape[0] == num_components
    assert data.shape[1:1 + input_spatial_dims] == image.GetSize()
    spacing = np.array(image.GetSpacing())
    direction = np.array(image.GetDirection())
    origin = image.GetOrigin()
    if len(direction) == 9:
        rotation = direction.reshape(3, 3)
    elif len(direction) == 4:  # ignore first dimension if 2D (1, W, H, 1)
        rotation_2d = direction.reshape(2, 2)
        rotation = np.eye(3)
        rotation[:2, :2] = rotation_2d
        spacing = *spacing, 1
        origin = *origin, 0
    else:
        raise RuntimeError(f'Direction not understood: {direction}')
    rotation = np.dot(FLIP_XY, rotation)
    rotation_zoom = rotation * spacing
    translation = np.dot(FLIP_XY, origin)
    affine = np.eye(4)
    affine[:3, :3] = rotation_zoom
    affine[:3, 3] = translation
    return data, affine
예제 #19
0
    def execute(self,
                image: sitk.Image,
                params: FilterParams = None) -> sitk.Image:
        # Cast the image to a Pytorch tensor
        image_arr = torch.from_numpy(sitk.GetArrayFromImage(image))

        # Compute the 3D-HOG features using Pytorch
        features = self.hog_module(image_arr)

        # Detach the features from the computational graph, write the memory to the RAM and
        # cast the features to be a np.ndarray
        features_np = features.detach().cpu().numpy()

        del features
        torch.cuda.empty_cache()

        features_np = np.squeeze(features_np)
        features_np = np.transpose(features_np, (1, 2, 3, 0))

        image_size = image.GetSize()
        offset = self.hog_module.Get_block_size() // 2

        features_np = features_np[offset:image_size[2] + offset,
                                  offset:image_size[1] + offset,
                                  offset:image_size[0] + offset]

        img_out = sitk.GetImageFromArray(features_np)
        img_out.CopyInformation(image)
        return img_out
예제 #20
0
def get_largest_segment(image: sitk.Image,
                        extract_background=False) -> sitk.Image:

    # get number of labels
    img_statistic = sitk.StatisticsImageFilter()
    img_statistic.Execute(image)
    min_val = int(img_statistic.GetMinimum())
    max_val = int(img_statistic.GetMaximum())

    # create empty output image
    img_out = sitk.Image(image.GetSize(), image.GetPixelIDValue())
    img_out.CopyInformation(image)

    # setup connected components filter
    connected_comp_filter = sitk.ConnectedComponentImageFilter()
    connected_comp_filter.FullyConnectedOn()

    # extract largest segment
    for label in range(min_val, max_val + 1):
        img_label = image == label
        seg = connected_comp_filter.Execute(img_label != 0)

        if label == 0:
            # create temporary a new label for largest connected comp of the background
            seg = (sitk.RelabelComponent(seg)
                   == 1) * (max_val + 1) * extract_background
        else:
            seg = (sitk.RelabelComponent(seg) == 1) * label
        img_out = img_out + seg

    # arr_image = sitk.GetArrayFromImage(img_out)
    # plt.imshow(arr_image[60,:,:], cmap='jet')
    # plt.show()

    return img_out
예제 #21
0
    def _image_as_numpy_array(image: sitk.Image, mask: np.ndarray = None):
        """Gets an image as numpy array where each row is a voxel and each column is a feature.

        Args:
            image (sitk.Image): The image.
            mask (np.ndarray): A mask defining which voxels to return. True is background, False is a masked voxel.

        Returns:
            np.ndarray: An array where each row is a voxel and each column is a feature.
        """

        number_of_components = image.GetNumberOfComponentsPerPixel(
        )  # the number of features for this image
        no_voxels = np.prod(image.GetSize())
        image = sitk.GetArrayFromImage(image)

        if mask is not None:
            no_voxels = np.size(mask) - np.count_nonzero(mask)

            if number_of_components == 1:
                masked_image = np.ma.masked_array(image, mask=mask)
            else:
                # image is a vector image, make a vector mask
                vector_mask = np.expand_dims(
                    mask, axis=3)  # shape is now (z, x, y, 1)
                vector_mask = np.repeat(
                    vector_mask, number_of_components,
                    axis=3)  # shape is now (z, x, y, number_of_components)
                masked_image = np.ma.masked_array(image, mask=vector_mask)

            image = masked_image[~masked_image.mask]

        return image.reshape((no_voxels, number_of_components))
예제 #22
0
def resample_image_to_voxel_size(image: sitk.Image,
                                 new_spacing,
                                 interpolator,
                                 fill_value=0) -> sitk.Image:
    """
        resample a 3d image to a given voxel size (dx, dy, dz)
    :param image: 3D sitk image
    :param new_spacing: voxel size (dx, dy, dz)
    :param interpolator:
    :param fill_value: pixel value when a transformed pixel is outside of the image.
    :return:
    """

    # computing image shape based on the desired voxel size
    orig_size = np.array(image.GetSize())
    orig_spacing = np.array(image.GetSpacing())
    new_spacing = np.array(new_spacing)
    new_size = orig_size * (orig_spacing / new_spacing)

    #  Image dimensions are in integers
    new_size = np.ceil(new_size).astype(np.int)

    #  SimpleITK expects lists, not ndarrays
    new_size = [int(s) for s in new_size]

    return resample_sitk_image(image, new_size, interpolator, fill_value)
예제 #23
0
def downsample_image(image: sitk.Image, input_pixel_size, output_pixel_size):
    image.SetSpacing([input_pixel_size for i in range(3)])
    size = [
        int(math.ceil(i / output_pixel_size * input_pixel_size))
        for i in image.GetSize()
    ]
    return sitk.Resample(image, size, sitk.Transform(), sitk.sitkLinear,
                         [0, 0, 0], [output_pixel_size for i in range(3)])
예제 #24
0
def get_raw_liver_image(raw: sitk.Image, liver: sitk.Image) -> sitk.Image:
    size = raw.GetSize()
    output = sitk.Image(size, sitk.sitkUInt8)
    output.CopyInformation(raw)
    for i in range(0, size[0]):
        for j in range(0, size[1]):
            if liver.GetPixel(i, j) == 1:
                output.SetPixel(i, j, raw.GetPixel(i, j))
    return output
예제 #25
0
def resample_sitk_image(sitk_image: sitk.Image,
                        new_size,
                        interpolator="gaussian",
                        fill_value=0) -> sitk.Image:
    """
        modified version from:
            https://github.com/jonasteuwen/SimpleITK-examples/blob/master/examples/resample_isotropically.py
    """

    # if pass a path to image
    if isinstance(sitk_image, str):
        sitk_image = sitk.ReadImage(sitk_image)

    assert (interpolator in _SITK_INTERPOLATOR_DICT.keys()
            ), "`interpolator` should be one of {}".format(
                _SITK_INTERPOLATOR_DICT.keys())

    if not interpolator:
        interpolator = "linear"
        pixelid = sitk_image.GetPixelIDValue()

        if pixelid not in [1, 2, 4]:
            raise NotImplementedError(
                "Set `interpolator` manually, "
                "can only infer for 8-bit unsigned or 16, 32-bit signed integers"
            )

    #  8-bit unsigned int
    if sitk_image.GetPixelIDValue() == 1:
        # if binary mask interpolate it as nearest
        interpolator = "nearest"

    sitk_interpolator = _SITK_INTERPOLATOR_DICT[interpolator]
    orig_pixelid = sitk_image.GetPixelIDValue()
    orig_origin = sitk_image.GetOrigin()
    orig_direction = sitk_image.GetDirection()

    # new spacing based on the desired output shape
    new_spacing = tuple(
        np.array(sitk_image.GetSpacing()) * np.array(sitk_image.GetSize()) /
        np.array(new_size))

    # setup image resampler - SimpleITK 2.0
    resample_filter = sitk.ResampleImageFilter()
    resample_filter.SetOutputSpacing(new_spacing)
    resample_filter.SetSize(new_size)
    resample_filter.SetOutputDirection(orig_direction)
    resample_filter.SetOutputOrigin(orig_origin)
    resample_filter.SetTransform(sitk.Transform())
    resample_filter.SetDefaultPixelValue(orig_pixelid)
    resample_filter.SetInterpolator(sitk_interpolator)
    resample_filter.SetDefaultPixelValue(fill_value)
    # run it
    resampled_sitk_image = resample_filter.Execute(sitk_image)

    return resampled_sitk_image
예제 #26
0
def resize(image: sitk.Image,
           size: Union[int, Sequence[int], np.ndarray],
           interpolation: str = "linear",
           anti_alias: bool = True,
           anti_alias_sigma: Optional[float] = None) -> sitk.Image:
    """Resize image to a given size by resampling coordinates.

    Parameters
    ----------
    image
        The image to be resize.

    size
        The new image size. If float, assumes the same size in all directions.
        Alternatively, a sequence of floats can be passed to specify size along
        each dimension. Passing 0 at any position will keep the original
        size along that dimension.

    interpolation, optional
        The interpolation method to use. Valid options are:
        - "linear" for bi/trilinear interpolation (default)
        - "nearest" for nearest neighbour interpolation
        - "bspline" for order-3 b-spline interpolation

    anti_alias, optional
        Whether to smooth the image with a Gaussian kernel before resampling.
        Only used when downsampling, i.e. when `size < image.GetSize()`.
        This should be used to avoid aliasing artifacts.

    anti_alias_sigma, optional
        The standard deviation of the Gaussian kernel used for anti-aliasing.

    Returns
    -------
    sitk.Image
        The resized image.
    """

    original_size = np.array(image.GetSize())
    original_spacing = np.array(image.GetSpacing())

    if isinstance(size, (float, int)):
        new_size = np.repeat(size, len(original_size)).astype(np.float64)
    else:
        size = np.asarray(size)
        new_size = np.where(size == 0, original_size, size)

    new_spacing = original_spacing * original_size / new_size

    return resample(image,
                    new_spacing,
                    anti_alias=anti_alias,
                    anti_alias_sigma=anti_alias_sigma,
                    interpolation=interpolation)
예제 #27
0
def display_info(img: sitk.Image) -> None:
    """display information about a sitk.Image

    Args:
        img (sitk.Image): [sitk image]
    """
    print('img information :')
    print('\t Origin    :', img.GetOrigin())
    print('\t Size      :', img.GetSize())
    print('\t Spacing   :', img.GetSpacing())
    print('\t Direction :', img.GetDirection())
예제 #28
0
    def set_origin_image(self, origin_img:sitk.Image) -> None : 
        """method to set the origin sitk.Image on which we want to resample 

        Args:
            origin_img (sitk.Image): []
        """
        self.origin_img = origin_img 
        self.origin_size = origin_img.GetSize()
        self.origin_spacing = origin_img.GetSpacing()
        self.origin_direction = origin_img.GetDirection()
        self.origin_origin = origin_img.GetOrigin()
예제 #29
0
    def _assert_3d(self, image: sitk.Image, is_vector=False):
        self.assertEqual(self.properties_3d.size, image.GetSize())

        if is_vector:
            self.assertEqual(self.no_vector_components,
                             image.GetNumberOfComponentsPerPixel())
        else:
            self.assertEqual(1, image.GetNumberOfComponentsPerPixel())

        self.assertEqual(self.origin_spacing_3d, image.GetOrigin())
        self.assertEqual(self.origin_spacing_3d, image.GetSpacing())
        self.assertEqual(self.direction_3d, image.GetDirection())
예제 #30
0
def sitk_to_itk(image: sitk.Image) -> Any:
    r""" Function to convert an image object from SimpleITK to ITK.

    .. note::
        Data is copied to the new object (deep copy).

    Parameters
    ----------
    image : sitk.Image
        Input image.

    Returns
    -------
    any
        Image in ITK format.
    """

    if 'itk' not in sys.modules:
        raise Exception(
            'sitk_to_itk: itk module is required to use this feature.')

    a = sitk.GetArrayViewFromImage(image)

    if len(a.shape) < 4:
        result = itk.GetImageFromArray(a)

    else:
        # NOTE: This workaround is implemented this way since it
        # seems that itk.GetImageFromArray() is not working properly
        # with vector images.

        region = itk.ImageRegion[3]()
        region.SetSize(image.GetSize())
        region.SetIndex((0, 0, 0))

        PixelType = itk.Vector[itk_float_type, 3]
        ImageType = itk.Image[PixelType, 3]

        # Create new image
        result = ImageType.New()
        result.SetRegions(region)
        result.Allocate()
        result.SetSpacing(image.GetSpacing())

        # Copy image data
        b = itk.GetArrayViewFromImage(result)
        b[:] = a[:]

    result.SetSpacing(image.GetSpacing())
    result.SetOrigin(image.GetOrigin())

    return result