Пример #1
0
def downsample_image(image: sitk.Image, input_pixel_size, output_pixel_size):
    image.SetSpacing([input_pixel_size for i in range(3)])
    size = [
        int(math.ceil(i / output_pixel_size * input_pixel_size))
        for i in image.GetSize()
    ]
    return sitk.Resample(image, size, sitk.Transform(), sitk.sitkLinear,
                         [0, 0, 0], [output_pixel_size for i in range(3)])
Пример #2
0
def copy_meta_data_itk(source: sitk.Image, target: sitk.Image) -> sitk.Image:
    """
    Copy meta data between files

    Args:
        source: source file
        target: target file

    Returns:
        sitk.Image: target file with copied meta data
    """
    # for i in source.GetMetaDataKeys():
    #     target.SetMetaData(i, source.GetMetaData(i))
    raise NotImplementedError("Does not work!")
    target.SetOrigin(source.GetOrigin())
    target.SetDirection(source.GetDirection())
    target.SetSpacing(source.GetSpacing())
    return target
Пример #3
0
    def preprocess_reg_image_intensity(
            self, image: sitk.Image,
            preprocessing: ImagePreproParams) -> sitk.Image:
        """
        Preprocess image intensity data to single channel image.

        Parameters
        ----------
        image: sitk.Image
            reg_image to be preprocessed
        preprocessing: ImagePreproParams
            Parameters of the preprocessing

        Returns
        -------
        image: sitk.Image
            Preprocessed single-channel image
        """

        if preprocessing.image_type.value == "FL":
            preprocessing.invert_intensity = False
        elif preprocessing.image_type.value == "BF":
            preprocessing.max_int_proj = False
            preprocessing.contrast_enhance = False
            if self.is_rgb:
                preprocessing.invert_intensity = True

        if preprocessing.max_int_proj:
            image = sitk_max_int_proj(image)

        if preprocessing.contrast_enhance:
            image = contrast_enhance(image)

        if preprocessing.invert_intensity:
            image = sitk_inv_int(image)

        if preprocessing.custom_processing:
            for k, v in preprocessing.custom_processing.items():
                print(f"performing preprocessing: {k}")
                image = v(image)

        image.SetSpacing((self.image_res, self.image_res))

        return image
Пример #4
0
def match_world_info(source: sitk.Image,
                     target: sitk.Image,
                     spacing: Union[bool, Tuple[int], List[int]] = True,
                     origin: Union[bool, Tuple[int], List[int]] = True,
                     direction: Union[bool, Tuple[int], List[int]] = True):
    """Copy world information (eg spacing, origin, direction) from one
    image object to another.

    This matching is sometimes necessary for slight differences in
    metadata perhaps from founding that may prevent ITK filters from executing.

    Args:
        source (:obj:`sitk.Image`): Source object whose relevant metadata
            will be copied into ``target``.
        target (:obj:`sitk.Image`): Target object whose corresponding
            metadata will be overwritten by that of ``source``.
        spacing: True to copy the spacing from ``source`` to ``target``, or
            the spacing to set in ``target``; defaults to True.
        origin: True to copy the origin from ``source`` to ``target``, or
            the origin to set in ``target``; defaults to True.
        direction: True to copy the direction from ``source`` to ``target``, or
            the direction to set in ``target``; defaults to True.

    """
    # get the world info from the source if not already set
    if spacing is True:
        spacing = source.GetSpacing()
    if origin is True:
        origin = source.GetOrigin()
    if direction is True:
        direction = source.GetDirection()

    # set the values in the target
    _logger.debug(
        "Adjusting spacing from %s to %s, origin from %s to %s, "
        "direction from %s to %s", target.GetSpacing(), spacing,
        target.GetOrigin(), origin, target.GetDirection(), direction)
    if spacing:
        target.SetSpacing(spacing)
    if origin:
        target.SetOrigin(origin)
    if direction:
        target.SetDirection(direction)
Пример #5
0
def _make_3d(img_sitk: sitk.Image, spacing_z=1) -> sitk.Image:
    """Make a 2D image into a 3D image with a single plane.
    
    Args:
        img_sitk: Image in SimpleITK format.
        spacing_z: Z-axis spacing; defaults to 1.

    Returns:
        ``img_sitk`` converted to a 3D image if previously 2D.

    """
    spacing = img_sitk.GetSpacing()
    if len(spacing) == 2:
        # prepend an additional axis for 2D images to make them 3D
        img_np = sitk.GetArrayFromImage(img_sitk)[None]
        spacing = list(spacing) + [spacing_z]
        img_sitk = sitk.GetImageFromArray(img_np)
        img_sitk.SetSpacing(spacing)
    return img_sitk
