示例#1
0
def apply_transform_pandas(fixed_image: sitk.Image,
                           moving_image: sitk.Image,
                           reference_path,
                           index=None):

    transform_path = os.path.join(reference_path.parent, 'Transforms.csv')

    if index is None:
        transform_params = blk.read_pandas_row(transform_path,
                                               reference_path.name, 'Image')
    else:
        transform_params = blk.read_pandas_row(transform_path, index, 'Image')

    transform = sitk.AffineTransform(2)

    transform.Rotate(0, 1, transform_params['Rotation'], pre=True)

    matrix = [
        transform_params['Matrix Top Left'],
        transform_params['Matrix Top Right'],
        transform_params['Matrix Bottom Left'],
        transform_params['Matrix Bottom Right']
    ]

    transform.SetMatrix(matrix)
    transform.SetTranslation(
        [transform_params['X Translation'], transform_params['Y Translation']])

    origin = (int(transform_params['X Origin']),
              int(transform_params['Y Origin']))
    moving_image.SetOrigin(origin)

    return sitk.Resample(moving_image, fixed_image, transform, sitk.sitkLinear,
                         0.0, moving_image.GetPixelID())
示例#2
0
def fitting_index(image: sitk.Image,
                  norm: Union[int, str] = 2.0,
                  centre: Tuple[float, float, float] = None,
                  radius: float = None,
                  padding: bool = True) -> float:
    r""" Compute the fitting index of an input object.

    The fitting index of order `p` is defined as the Jaccard coefficient
    computed between the input object and a p-ball centred in the object's
    centre of mass.

    Parameters
    ----------
    image : sitk.Image
        Input binary image of the object.
    norm : Union[int,str]
        Order of the Minkowski norm ('inf' or 'max' to use the Chebyshev norm).
    centre : Tuple[float, float, float]
        Forces the p-ball to be centred in a specific point.
    radius : float
        Force the radius of the p-ball.
    padding : bool
        If `True`, add enough padding to be sure that the ball will entirely
        fit within the volume.

    Returns
    -------
    float
        Value of the fitting index.
    """

    if image.GetPixelID() != sitk.sitkUInt8:
        raise Exception('Unsupported %s pixel type' %
                        image.GetPixelIDTypeAsString())

    if centre is None:
        # Use the centroid as centre
        lssif = sitk.LabelShapeStatisticsImageFilter()
        lssif.Execute(image)
        centre = lssif.GetCentroid(1)

    if padding:
        # Add some padding to be sure that an isovolumetric 1-ball can fit
        # within the same volume of a sphere touching the boundary
        pad = tuple([x // 4 for x in image.GetSize()])
        image = sitk.ConstantPad(image, pad, pad, 0)
        image.SetOrigin((0, 0, 0))
        centre = tuple([x + y for x, y in zip(centre, pad)])

    if radius is None:
        radius = isovolumteric_radius(image, norm)

    size = image.GetSize()

    sphere = drawing.create_sphere(radius, size=size, centre=centre,
                                   norm=norm) > 0

    return jaccard(image, sphere)
示例#3
0
def copy_meta_data_itk(source: sitk.Image, target: sitk.Image) -> sitk.Image:
    """
    Copy meta data between files

    Args:
        source: source file
        target: target file

    Returns:
        sitk.Image: target file with copied meta data
    """
    # for i in source.GetMetaDataKeys():
    #     target.SetMetaData(i, source.GetMetaData(i))
    raise NotImplementedError("Does not work!")
    target.SetOrigin(source.GetOrigin())
    target.SetDirection(source.GetDirection())
    target.SetSpacing(source.GetSpacing())
    return target
示例#4
0
def match_world_info(source: sitk.Image,
                     target: sitk.Image,
                     spacing: Union[bool, Tuple[int], List[int]] = True,
                     origin: Union[bool, Tuple[int], List[int]] = True,
                     direction: Union[bool, Tuple[int], List[int]] = True):
    """Copy world information (eg spacing, origin, direction) from one
    image object to another.

    This matching is sometimes necessary for slight differences in
    metadata perhaps from founding that may prevent ITK filters from executing.

    Args:
        source (:obj:`sitk.Image`): Source object whose relevant metadata
            will be copied into ``target``.
        target (:obj:`sitk.Image`): Target object whose corresponding
            metadata will be overwritten by that of ``source``.
        spacing: True to copy the spacing from ``source`` to ``target``, or
            the spacing to set in ``target``; defaults to True.
        origin: True to copy the origin from ``source`` to ``target``, or
            the origin to set in ``target``; defaults to True.
        direction: True to copy the direction from ``source`` to ``target``, or
            the direction to set in ``target``; defaults to True.

    """
    # get the world info from the source if not already set
    if spacing is True:
        spacing = source.GetSpacing()
    if origin is True:
        origin = source.GetOrigin()
    if direction is True:
        direction = source.GetDirection()

    # set the values in the target
    _logger.debug(
        "Adjusting spacing from %s to %s, origin from %s to %s, "
        "direction from %s to %s", target.GetSpacing(), spacing,
        target.GetOrigin(), origin, target.GetDirection(), direction)
    if spacing:
        target.SetSpacing(spacing)
    if origin:
        target.SetOrigin(origin)
    if direction:
        target.SetDirection(direction)
示例#5
0
def sitk_copy_metadata(img_source: sitk.Image, img_target: sitk.Image) -> sitk.Image:
    """
    Copy metadata (spacing, origin, direction) from source to target image

    Args
        img_source: source image
        img_target: target image

    Returns:
        SimpleITK.Image: target image with copied metadata
    """ 
    raise RuntimeError("Deprecated")
    spacing = img_source.GetSpacing()
    img_target.SetSpacing(spacing)

    origin = img_source.GetOrigin()
    img_target.SetOrigin(origin)

    direction = img_source.GetDirection()
    img_target.SetDirection(direction)
    return img_target
示例#6
0
def copy_geometry(image: sitk.Image, ref: sitk.Image):
    image.SetOrigin(ref.GetOrigin())
    image.SetDirection(ref.GetDirection())
    image.SetSpacing(ref.GetSpacing())
    return image
示例#7
0
def calc_surface_height_map(img: sitk.Image):
    img.SetSpacing([1, 1, 1])
    img.SetOrigin([0, 0, 0])
    src_img = sitk.GetArrayFromImage(img)
    src_img = torch.FloatTensor(np.float32(src_img))
    size_ = img.GetSize()
    tf = sitk.AffineTransform(3)
    tf.Scale([2, 2, 1])
    proc_size = [int(size_[0] / 2), int(size_[1] / 2), size_[2]]
    img = sitk.Resample(img, proc_size, tf)
    gpu_load_grad = False
    if proc_size[0] * proc_size[1] * proc_size[2] < 900000000:
        gpu_load_grad = True
    img = fill_blank(img)
    img = sitk.Cast(img, sitk.sitkFloat32)
    img = (sitk.Log(img) - 4.6) * 39.4
    img = sitk.Clamp(img, sitk.sitkFloat32, 0, 255)

    def get_edge_grad(img_: sitk.Image, ul):
        grad_m = sitk.Convolution(img_, k_grad)
        if ul == 1:
            grad_m = sitk.Clamp(grad_m, sitk.sitkFloat32, 0, 65535)
        else:
            grad_m = sitk.Clamp(grad_m, sitk.sitkFloat32, -65535, 0)
        grad_m = ul * sitk.Convolution(grad_m, k_grad2)
        grad_m = sitk.GetArrayFromImage(grad_m)
        grad_m = torch.FloatTensor(grad_m)
        if gpu_load_grad:
            grad_m = grad_m.cuda()
        return grad_m

    u_grad_m = get_edge_grad(img, 1)
    l_grad_m = get_edge_grad(img, -1)
    img = torch.Tensor(sitk.GetArrayFromImage(img)).byte()
    #if gpu_load_grad:
    img = img.cuda()
    u = (torch.rand(1, 1, u_grad_m.shape[1], u_grad_m.shape[2]) *
         (u_grad_m.shape[0]) / 2) + l_grad_m.shape[0] / 2
    l = (l_grad_m.shape[0] / 1 +
         torch.rand(1, 1, l_grad_m.shape[1], l_grad_m.shape[2]) *
         (l_grad_m.shape[0]) / 2)

    lr = 0.002
    momentum = 0.9
    lr_decay = 0.0002
    u_grad = torch.zeros(1, 1, u_grad_m.shape[1], u_grad_m.shape[2])
    l_grad = torch.zeros(1, 1, l_grad_m.shape[1], l_grad_m.shape[2])
    k = k_surface_band
    #if gpu_load_grad:
    u = u.cuda()
    l = l.cuda()
    u_grad = u_grad.cuda()
    l_grad = l_grad.cuda()
    k = k_surface_band_c
    gu, gl = None, None
    for i in range(8000):

        def calc_grad(s, grad, grad_m, ul):
            g = torch.clamp(s[0], 0, grad_m.shape[0] - 1).long()
            u_plane_bend = F.pad(s, (1, 1, 1, 1), mode='reflect')
            u_plane_bend = F.conv2d(u_plane_bend, k, padding=0)[0].data
            if gpu_load_grad:
                u_edge = torch.gather(grad_m, 0, g)
            else:
                u_edge = torch.gather(grad_m, 0, g.cpu()).cuda()
            grad = u_edge + \
                   3000 * u_plane_bend + \
                   0.05 * ul * torch.clamp(l - u - 75, -1000, 1) + \
                   momentum * grad
            return grad, g

        u_grad, gu = calc_grad(u, u_grad, u_grad_m, 1)
        l_grad, gl = calc_grad(l, l_grad, l_grad_m, -1)
        u += lr * u_grad
        l += lr * l_grad
        u = torch.clamp(u, 0, u_grad_m.shape[0] - 1)
        l = torch.clamp(l, 0, u_grad_m.shape[0] - 1)
        lr *= (1 - lr_decay)
        #v = torch.gather(img, 0, gu)[0].cpu().numpy()
        #print(torch.mean(u))
        #if i % 10 == 0:
        #    cv2.imshow('s', np.uint8(v * 5))
        #    cv2.waitKey(1)
        print(i)

    umap = np.float32((u + 1).cpu().numpy()[0][0])
    lmap = np.float32((l - 1).cpu().numpy()[0][0])
    umap = cv2.resize(umap, (size_[0], size_[1]))
    lmap = cv2.resize(lmap, (size_[0], size_[1]))
    u_surface = torch.gather(src_img, 0,
                             torch.Tensor(np.array(
                                 [umap])).long())[0].cpu().numpy()
    l_surface = torch.gather(src_img, 0,
                             torch.Tensor(np.array(
                                 [lmap])).long())[0].cpu().numpy()
    u_surface = np.clip((np.log(u_surface) - 4.6) * 39.4, 0, 255)
    l_surface = np.clip((np.log(l_surface) - 4.6) * 39.4, 0, 255)

    umap = sitk.GetImageFromArray(umap)
    lmap = sitk.GetImageFromArray(lmap)
    u_surface = sitk.GetImageFromArray(np.uint8(u_surface))
    l_surface = sitk.GetImageFromArray(np.uint8(l_surface))

    return umap, lmap, u_surface, l_surface
def align_surfaces(prev_surface,
                   next_surface,
                   ref_img: sitk.Image = None,
                   prev_points=None,
                   next_points=None,
                   outside_brightness=2,
                   nonrigid=True,
                   ref_size=None,
                   ref_scale=1,
                   use_rigidity_mask=False,
                   **kwargs):
    prev_df1, next_df1 = None, None
    if prev_surface is not None:
        prev_surface = prev_surface[:, :, 0]
    if next_surface is not None:
        next_surface = next_surface[:, :, 0]
    if ref_img is not None:
        ref_img = ref_img[:, :, 0]
    s1, s2 = prev_surface, next_surface
    if ref_img is None:
        ref_img = prev_surface
        if prev_surface is None:
            ref_img = next_surface
        ref_scale = 1
    if ref_size is not None:
        origin = [(ref_size[i] - ref_img.GetSize()[i] * ref_scale) / 2
                  for i in range(2)]
        ref_img.SetSpacing([ref_scale for i in range(2)])
        ref_img.SetOrigin(origin)
        ref_img = sitk.Resample(ref_img, ref_size, sitk.Transform(),
                                sitk.sitkLinear, [0, 0], [1, 1])
    if prev_surface is not None:
        prev_surface = fill_outside(prev_surface, outside_brightness)
        prev_surface, prev_transform1 = get_align_transform(
            ref_img, prev_surface,
            [os.path.join(PARAMETER_DIR, 'tp_align_surface_rigid.txt')])
        prev_df1 = sitk.TransformToDisplacementField(prev_transform1,
                                                     sitk.sitkVectorFloat64,
                                                     ref_img.GetSize(),
                                                     ref_img.GetOrigin(),
                                                     ref_img.GetSpacing(),
                                                     ref_img.GetDirection())
    else:
        next_surface = fill_outside(next_surface, outside_brightness)
        next_surface, next_transform1 = get_align_transform(
            ref_img, next_surface,
            [os.path.join(PARAMETER_DIR, 'tp_align_surface_rigid.txt')])
        next_df1 = sitk.TransformToDisplacementField(next_transform1,
                                                     sitk.sitkVectorFloat64,
                                                     ref_img.GetSize(),
                                                     ref_img.GetOrigin(),
                                                     ref_img.GetSpacing(),
                                                     ref_img.GetDirection())
    if prev_surface is None or next_surface is None:
        return prev_df1, next_df1
    if not nonrigid:
        return prev_df1, prev_df1

    rigidity_mask = None
    if use_rigidity_mask == True:
        rigidity_mask = sitk.BinaryThreshold(next_surface,
                                             outside_brightness + 1)
        rigidity_mask = sitk.BinaryMorphologicalOpening(rigidity_mask)
    prev_surface = fill_outside(prev_surface, outside_brightness)
    next_surface = fill_outside(next_surface, outside_brightness)

    if prev_points is not None and next_points is not None:

        def get_transform_points(points, transform, file, s_):
            tf = transform.GetInverse()
            points = read_transformix_input_points(points)
            points = [tf.TransformPoint(p[:2]) for p in points]
            write_transformix_input_points(file, points, 2)

        with tempfile.TemporaryDirectory() as ELASTIX_TEMP:
            get_transform_points(prev_points, prev_transform1,
                                 os.path.join(ELASTIX_TEMP, 'prev.txt'), s1)
            get_transform_points(next_points,
                                 sitk.Transform(2, sitk.sitkIdentity),
                                 os.path.join(ELASTIX_TEMP, 'next.txt'), s2)
            _, transform2 = get_align_transform(
                prev_surface,
                next_surface, [
                    os.path.join(PARAMETER_DIR,
                                 'tp_align_surface_rigid_manual.txt'),
                    os.path.join(PARAMETER_DIR,
                                 'tp_align_surface_bspline_manual.txt')
                ],
                fixed_points=os.path.join(ELASTIX_TEMP, 'prev.txt'),
                moving_points=os.path.join(ELASTIX_TEMP, 'next.txt'),
                rigidity_mask=rigidity_mask)
    else:
        _, transform2 = get_align_transform(
            prev_surface,
            next_surface, [
                os.path.join(PARAMETER_DIR, 'tp_align_surface_rigid.txt'),
                os.path.join(PARAMETER_DIR, 'tp_align_surface_bspline3.txt')
            ],
            rigidity_mask=rigidity_mask)
    #sitk.WriteImage(prev_surface, 'F:/chaoyu/test/1_.mha')
    #sitk.WriteImage(_, 'F:/chaoyu/test/2_.mha')
    prev_df = prev_df1
    next_df = sitk.TransformToDisplacementField(transform2,
                                                sitk.sitkVectorFloat64,
                                                ref_img.GetSize(),
                                                ref_img.GetOrigin(),
                                                ref_img.GetSpacing(),
                                                ref_img.GetDirection())
    return prev_df, next_df
示例#9
0
def registration(fixed_image: sitk.Image,     # noqa: C901
                 moving_image: sitk.Image,
                 *,
                 do_fft_initialization=True,
                 do_affine2d=False,
                 do_affine3d=True,
                 ignore_spacing=True,
                 sigma=1.0,
                 auto_mask=False,
                 samples_per_parameter=5000,
                 expand=None) -> sitk.Transform:
    """Robust multi-phase registration for multi-panel confocal microscopy images.

    The fixed and moving image are expected to be the same molecular labeling, and the same imaged regioned.

    The phase available are:
      - fft initialization for translation estimation
      - 2D affine which can correct rotational acquisition problems, this phase is done on z-projections to optimize a \
       2D similarity transform followed by 2D affine
      - 3D affine robust mulit-level registration


    :param fixed_image: a scalar SimpleITK 3D Image
    :param moving_image: a scalar SimpleITK 3D Image
    :param do_fft_initialization: perform FFT based cross correlation for initialize translation
    :param do_affine2d: perform registration on 2D images from z-projection
    :param do_affine3d: multi-level affine transform
    :param ignore_spacing: internally adjust spacing magnitude to be near 1 to avoid numeric stability issues with \
    micro sized spacing
    :param sigma: scalar to change the amount of Gaussian smoothing performed
    :param auto_mask: ignore zero valued pixels connected to the image boarder
    :param samples_per_parameter: the number of image samples to used per transform parameter at full resolution
    :param expand: Perform super-sampling to increase number of z-slices by an integer factor. Super-sampling is \
    automatically performed when the number of z-slices is less than 5.
    :return: A SimpleITK transform mapping points from the fixed image to the moving. This may be a CompositeTransform.

    """
    # Identity transform will be returned if all registration steps are disabled by
    # the calling function.
    result = sitk.Transform()

    initial_translation_3d = True

    moving_mask = None
    fixed_mask = None

    number_of_samples_per_parameter = samples_per_parameter

    expand_factors = None

    if expand:
        expand_factors = [1, 1, expand]

    if fixed_image.GetPixelID() != sitk.sitkFloat32:
        fixed_image = sitk.Cast(fixed_image, sitk.sitkFloat32)

    # expand the image if at least 5 in any dimension
    if not expand_factors:
        expand_factors = [-(-5//s) for s in fixed_image.GetSize()]

    if any([e != 1 for e in expand_factors]):
        _logger.warning("Fixed image under sized in at lease one dimension!"
                        "\tApplying expand factors {0} to image size.".format(expand_factors))
        fixed_image = sitk.Expand(fixed_image, expandFactors=expand_factors)

    if moving_image.GetPixelID() != sitk.sitkFloat32:
        moving_image = sitk.Cast(moving_image, sitk.sitkFloat32)

    expand_factors = [-(-5//s) for s in moving_image.GetSize()]
    if any([e != 1 for e in expand_factors]):
        _logger.warning("WARNING: Moving image under sized in at lease one dimension!"
                        "\tApplying expand factors {0} to image size.".format(expand_factors))
        moving_image = sitk.Expand(moving_image, expandFactors=expand_factors)

    if auto_mask:
        fixed_mask = imgf.make_auto_mask(fixed_image)
        moving_mask = imgf.make_auto_mask(moving_image)

    if ignore_spacing:

        #
        # FORCE THE SPACING magnitude to be normalized near 1.0
        #

        spacing_magnitude = imgf.spacing_average_magnitude(fixed_image)

        _logger.info("Adjusting image spacing by {0}...".format(1.0/spacing_magnitude))

        new_spacing = [s/spacing_magnitude for s in fixed_image.GetSpacing()]
        _logger.info("\tFixed Image Spacing: {0}->{1}".format(fixed_image.GetSpacing(), new_spacing))
        fixed_image.SetSpacing(new_spacing)
        fixed_image.SetOrigin([o/spacing_magnitude for o in fixed_image.GetOrigin()])

        new_spacing = [s / spacing_magnitude for s in moving_image.GetSpacing()]
        _logger.info("\tMoving Image Spacing: {0}->{1}".format(moving_image.GetSpacing(), new_spacing))
        moving_image.SetSpacing(new_spacing)
        moving_image.SetOrigin([o/spacing_magnitude for o in moving_image.GetOrigin()])

        if moving_mask:
            moving_mask.SetSpacing(new_spacing)
            moving_mask.SetOrigin([o/spacing_magnitude for o in moving_mask.GetOrigin()])

        if fixed_mask:
            fixed_mask.SetSpacing(new_spacing)
            fixed_mask.SetOrigin([o / spacing_magnitude for o in fixed_mask.GetOrigin()])

    #
    #
    # Do FFT based translation initialization
    #
    #
    initial_translation = None
    if do_fft_initialization:
        initial_translation = imgf.fft_initialization(moving_image,
                                                      fixed_image,
                                                      bin_shrink=8,
                                                      projection=(not initial_translation_3d))
        result = sitk.TranslationTransform(len(initial_translation), initial_translation)

    #
    # Do 2D registration first
    #
    if do_affine2d:
        result = register_as_2d_affine(fixed_image, moving_image,
                                       sigma_base=sigma,
                                       initial_translation=initial_translation,
                                       fixed_image_mask=fixed_mask,
                                       moving_image_mask=moving_mask)

    if do_affine3d:

        _logger.info("Initializing Affine Registration...")

        if do_affine2d:
            # set the FFT xcoor initial z translation
            if do_fft_initialization and len(initial_translation) >= 3:

                # take the z-translation from the FFT
                translation = list(result.GetTranslation())
                translation[2] = initial_translation[2]
                result.SetTranslation(translation)

                _logger.info("Initialized 3D affine with z-translation... {0}".format(translation))

            affine = result
        else:
            affine = sitk.CenteredTransformInitializer(fixed_image,
                                                       moving_image,
                                                       sitk.AffineTransform(3),
                                                       sitk.CenteredTransformInitializerFilter.GEOMETRY)
            affine = sitk.AffineTransform(affine)

            if do_fft_initialization:
                if len(initial_translation) >= 3:
                    affine.SetTranslation(list(initial_translation))
                    _logger.info("Initialized 3D affine with z-translation... {0}".format(initial_translation))
                else:
                    affine.SetTranslation(list(initial_translation)+[0, ])

        affine_result = register_3d(fixed_image, moving_image,
                                    initial_transform=affine,
                                    sigma_base=sigma,
                                    fixed_image_mask=fixed_mask,
                                    moving_image_mask=moving_mask,
                                    number_of_samples_per_parameter=number_of_samples_per_parameter)

        result = affine_result

    if ignore_spacing:

        # Compose the scaling Transform into a single affine transform

        # The spacing of the image was modified to do registration, so we need to apply the appropriate scaling to
        # transform to the space the registration was done in.r

        scale = spacing_magnitude

        scale_transform = sitk.ScaleTransform(3)
        scale_transform.SetScale([scale]*3)

        result = sitk.CompositeTransform([sitk.Transform(scale_transform),
                                          result,
                                          scale_transform.GetInverse()])

        # if result was a composite transform then we have nested composite
        # transforms white need to be flattened for writing.

        result.FlattenTransform()

        _logger.info(result)

    return result