Пример #6
0
def sitk_copy_metadata(img_source: sitk.Image, img_target: sitk.Image) -> sitk.Image:
    """
    Copy metadata (spacing, origin, direction) from source to target image

    Args
        img_source: source image
        img_target: target image

    Returns:
        SimpleITK.Image: target image with copied metadata
    """ 
    raise RuntimeError("Deprecated")
    spacing = img_source.GetSpacing()
    img_target.SetSpacing(spacing)

    origin = img_source.GetOrigin()
    img_target.SetOrigin(origin)

    direction = img_source.GetDirection()
    img_target.SetDirection(direction)
    return img_target
Пример #7
0
def copy_geometry(image: sitk.Image, ref: sitk.Image):
    image.SetOrigin(ref.GetOrigin())
    image.SetDirection(ref.GetDirection())
    image.SetSpacing(ref.GetSpacing())
    return image
Пример #8
0
def calc_surface_height_map(img: sitk.Image):
    img.SetSpacing([1, 1, 1])
    img.SetOrigin([0, 0, 0])
    src_img = sitk.GetArrayFromImage(img)
    src_img = torch.FloatTensor(np.float32(src_img))
    size_ = img.GetSize()
    tf = sitk.AffineTransform(3)
    tf.Scale([2, 2, 1])
    proc_size = [int(size_[0] / 2), int(size_[1] / 2), size_[2]]
    img = sitk.Resample(img, proc_size, tf)
    gpu_load_grad = False
    if proc_size[0] * proc_size[1] * proc_size[2] < 900000000:
        gpu_load_grad = True
    img = fill_blank(img)
    img = sitk.Cast(img, sitk.sitkFloat32)
    img = (sitk.Log(img) - 4.6) * 39.4
    img = sitk.Clamp(img, sitk.sitkFloat32, 0, 255)

    def get_edge_grad(img_: sitk.Image, ul):
        grad_m = sitk.Convolution(img_, k_grad)
        if ul == 1:
            grad_m = sitk.Clamp(grad_m, sitk.sitkFloat32, 0, 65535)
        else:
            grad_m = sitk.Clamp(grad_m, sitk.sitkFloat32, -65535, 0)
        grad_m = ul * sitk.Convolution(grad_m, k_grad2)
        grad_m = sitk.GetArrayFromImage(grad_m)
        grad_m = torch.FloatTensor(grad_m)
        if gpu_load_grad:
            grad_m = grad_m.cuda()
        return grad_m

    u_grad_m = get_edge_grad(img, 1)
    l_grad_m = get_edge_grad(img, -1)
    img = torch.Tensor(sitk.GetArrayFromImage(img)).byte()
    #if gpu_load_grad:
    img = img.cuda()
    u = (torch.rand(1, 1, u_grad_m.shape[1], u_grad_m.shape[2]) *
         (u_grad_m.shape[0]) / 2) + l_grad_m.shape[0] / 2
    l = (l_grad_m.shape[0] / 1 +
         torch.rand(1, 1, l_grad_m.shape[1], l_grad_m.shape[2]) *
         (l_grad_m.shape[0]) / 2)

    lr = 0.002
    momentum = 0.9
    lr_decay = 0.0002
    u_grad = torch.zeros(1, 1, u_grad_m.shape[1], u_grad_m.shape[2])
    l_grad = torch.zeros(1, 1, l_grad_m.shape[1], l_grad_m.shape[2])
    k = k_surface_band
    #if gpu_load_grad:
    u = u.cuda()
    l = l.cuda()
    u_grad = u_grad.cuda()
    l_grad = l_grad.cuda()
    k = k_surface_band_c
    gu, gl = None, None
    for i in range(8000):

        def calc_grad(s, grad, grad_m, ul):
            g = torch.clamp(s[0], 0, grad_m.shape[0] - 1).long()
            u_plane_bend = F.pad(s, (1, 1, 1, 1), mode='reflect')
            u_plane_bend = F.conv2d(u_plane_bend, k, padding=0)[0].data
            if gpu_load_grad:
                u_edge = torch.gather(grad_m, 0, g)
            else:
                u_edge = torch.gather(grad_m, 0, g.cpu()).cuda()
            grad = u_edge + \
                   3000 * u_plane_bend + \
                   0.05 * ul * torch.clamp(l - u - 75, -1000, 1) + \
                   momentum * grad
            return grad, g

        u_grad, gu = calc_grad(u, u_grad, u_grad_m, 1)
        l_grad, gl = calc_grad(l, l_grad, l_grad_m, -1)
        u += lr * u_grad
        l += lr * l_grad
        u = torch.clamp(u, 0, u_grad_m.shape[0] - 1)
        l = torch.clamp(l, 0, u_grad_m.shape[0] - 1)
        lr *= (1 - lr_decay)
        #v = torch.gather(img, 0, gu)[0].cpu().numpy()
        #print(torch.mean(u))
        #if i % 10 == 0:
        #    cv2.imshow('s', np.uint8(v * 5))
        #    cv2.waitKey(1)
        print(i)

    umap = np.float32((u + 1).cpu().numpy()[0][0])
    lmap = np.float32((l - 1).cpu().numpy()[0][0])
    umap = cv2.resize(umap, (size_[0], size_[1]))
    lmap = cv2.resize(lmap, (size_[0], size_[1]))
    u_surface = torch.gather(src_img, 0,
                             torch.Tensor(np.array(
                                 [umap])).long())[0].cpu().numpy()
    l_surface = torch.gather(src_img, 0,
                             torch.Tensor(np.array(
                                 [lmap])).long())[0].cpu().numpy()
    u_surface = np.clip((np.log(u_surface) - 4.6) * 39.4, 0, 255)
    l_surface = np.clip((np.log(l_surface) - 4.6) * 39.4, 0, 255)

    umap = sitk.GetImageFromArray(umap)
    lmap = sitk.GetImageFromArray(lmap)
    u_surface = sitk.GetImageFromArray(np.uint8(u_surface))
    l_surface = sitk.GetImageFromArray(np.uint8(l_surface))

    return umap, lmap, u_surface, l_surface
Пример #9
0
    def preprocess_reg_image_spatial(
        self,
        image: sitk.Image,
        preprocessing: ImagePreproParams,
        imported_transforms=None,
    ) -> Tuple[sitk.Image, List[Dict]]:
        """
        Spatial preprocessing of the reg_image.

        Parameters
        ----------
        image: sitk.Image
            reg_image to be preprocessed
        preprocessing: ImagePreproParams
            Spatial preprocessing parameters
        imported_transforms:
            Not implemented yet..

        Returns
        -------
        image: sitk.Image
            Spatially preprcessed image ready for registration
        transforms: list of transforms
            List of pre-initial transformations
        """

        transforms = []
        original_size = image.GetSize()

        if preprocessing.downsampling > 1:
            print("performing downsampling by factor: {}".format(
                preprocessing.downsampling))
            image.SetSpacing((self.image_res, self.image_res))
            image = sitk.Shrink(
                image,
                (preprocessing.downsampling, preprocessing.downsampling),
            )

            if self._mask is not None:
                self._mask.SetSpacing((self.image_res, self.image_res))
                self._mask = sitk.Shrink(
                    self._mask,
                    (
                        preprocessing.downsampling,
                        preprocessing.downsampling,
                    ),
                )

            image_res = image.GetSpacing()[0]
        else:
            image_res = self.image_res

        if float(preprocessing.rot_cc) != 0.0:
            print(f"rotating counter-clockwise {preprocessing.rot_cc}")
            rot_tform = gen_rigid_tform_rot(image, image_res,
                                            preprocessing.rot_cc)
            (
                composite_transform,
                _,
                final_tform,
            ) = prepare_wsireg_transform_data({"initial": [rot_tform]})

            image = transform_plane(image, final_tform, composite_transform)

            if self._mask is not None:
                self._mask.SetSpacing((image_res, image_res))
                self._mask = transform_plane(self._mask, final_tform,
                                             composite_transform)
            transforms.append(rot_tform)

        if preprocessing.flip:
            print(f"flipping image {preprocessing.flip.value}")

            flip_tform = gen_aff_tform_flip(image, image_res,
                                            preprocessing.flip.value)

            (
                composite_transform,
                _,
                final_tform,
            ) = prepare_wsireg_transform_data({"initial": [flip_tform]})

            image = transform_plane(image, final_tform, composite_transform)

            if self._mask is not None:
                self._mask.SetSpacing((image_res, image_res))
                self._mask = transform_plane(self._mask, final_tform,
                                             composite_transform)

            transforms.append(flip_tform)

        if self._mask and preprocessing.crop_to_mask_bbox:
            print("computing mask bounding box")
            if preprocessing.mask_bbox is None:
                mask_bbox = compute_mask_to_bbox(self._mask)
                preprocessing.mask_bbox = mask_bbox

        if preprocessing.mask_bbox:

            print("cropping to mask")
            translation_transform = gen_rigid_translation(
                image,
                image_res,
                preprocessing.mask_bbox.X,
                preprocessing.mask_bbox.Y,
                preprocessing.mask_bbox.WIDTH,
                preprocessing.mask_bbox.HEIGHT,
            )

            (
                composite_transform,
                _,
                final_tform,
            ) = prepare_wsireg_transform_data(
                {"initial": [translation_transform]})

            image = transform_plane(image, final_tform, composite_transform)

            self.original_size_transform = gen_rig_to_original(
                original_size, deepcopy(translation_transform))

            if self._mask is not None:
                self._mask.SetSpacing((image_res, image_res))
                self._mask = transform_plane(self._mask, final_tform,
                                             composite_transform)
            transforms.append(translation_transform)

        return image, transforms
def align_surfaces(prev_surface,
                   next_surface,
                   ref_img: sitk.Image = None,
                   prev_points=None,
                   next_points=None,
                   outside_brightness=2,
                   nonrigid=True,
                   ref_size=None,
                   ref_scale=1,
                   use_rigidity_mask=False,
                   **kwargs):
    prev_df1, next_df1 = None, None
    if prev_surface is not None:
        prev_surface = prev_surface[:, :, 0]
    if next_surface is not None:
        next_surface = next_surface[:, :, 0]
    if ref_img is not None:
        ref_img = ref_img[:, :, 0]
    s1, s2 = prev_surface, next_surface
    if ref_img is None:
        ref_img = prev_surface
        if prev_surface is None:
            ref_img = next_surface
        ref_scale = 1
    if ref_size is not None:
        origin = [(ref_size[i] - ref_img.GetSize()[i] * ref_scale) / 2
                  for i in range(2)]
        ref_img.SetSpacing([ref_scale for i in range(2)])
        ref_img.SetOrigin(origin)
        ref_img = sitk.Resample(ref_img, ref_size, sitk.Transform(),
                                sitk.sitkLinear, [0, 0], [1, 1])
    if prev_surface is not None:
        prev_surface = fill_outside(prev_surface, outside_brightness)
        prev_surface, prev_transform1 = get_align_transform(
            ref_img, prev_surface,
            [os.path.join(PARAMETER_DIR, 'tp_align_surface_rigid.txt')])
        prev_df1 = sitk.TransformToDisplacementField(prev_transform1,
                                                     sitk.sitkVectorFloat64,
                                                     ref_img.GetSize(),
                                                     ref_img.GetOrigin(),
                                                     ref_img.GetSpacing(),
                                                     ref_img.GetDirection())
    else:
        next_surface = fill_outside(next_surface, outside_brightness)
        next_surface, next_transform1 = get_align_transform(
            ref_img, next_surface,
            [os.path.join(PARAMETER_DIR, 'tp_align_surface_rigid.txt')])
        next_df1 = sitk.TransformToDisplacementField(next_transform1,
                                                     sitk.sitkVectorFloat64,
                                                     ref_img.GetSize(),
                                                     ref_img.GetOrigin(),
                                                     ref_img.GetSpacing(),
                                                     ref_img.GetDirection())
    if prev_surface is None or next_surface is None:
        return prev_df1, next_df1
    if not nonrigid:
        return prev_df1, prev_df1

    rigidity_mask = None
    if use_rigidity_mask == True:
        rigidity_mask = sitk.BinaryThreshold(next_surface,
                                             outside_brightness + 1)
        rigidity_mask = sitk.BinaryMorphologicalOpening(rigidity_mask)
    prev_surface = fill_outside(prev_surface, outside_brightness)
    next_surface = fill_outside(next_surface, outside_brightness)

    if prev_points is not None and next_points is not None:

        def get_transform_points(points, transform, file, s_):
            tf = transform.GetInverse()
            points = read_transformix_input_points(points)
            points = [tf.TransformPoint(p[:2]) for p in points]
            write_transformix_input_points(file, points, 2)

        with tempfile.TemporaryDirectory() as ELASTIX_TEMP:
            get_transform_points(prev_points, prev_transform1,
                                 os.path.join(ELASTIX_TEMP, 'prev.txt'), s1)
            get_transform_points(next_points,
                                 sitk.Transform(2, sitk.sitkIdentity),
                                 os.path.join(ELASTIX_TEMP, 'next.txt'), s2)
            _, transform2 = get_align_transform(
                prev_surface,
                next_surface, [
                    os.path.join(PARAMETER_DIR,
                                 'tp_align_surface_rigid_manual.txt'),
                    os.path.join(PARAMETER_DIR,
                                 'tp_align_surface_bspline_manual.txt')
                ],
                fixed_points=os.path.join(ELASTIX_TEMP, 'prev.txt'),
                moving_points=os.path.join(ELASTIX_TEMP, 'next.txt'),
                rigidity_mask=rigidity_mask)
    else:
        _, transform2 = get_align_transform(
            prev_surface,
            next_surface, [
                os.path.join(PARAMETER_DIR, 'tp_align_surface_rigid.txt'),
                os.path.join(PARAMETER_DIR, 'tp_align_surface_bspline3.txt')
            ],
            rigidity_mask=rigidity_mask)
    #sitk.WriteImage(prev_surface, 'F:/chaoyu/test/1_.mha')
    #sitk.WriteImage(_, 'F:/chaoyu/test/2_.mha')
    prev_df = prev_df1
    next_df = sitk.TransformToDisplacementField(transform2,
                                                sitk.sitkVectorFloat64,
                                                ref_img.GetSize(),
                                                ref_img.GetOrigin(),
                                                ref_img.GetSpacing(),
                                                ref_img.GetDirection())
    return prev_df, next_df
Пример #11
0
def registration(fixed_image: sitk.Image,     # noqa: C901
                 moving_image: sitk.Image,
                 *,
                 do_fft_initialization=True,
                 do_affine2d=False,
                 do_affine3d=True,
                 ignore_spacing=True,
                 sigma=1.0,
                 auto_mask=False,
                 samples_per_parameter=5000,
                 expand=None) -> sitk.Transform:
    """Robust multi-phase registration for multi-panel confocal microscopy images.

    The fixed and moving image are expected to be the same molecular labeling, and the same imaged regioned.

    The phase available are:
      - fft initialization for translation estimation
      - 2D affine which can correct rotational acquisition problems, this phase is done on z-projections to optimize a \
       2D similarity transform followed by 2D affine
      - 3D affine robust mulit-level registration


    :param fixed_image: a scalar SimpleITK 3D Image
    :param moving_image: a scalar SimpleITK 3D Image
    :param do_fft_initialization: perform FFT based cross correlation for initialize translation
    :param do_affine2d: perform registration on 2D images from z-projection
    :param do_affine3d: multi-level affine transform
    :param ignore_spacing: internally adjust spacing magnitude to be near 1 to avoid numeric stability issues with \
    micro sized spacing
    :param sigma: scalar to change the amount of Gaussian smoothing performed
    :param auto_mask: ignore zero valued pixels connected to the image boarder
    :param samples_per_parameter: the number of image samples to used per transform parameter at full resolution
    :param expand: Perform super-sampling to increase number of z-slices by an integer factor. Super-sampling is \
    automatically performed when the number of z-slices is less than 5.
    :return: A SimpleITK transform mapping points from the fixed image to the moving. This may be a CompositeTransform.

    """
    # Identity transform will be returned if all registration steps are disabled by
    # the calling function.
    result = sitk.Transform()

    initial_translation_3d = True

    moving_mask = None
    fixed_mask = None

    number_of_samples_per_parameter = samples_per_parameter

    expand_factors = None

    if expand:
        expand_factors = [1, 1, expand]

    if fixed_image.GetPixelID() != sitk.sitkFloat32:
        fixed_image = sitk.Cast(fixed_image, sitk.sitkFloat32)

    # expand the image if at least 5 in any dimension
    if not expand_factors:
        expand_factors = [-(-5//s) for s in fixed_image.GetSize()]

    if any([e != 1 for e in expand_factors]):
        _logger.warning("Fixed image under sized in at lease one dimension!"
                        "\tApplying expand factors {0} to image size.".format(expand_factors))
        fixed_image = sitk.Expand(fixed_image, expandFactors=expand_factors)

    if moving_image.GetPixelID() != sitk.sitkFloat32:
        moving_image = sitk.Cast(moving_image, sitk.sitkFloat32)

    expand_factors = [-(-5//s) for s in moving_image.GetSize()]
    if any([e != 1 for e in expand_factors]):
        _logger.warning("WARNING: Moving image under sized in at lease one dimension!"
                        "\tApplying expand factors {0} to image size.".format(expand_factors))
        moving_image = sitk.Expand(moving_image, expandFactors=expand_factors)

    if auto_mask:
        fixed_mask = imgf.make_auto_mask(fixed_image)
        moving_mask = imgf.make_auto_mask(moving_image)

    if ignore_spacing:

        #
        # FORCE THE SPACING magnitude to be normalized near 1.0
        #

        spacing_magnitude = imgf.spacing_average_magnitude(fixed_image)

        _logger.info("Adjusting image spacing by {0}...".format(1.0/spacing_magnitude))

        new_spacing = [s/spacing_magnitude for s in fixed_image.GetSpacing()]
        _logger.info("\tFixed Image Spacing: {0}->{1}".format(fixed_image.GetSpacing(), new_spacing))
        fixed_image.SetSpacing(new_spacing)
        fixed_image.SetOrigin([o/spacing_magnitude for o in fixed_image.GetOrigin()])

        new_spacing = [s / spacing_magnitude for s in moving_image.GetSpacing()]
        _logger.info("\tMoving Image Spacing: {0}->{1}".format(moving_image.GetSpacing(), new_spacing))
        moving_image.SetSpacing(new_spacing)
        moving_image.SetOrigin([o/spacing_magnitude for o in moving_image.GetOrigin()])

        if moving_mask:
            moving_mask.SetSpacing(new_spacing)
            moving_mask.SetOrigin([o/spacing_magnitude for o in moving_mask.GetOrigin()])

        if fixed_mask:
            fixed_mask.SetSpacing(new_spacing)
            fixed_mask.SetOrigin([o / spacing_magnitude for o in fixed_mask.GetOrigin()])

    #
    #
    # Do FFT based translation initialization
    #
    #
    initial_translation = None
    if do_fft_initialization:
        initial_translation = imgf.fft_initialization(moving_image,
                                                      fixed_image,
                                                      bin_shrink=8,
                                                      projection=(not initial_translation_3d))
        result = sitk.TranslationTransform(len(initial_translation), initial_translation)

    #
    # Do 2D registration first
    #
    if do_affine2d:
        result = register_as_2d_affine(fixed_image, moving_image,
                                       sigma_base=sigma,
                                       initial_translation=initial_translation,
                                       fixed_image_mask=fixed_mask,
                                       moving_image_mask=moving_mask)

    if do_affine3d:

        _logger.info("Initializing Affine Registration...")

        if do_affine2d:
            # set the FFT xcoor initial z translation
            if do_fft_initialization and len(initial_translation) >= 3:

                # take the z-translation from the FFT
                translation = list(result.GetTranslation())
                translation[2] = initial_translation[2]
                result.SetTranslation(translation)

                _logger.info("Initialized 3D affine with z-translation... {0}".format(translation))

            affine = result
        else:
            affine = sitk.CenteredTransformInitializer(fixed_image,
                                                       moving_image,
                                                       sitk.AffineTransform(3),
                                                       sitk.CenteredTransformInitializerFilter.GEOMETRY)
            affine = sitk.AffineTransform(affine)

            if do_fft_initialization:
                if len(initial_translation) >= 3:
                    affine.SetTranslation(list(initial_translation))
                    _logger.info("Initialized 3D affine with z-translation... {0}".format(initial_translation))
                else:
                    affine.SetTranslation(list(initial_translation)+[0, ])

        affine_result = register_3d(fixed_image, moving_image,
                                    initial_transform=affine,
                                    sigma_base=sigma,
                                    fixed_image_mask=fixed_mask,
                                    moving_image_mask=moving_mask,
                                    number_of_samples_per_parameter=number_of_samples_per_parameter)

        result = affine_result

    if ignore_spacing:

        # Compose the scaling Transform into a single affine transform

        # The spacing of the image was modified to do registration, so we need to apply the appropriate scaling to
        # transform to the space the registration was done in.r

        scale = spacing_magnitude

        scale_transform = sitk.ScaleTransform(3)
        scale_transform.SetScale([scale]*3)

        result = sitk.CompositeTransform([sitk.Transform(scale_transform),
                                          result,
                                          scale_transform.GetInverse()])

        # if result was a composite transform then we have nested composite
        # transforms white need to be flattened for writing.

        result.FlattenTransform()

        _logger.info(result)

    return result