Example #1
0
def generate_field_shift(mask_image,
                         vector_shift=(10, 10, 10),
                         gaussian_smooth=5):
    """
    Shifts (moves) a structure defined using a binary mask.

    Args:
        mask_image ([SimpleITK.Image]): The binary mask to shift.
        vector_shift (tuple, optional): The displacement vector applied to the entire binary mask.
                                        Convention: (+/-, +/-, +/-) = (sup/inf, post/ant,
                                        left/right) shift.
                                        Defined in millimetres.
                                        Defaults to (10, 10, 10).
        gaussian_smooth (int | list, optional): Scale of a Gaussian kernel used to smooth the
                                                deformation vector field. Defaults to 5.

    Returns:
        [SimpleITK.Image]: The binary mask following the shift.
        [SimpleITK.DisplacementFieldTransform]: The transform representing the shift.
        [SimpleITK.Image]: The displacement vector field representing the shift.
    """
    # Define array
    # Used for image array manipulations
    mask_image_arr = sitk.GetArrayFromImage(mask_image)

    # The template deformation field
    # Used to generate transforms
    dvf_arr = np.zeros(mask_image_arr.shape + (3, ))
    dvf_arr = dvf_arr - np.array([[[vector_shift[::-1]]]])
    dvf_template = sitk.GetImageFromArray(dvf_arr)

    # Copy image information
    dvf_template.CopyInformation(mask_image)

    dvf_tfm = sitk.DisplacementFieldTransform(
        sitk.Cast(dvf_template, sitk.sitkVectorFloat64))
    mask_image_shift = apply_transform(mask_image,
                                       transform=dvf_tfm,
                                       default_value=0,
                                       interpolator=sitk.sitkNearestNeighbor)

    dvf_template = sitk.Mask(dvf_template, mask_image | mask_image_shift)

    # smooth
    if np.any(gaussian_smooth):

        if not hasattr(gaussian_smooth, "__iter__"):
            gaussian_smooth = (gaussian_smooth, ) * 3

        dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template,
                                                       gaussian_smooth)

    dvf_tfm = sitk.DisplacementFieldTransform(
        sitk.Cast(dvf_template, sitk.sitkVectorFloat64))
    mask_image_shift = apply_transform(mask_image,
                                       transform=dvf_tfm,
                                       default_value=0,
                                       interpolator=sitk.sitkNearestNeighbor)

    return mask_image_shift, dvf_tfm, dvf_template
Example #2
0
def apply_augmentation(image, augmentation, masks=[]):

    if not isinstance(image, sitk.Image):
        raise AttributeError("image should be a SimpleITK.Image")

    if isinstance(augmentation, DeformableAugment):
        augmentation = [augmentation]

    if not isinstance(augmentation, Iterable):
        raise AttributeError(
            "augmentation must be a DeformableAugment or an iterable (such as list) of"
            "DeformableAugment's")

    transforms = []
    dvf = None
    for aug in augmentation:

        if not isinstance(aug, DeformableAugment):
            raise AttributeError(
                "Each augmentation must be of type DeformableAugment")

        tfm, field = aug.augment()
        transforms.append(tfm)

        if dvf is None:
            dvf = field
        else:
            dvf += field

    transform = sitk.CompositeTransform(transforms)
    del transforms

    image_deformed = apply_transform(
        image,
        transform=transform,
        default_value=int(sitk.GetArrayViewFromImage(image).min()),
        interpolator=sitk.sitkLinear,
    )

    masks_deformed = []
    for mask in masks:
        masks_deformed.append(
            apply_transform(mask,
                            transform=transform,
                            default_value=0,
                            interpolator=sitk.sitkNearestNeighbor))

    if masks:
        return image_deformed, masks_deformed, dvf

    return image_deformed, dvf
Example #3
0
def generate_field_radial_bend(
        reference_image,
        body_mask,
        reference_point,
        axis_of_rotation=[0, 0, -1],
        scale=0.1,
        mask_bend_from_reference_point=("z", "inf"),
        gaussian_smooth=5,
):
    """
    Generates a synthetic field characterised by radial bending.
    Typically, this field would be used to simulate a moving head and so masking is important.

    Args:
        reference_image ([SimpleITK.Image]): The image to be deformed.
        body_mask ([SimpleITK.Image]): A binary mask in which the deformation field will be defined
        reference_point ([tuple]): The point (z,y,x) about which the rotation field is defined.
        axis_of_rotation (tuple, optional): The axis of rotation (z,y,x). Defaults to [0, 0, -1].
        scale (int, optional): The deformation vector length at each point will equal scale
                              multiplied by the distance to that point from reference_point.
                              Defaults to 1.
        mask_bend_from_reference_point (tuple, optional): The dimension (z=axial, y=coronal,
                                                          x=sagittal) and limit (inf/sup, post/ant,
                                                          left/right) for masking the vector field,
                                                          relative to the reference point. Defaults
                                                          to ("z", "inf").
        gaussian_smooth (int | list, optional): Scale of a Gaussian kernel used to smooth the
                                                deformation vector field. Defaults to 5.

    Returns:
        [SimpleITK.Image]: The binary mask following the expansion.
        [SimpleITK.DisplacementFieldTransform]: The transform representing the expansion.
        [SimpleITK.Image]: The displacement vector field representing the expansion.
    """

    body_mask_arr = sitk.GetArrayFromImage(body_mask)

    if mask_bend_from_reference_point is not False:
        if mask_bend_from_reference_point[0] == "z":
            if mask_bend_from_reference_point[1] == "inf":
                body_mask_arr[:reference_point[0], :, :] = 0
            elif mask_bend_from_reference_point[1] == "sup":
                body_mask_arr[reference_point[0]:, :, :] = 0
        if mask_bend_from_reference_point[0] == "y":
            if mask_bend_from_reference_point[1] == "post":
                body_mask_arr[:, reference_point[1]:, :] = 0
            elif mask_bend_from_reference_point[1] == "ant":
                body_mask_arr[:, :reference_point[1], :] = 0
        if mask_bend_from_reference_point[0] == "x":
            if mask_bend_from_reference_point[1] == "left":
                body_mask_arr[:, :, reference_point[2]:] = 0
            elif mask_bend_from_reference_point[1] == "right":
                body_mask_arr[:, :, :reference_point[2]] = 0

    pt_arr = np.array(np.where(body_mask_arr))
    vector_ref_to_pt = pt_arr - np.array(reference_point)[:, None]

    # Normalise the normal vector (axis_of_rotation)
    axis_of_rotation = np.array(axis_of_rotation)
    axis_of_rotation = axis_of_rotation / np.linalg.norm(axis_of_rotation)

    deformation_vectors = np.cross(vector_ref_to_pt[::-1].T,
                                   axis_of_rotation[::-1])

    dvf_template = sitk.Image(reference_image.GetSize(),
                              sitk.sitkVectorFloat64, 3)
    dvf_template_arr = sitk.GetArrayFromImage(dvf_template)

    if scale is not False:
        dvf_template_arr[np.where(body_mask_arr)] = deformation_vectors * scale

    dvf_template = sitk.GetImageFromArray(dvf_template_arr)
    dvf_template.CopyInformation(reference_image)

    # smooth
    if np.any(gaussian_smooth):

        if not hasattr(gaussian_smooth, "__iter__"):
            gaussian_smooth = (gaussian_smooth, ) * 3

        dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template,
                                                       gaussian_smooth)

    dvf_tfm = sitk.DisplacementFieldTransform(
        sitk.Cast(dvf_template, sitk.sitkVectorFloat64))
    reference_image_bend = apply_transform(
        reference_image,
        transform=dvf_tfm,
        default_value=int(sitk.GetArrayViewFromImage(reference_image).min()),
        interpolator=sitk.sitkLinear,
    )

    return reference_image_bend, dvf_tfm, dvf_template
Example #4
0
def generate_field_expand(
    mask,
    bone_mask=False,
    expand=3,
    gaussian_smooth=5,
    use_internal_deformation=True,
):
    """
    Expands a structure (defined using a binary mask) using a specified vector to define the
    dilation kernel.

    Args:
        mask ([SimpleITK.Image]): The binary mask to expand.
        bone_mask ([SimpleITK.Image, optional]): A binary mask defining regions where we expect
                                                 restricted deformations.
        vector_asymmetric_extend (int |tuple, optional): The expansion vector applied to the entire
                                                         binary mask.
                                                    Convention: (z,y,x) size of expansion kernel.
                                                    Defined in millimetres.
                                                    Defaults to 3.
        gaussian_smooth (int | list, optional): Scale of a Gaussian kernel used to smooth the
                                                deformation vector field. Defaults to 5.

    Returns:
        [SimpleITK.Image]: The binary mask following the expansion.
        [SimpleITK.DisplacementFieldTransform]: The transform representing the expansion.
        [SimpleITK.Image]: The displacement vector field representing the expansion.
    """

    if bone_mask is not False:
        mask_original = mask + bone_mask
    else:
        mask_original = mask

    # Use binary erosion to create a smaller volume
    if not hasattr(expand, "__iter__"):
        expand = (expand, ) * 3

    expand = np.array(expand)

    # Convert voxels to millimetres
    expand = expand / np.array(mask.GetSpacing()[::-1])

    # Re-order to (x,y,z)
    expand = expand[::-1]
    # expand = [int(i / j) for i, j in zip(expand, mask.GetSpacing()[::-1])][::-1]

    # If all negative: erode
    if np.all(np.array(expand) <= 0):
        print("All factors negative: shrinking only.")
        mask_expand = sitk.BinaryErode(mask,
                                       np.abs(expand).astype(int).tolist(),
                                       sitk.sitkBall)

    # If all positive: dilate
    elif np.all(np.array(expand) >= 0):
        print("All factors positive: expansion only.")
        mask_expand = sitk.BinaryDilate(mask,
                                        np.abs(expand).astype(int).tolist(),
                                        sitk.sitkBall)

    # Otherwise: sequential operations
    else:
        print("Mixed factors: shrinking and expansion.")
        expansion_kernel = expand * (expand > 0)
        shrink_kernel = expand * (expand < 0)

        mask_expand = sitk.BinaryDilate(
            mask,
            np.abs(expansion_kernel).astype(int).tolist(), sitk.sitkBall)
        mask_expand = sitk.BinaryErode(
            mask_expand,
            np.abs(shrink_kernel).astype(int).tolist(), sitk.sitkBall)

    if bone_mask is not False:
        mask_expand = mask_expand + bone_mask

    if use_internal_deformation:
        registration_mask_original = convert_mask_to_reg_structure(
            mask_original)
        registration_mask_expand = convert_mask_to_reg_structure(mask_expand)

    else:
        registration_mask_original = mask_original
        registration_mask_expand = mask_expand

    # Use DIR to find the deformation
    _, _, dvf_template = fast_symmetric_forces_demons_registration(
        registration_mask_expand,
        registration_mask_original,
        isotropic_resample=True,
        resolution_staging=[4, 2],
        iteration_staging=[10, 10],
        ncores=8,
    )

    # smooth
    if np.any(gaussian_smooth):

        if not hasattr(gaussian_smooth, "__iter__"):
            gaussian_smooth = (gaussian_smooth, ) * 3

        dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template,
                                                       gaussian_smooth)

    dvf_tfm = sitk.DisplacementFieldTransform(
        sitk.Cast(dvf_template, sitk.sitkVectorFloat64))

    mask_symmetric_expand = apply_transform(
        mask,
        transform=dvf_tfm,
        default_value=0,
        interpolator=sitk.sitkNearestNeighbor)

    return mask_symmetric_expand, dvf_tfm, dvf_template
Example #5
0
def linear_registration(
    fixed_image,
    moving_image,
    fixed_structure=None,
    moving_structure=None,
    reg_method="similarity",
    metric="mean_squares",
    optimiser="gradient_descent",
    shrink_factors=[8, 2, 1],
    smooth_sigmas=[4, 2, 0],
    sampling_rate=0.25,
    final_interp=2,
    number_of_iterations=50,
    default_value=None,
    verbose=False,
):
    """
    Initial linear registration between two images.
    The images are not required to be in the same space.
    There are several transforms available, with options for the metric and optimiser to be used.
    Note the default_value, which should be set to match the image modality.

    Args:
        fixed_image ([SimpleITK.Image]): The fixed (target/primary) image.
        moving_image ([SimpleITK.Image]): The moving (secondary) image.
        fixed_structure (bool, optional): If defined, a binary SimpleITK.Image used to mask metric
                                          evaluation for the moving image. Defaults to False.
        moving_structure (bool, optional): If defined, a binary SimpleITK.Image used to mask metric
                                           evaluation for the fixed image. Defaults to False.
        reg_method (str, optional): The linear transformtation model to be used for image
                                    registration.
                                    Available options:
                                     - translation
                                     - rigid
                                     - similarity
                                     - affine
                                     - scaleversor
                                     - scaleskewversor
                                    Defaults to "Similarity".
        metric (str, optional): The metric to be optimised during image registration.
                                Available options:
                                 - correlation
                                 - mean_squares
                                 - mattes_mi
                                 - joint_hist_mi
                                Defaults to "mean_squares".
        optimiser (str, optional): The optimiser algorithm used for image registration.
                                   Available options:
                                    - lbfgsb
                                      (limited-memory Broyden–Fletcher–Goldfarb–Shanno (bounded).)
                                    - gradient_descent
                                    - gradient_descent_line_search
                                   Defaults to "gradient_descent".
        shrink_factors (list, optional): The multi-resolution downsampling factors.
                                         Defaults to [8, 2, 1].
        smooth_sigmas (list, optional): The multi-resolution smoothing kernel scale (Gaussian).
                                        Defaults to [4, 2, 0].
        sampling_rate (float, optional): The fraction of voxels sampled during each iteration.
                                         Defaults to 0.25.
        ants_radius (int, optional): Used is the metric is set as "ants_radius". Defaults to 3.
        final_interp (int, optional): The final interpolation order. Defaults to 2 (linear).
        number_of_iterations (int, optional): Number of iterations in each multi-resolution step.
                                              Defaults to 50.
        default_value (int, optional): Default voxel value. Defaults to 0 unless image is CT-like.
        verbose (bool, optional): Print image registration process information. Defaults to False.

    Returns:
        [SimpleITK.Image]: The registered moving (secondary) image.
        [SimleITK.Transform]: The linear transformation.
    """

    # Re-cast
    fixed_image = sitk.Cast(fixed_image, sitk.sitkFloat32)

    moving_image_type = moving_image.GetPixelIDValue()
    moving_image = sitk.Cast(moving_image, sitk.sitkFloat32)

    # Initialise using a VersorRigid3DTransform
    initial_transform = sitk.CenteredTransformInitializer(
        fixed_image, moving_image, sitk.Euler3DTransform(), False)
    # Set up image registration method
    registration = sitk.ImageRegistrationMethod()

    registration.SetShrinkFactorsPerLevel(shrink_factors)
    registration.SetSmoothingSigmasPerLevel(smooth_sigmas)
    registration.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    registration.SetMovingInitialTransform(initial_transform)

    if metric.lower() == "correlation":
        registration.SetMetricAsCorrelation()
    elif metric.lower() == "mean_squares":
        registration.SetMetricAsMeanSquares()
    elif metric.lower() == "mattes_mi":
        registration.SetMetricAsMattesMutualInformation()
    elif metric.lower() == "joint_hist_mi":
        registration.SetMetricAsJointHistogramMutualInformation()
    # to do: add the rest

    registration.SetInterpolator(
        sitk.sitkLinear)  # Perhaps a small gain in improvement
    registration.SetMetricSamplingPercentage(sampling_rate)
    registration.SetMetricSamplingStrategy(
        sitk.ImageRegistrationMethod.REGULAR)

    # This is only necessary if using a transform comprising changes with different units
    # e.g. rigid (rotation: radians, translation: mm)
    # It can safely be left on
    registration.SetOptimizerScalesFromPhysicalShift()

    if moving_structure:
        registration.SetMetricMovingMask(moving_structure)

    if fixed_structure:
        registration.SetMetricFixedMask(fixed_structure)

    if isinstance(reg_method, str):
        if reg_method.lower() == "translation":
            registration.SetInitialTransform(sitk.TranslationTransform(3))
        elif reg_method.lower() == "similarity":
            registration.SetInitialTransform(sitk.Similarity3DTransform())
        elif reg_method.lower() == "affine":
            registration.SetInitialTransform(sitk.AffineTransform(3))
        elif reg_method.lower() == "rigid":
            registration.SetInitialTransform(sitk.VersorRigid3DTransform())
        elif reg_method.lower() == "scaleversor":
            registration.SetInitialTransform(sitk.ScaleVersor3DTransform())
        elif reg_method.lower() == "scaleskewversor":
            registration.SetInitialTransform(sitk.ScaleSkewVersor3DTransform())
        else:
            raise ValueError(
                "You have selected a registration method that does not exist.\n Please select from"
                " Translation, Similarity, Affine, Rigid, ScaleVersor, ScaleSkewVersor"
            )
    elif isinstance(
            reg_method,
        (
            sitk.CompositeTransform,
            sitk.Transform,
            sitk.TranslationTransform,
            sitk.Similarity3DTransform,
            sitk.AffineTransform,
            sitk.VersorRigid3DTransform,
            sitk.ScaleVersor3DTransform,
            sitk.ScaleSkewVersor3DTransform,
        ),
    ):
        registration.SetInitialTransform(reg_method)
    else:
        raise ValueError(
            "'reg_method' must be either a string (see docs for acceptable registration names), "
            "or a custom sitk.CompositeTransform.")

    if optimiser.lower() == "lbfgsb":
        registration.SetOptimizerAsLBFGSB(
            gradientConvergenceTolerance=1e-5,
            numberOfIterations=number_of_iterations,
            maximumNumberOfCorrections=50,
            maximumNumberOfFunctionEvaluations=1024,
            costFunctionConvergenceFactor=1e7,
            trace=verbose,
        )
    elif optimiser.lower() == "exhaustive":
        """
        This isn't well implemented
        Needs some work to give options for sampling rates
        Use is not currently recommended
        """
        samples = [10, 10, 10, 10, 10, 10]
        registration.SetOptimizerAsExhaustive(samples)
    elif optimiser.lower() == "gradient_descent_line_search":
        registration.SetOptimizerAsGradientDescentLineSearch(
            learningRate=1.0, numberOfIterations=number_of_iterations)
    elif optimiser.lower() == "gradient_descent":
        registration.SetOptimizerAsGradientDescent(
            learningRate=1.0, numberOfIterations=number_of_iterations)

    if verbose:
        registration.AddCommand(
            sitk.sitkIterationEvent,
            lambda: registration_command_iteration(registration),
        )

    output_transform = registration.Execute(fixed=fixed_image,
                                            moving=moving_image)
    # Combine initial and optimised transform
    combined_transform = sitk.CompositeTransform(
        [initial_transform, output_transform])

    # Try to find default value
    if default_value is None:
        default_value = 0

        # Test if image is CT-like
        if sitk.GetArrayViewFromImage(moving_image).min() <= -1000:
            default_value = -1000

    registered_image = apply_transform(
        input_image=moving_image,
        reference_image=fixed_image,
        transform=combined_transform,
        default_value=default_value,
        interpolator=final_interp,
    )

    registered_image = sitk.Cast(registered_image, moving_image_type)

    return registered_image, combined_transform
Example #6
0
def run_segmentation(img, settings=MUTLIATLAS_SETTINGS_DEFAULTS):
    """Runs the atlas-based segmentation algorithm

    Args:
        img (sitk.Image):
        settings (dict, optional): Dictionary containing settings for algorithm.
                                   Defaults to default_settings.

    Returns:
        dict: Dictionary containing output of segmentation
    """

    results = {}
    results_prob = {}

    """
    Initialisation - Read in atlases
    - image files
    - structure files

        Atlas structure:
        'ID': 'Original': 'CT Image'    : sitk.Image
                            'Struct A'    : sitk.Image
                            'Struct B'    : sitk.Image
                'RIR'     : 'CT Image'    : sitk.Image
                            'Transform'   : transform parameter map
                            'Struct A'    : sitk.Image
                            'Struct B'    : sitk.Image
                'DIR'     : 'CT Image'    : sitk.Image
                            'Transform'   : displacement field transform
                            'Weight Map'  : sitk.Image
                            'Struct A'    : sitk.Image
                            'Struct B'    : sitk.Image


    """

    logger.info("")
    # Settings
    atlas_path = settings["atlas_settings"]["atlas_path"]
    atlas_id_list = settings["atlas_settings"]["atlas_id_list"]
    atlas_structure_list = settings["atlas_settings"]["atlas_structure_list"]

    atlas_image_format = settings["atlas_settings"]["atlas_image_format"]
    atlas_label_format = settings["atlas_settings"]["atlas_label_format"]

    crop_atlas_to_structures = settings["atlas_settings"]["crop_atlas_to_structures"]
    crop_atlas_expansion_mm = settings["atlas_settings"]["crop_atlas_expansion_mm"]

    atlas_set = {}
    for atlas_id in atlas_id_list:
        atlas_set[atlas_id] = {}
        atlas_set[atlas_id]["Original"] = {}

        image = sitk.ReadImage(f"{atlas_path}/{atlas_image_format.format(atlas_id)}")

        structures = {
            struct: sitk.ReadImage(f"{atlas_path}/{atlas_label_format.format(atlas_id, struct)}")
            for struct in atlas_structure_list
        }

        if crop_atlas_to_structures:
            logger.info(f"Automatically cropping atlas: {atlas_id}")

            original_volume = np.product(image.GetSize())

            crop_box_size, crop_box_index = label_to_roi(
                structures.values(), expansion_mm=crop_atlas_expansion_mm
            )

            image = crop_to_roi(image, size=crop_box_size, index=crop_box_index)

            final_volume = np.product(image.GetSize())

            logger.info(f"  > Volume reduced by factor {original_volume/final_volume:.2f}")

            for struct in atlas_structure_list:
                structures[struct] = crop_to_roi(
                    structures[struct], size=crop_box_size, index=crop_box_index
                )

        atlas_set[atlas_id]["Original"]["CT Image"] = image

        for struct in atlas_structure_list:
            atlas_set[atlas_id]["Original"][struct] = structures[struct]

    """
    Step 1 - Automatic cropping
    If we have a guide structure:
        - use structure to crop target image

    Otherwise:
        - using a quick registration to register each atlas
        - expansion of the bounding box to ensure entire volume of interest is enclosed
        - target image is cropped
    """

    expansion_mm = settings["auto_crop_target_image_settings"]["expansion_mm"]

    quick_reg_settings = {
        "reg_method": "similarity",
        "shrink_factors": [8],
        "smooth_sigmas": [0],
        "sampling_rate": 0.75,
        "default_value": -1000,
        "number_of_iterations": 25,
        "final_interp": sitk.sitkLinear,
        "metric": "mean_squares",
        "optimiser": "gradient_descent_line_search",
    }

    registered_crop_images = []

    logger.info("Running initial Translation tranform to crop image volume")

    for atlas_id in atlas_id_list[: min([8, len(atlas_id_list)])]:

        logger.info(f"  > atlas {atlas_id}")

        # Register the atlases
        atlas_set[atlas_id]["RIR"] = {}
        atlas_image = atlas_set[atlas_id]["Original"]["CT Image"]

        reg_image, _ = linear_registration(
            img,
            atlas_image,
            **quick_reg_settings,
        )

        registered_crop_images.append(sitk.Cast(reg_image, sitk.sitkFloat32))

        del reg_image

    combined_image = sum(registered_crop_images) / len(registered_crop_images) > -1000

    crop_box_size, crop_box_index = label_to_roi(combined_image, expansion_mm=expansion_mm)

    img_crop = crop_to_roi(img, crop_box_size, crop_box_index)

    logger.info("Calculated crop box:")
    logger.info(f"  > {crop_box_index}")
    logger.info(f"  > {crop_box_size}")
    logger.info(f"  > Vol reduction = {np.product(img.GetSize())/np.product(crop_box_size):.2f}")

    """
    Step 2 - Rigid registration of target images
    - Individual atlas images are registered to the target
    - The transformation is used to propagate the labels onto the target
    """
    linear_registration_settings = settings["linear_registration_settings"]

    logger.info(
        f"Running {linear_registration_settings['reg_method']} tranform to align atlas images"
    )

    for atlas_id in atlas_id_list:
        # Register the atlases

        logger.info(f"  > atlas {atlas_id}")

        atlas_set[atlas_id]["RIR"] = {}

        target_reg_image = img_crop
        atlas_reg_image = atlas_set[atlas_id]["Original"]["CT Image"]

        _, initial_tfm = linear_registration(
            target_reg_image,
            atlas_reg_image,
            **linear_registration_settings,
        )

        # Save in the atlas dict
        atlas_set[atlas_id]["RIR"]["Transform"] = initial_tfm

        atlas_set[atlas_id]["RIR"]["CT Image"] = apply_transform(
            input_image=atlas_set[atlas_id]["Original"]["CT Image"],
            reference_image=img_crop,
            transform=initial_tfm,
            default_value=-1000,
            interpolator=sitk.sitkLinear,
        )

        # sitk.WriteImage(rigid_image, f"./RR_{atlas_id}.nii.gz")

        for struct in atlas_structure_list:
            input_struct = atlas_set[atlas_id]["Original"][struct]
            atlas_set[atlas_id]["RIR"][struct] = apply_transform(
                input_image=input_struct,
                reference_image=img_crop,
                transform=initial_tfm,
                default_value=0,
                interpolator=sitk.sitkNearestNeighbor,
            )

        atlas_set[atlas_id]["Original"] = None

    """
    Step 3 - Deformable image registration
    - Using Fast Symmetric Diffeomorphic Demons
    """

    # Settings
    deformable_registration_settings = settings["deformable_registration_settings"]

    logger.info("Running DIR to refine atlas image registration")

    for atlas_id in atlas_id_list:

        logger.info(f"  > atlas {atlas_id}")

        # Register the atlases
        atlas_set[atlas_id]["DIR"] = {}

        atlas_reg_image = atlas_set[atlas_id]["RIR"]["CT Image"]
        target_reg_image = img_crop

        _, dir_tfm, _ = fast_symmetric_forces_demons_registration(
            target_reg_image,
            atlas_reg_image,
            **deformable_registration_settings,
        )

        # Save in the atlas dict
        atlas_set[atlas_id]["DIR"]["Transform"] = dir_tfm

        atlas_set[atlas_id]["DIR"]["CT Image"] = apply_transform(
            input_image=atlas_set[atlas_id]["RIR"]["CT Image"],
            transform=dir_tfm,
            default_value=-1000,
            interpolator=sitk.sitkLinear,
        )

        for struct in atlas_structure_list:
            input_struct = atlas_set[atlas_id]["RIR"][struct]
            atlas_set[atlas_id]["DIR"][struct] = apply_transform(
                input_image=input_struct,
                transform=dir_tfm,
                default_value=0,
                interpolator=sitk.sitkNearestNeighbor,
            )

        atlas_set[atlas_id]["RIR"] = None

    """
    Step 4 - Label Fusion
    """
    # Compute weight maps
    vote_type = settings["label_fusion_settings"]["vote_type"]
    vote_params = settings["label_fusion_settings"]["vote_params"]

    # Compute weight maps
    for atlas_id in list(atlas_set.keys()):
        atlas_image = atlas_set[atlas_id]["DIR"]["CT Image"]
        weight_map = compute_weight_map(
            img_crop, atlas_image, vote_type=vote_type, vote_params=vote_params
        )
        atlas_set[atlas_id]["DIR"]["Weight Map"] = weight_map

    combined_label_dict = combine_labels(atlas_set, atlas_structure_list)

    """
    Step 6 - Paste the cropped structure into the original image space
    """
    logger.info("Generating binary segmentations.")
    template_img_binary = sitk.Cast((img * 0), sitk.sitkUInt8)
    template_img_prob = sitk.Cast((img * 0), sitk.sitkFloat64)

    for structure_name in atlas_structure_list:

        probability_map = combined_label_dict[structure_name]

        if structure_name in settings["label_fusion_settings"]["optimal_threshold"]:
            optimal_threshold = settings["label_fusion_settings"]["optimal_threshold"][
                structure_name
            ]
        else:
            optimal_threshold = 0.5

        binary_struct = process_probability_image(probability_map, optimal_threshold)

        # Un-crop binary structure
        paste_img_binary = sitk.Paste(
            template_img_binary,
            binary_struct,
            binary_struct.GetSize(),
            (0, 0, 0),
            crop_box_index,
        )
        results[structure_name] = paste_img_binary

        # Un-crop probability map
        paste_prob_img = sitk.Paste(
            template_img_prob,
            probability_map,
            probability_map.GetSize(),
            (0, 0, 0),
            crop_box_index,
        )
        results_prob[structure_name] = paste_prob_img

    """
    Step 8 - Post-processing
    """
    postprocessing_settings = settings["postprocessing_settings"]

    if postprocessing_settings["run_postprocessing"]:
        logger.info("Running post-processing.")

        # Remove any smaller components and perform morphological closing (hole filling)
        binaryfillhole_img = [
            int(postprocessing_settings["binaryfillhole_mm"] / sp) for sp in img.GetSpacing()
        ]

        for structure_name in postprocessing_settings["structures_for_binaryfillhole"]:

            if structure_name not in results.keys():
                continue

            contour_s = results[structure_name]
            contour_s = sitk.RelabelComponent(sitk.ConnectedComponent(contour_s)) == 1
            contour_s = sitk.BinaryMorphologicalClosing(contour_s, binaryfillhole_img)
            results[structure_name] = contour_s

        if len(postprocessing_settings["structures_for_overlap_correction"]) >= 2:
            # Remove any overlaps
            input_overlap = {
                s: results[s] for s in postprocessing_settings["structures_for_overlap_correction"]
            }
            output_overlap = correct_volume_overlap(input_overlap)

            for s in postprocessing_settings["structures_for_overlap_correction"]:
                results[s] = output_overlap[s]

    logger.info("Done!")

    return results, results_prob
Example #7
0
def bspline_registration(
    fixed_image,
    moving_image,
    fixed_structure=False,
    moving_structure=False,
    resolution_staging=[8, 4, 2],
    smooth_sigmas=[4, 2, 1],
    sampling_rate=0.1,
    optimiser="LBFGS",
    metric="mean_squares",
    initial_grid_spacing=64,
    grid_scale_factors=[1, 2, 4],
    interp_order=3,
    default_value=-1000,
    number_of_iterations=20,
    isotropic_resample=False,
    initial_isotropic_size=1,
    number_of_histogram_bins_mi=30,
    verbose=False,
    ncores=8,
):
    """
    B-Spline image registration using ITK

    IMPORTANT - THIS IS UNDER ACTIVE DEVELOPMENT

    Args:
        fixed_image ([SimpleITK.Image]): The fixed (target/primary) image.
        moving_image ([SimpleITK.Image]): The moving (secondary) image.
        fixed_structure (bool, optional): If defined, a binary SimpleITK.Image used to mask metric
                                          evaluation for the moving image. Defaults to False.
        moving_structure (bool, optional): If defined, a binary SimpleITK.Image used to mask metric
                                           evaluation for the fixed image. Defaults to False.
        resolution_staging (list, optional): The multi-resolution downsampling factors.
                                             Defaults to [8, 4, 2].
        smooth_sigmas (list, optional): The multi-resolution smoothing kernel scale (Gaussian).
                                        Defaults to [4, 2, 1].
        sampling_rate (float, optional): The fraction of voxels sampled during each iteration.
                                         Defaults to 0.1.
        optimiser (str, optional): The optimiser algorithm used for image registration.
                                   Available options:
                                    - LBFSGS
                                      (limited-memory Broyden–Fletcher–Goldfarb–Shanno (bounded).)
                                    - LBFSG
                                      (limited-memory Broyden–Fletcher–Goldfarb–Shanno
                                      (unbounded).)
                                    - CGLS (conjugate gradient line search)
                                    - gradient_descent
                                    - gradient_descent_line_search
                                   Defaults to "LBFGS".
        metric (str, optional): The metric to be optimised during image registration.
                                Available options:
                                 - correlation
                                 - mean_squares
                                 - demons
                                 - mutual_information
                                   (used with parameter number_of_histogram_bins_mi)
                                Defaults to "mean_squares".
        initial_grid_spacing (int, optional): Grid spacing of lower resolution stage (in mm).
                                              Defaults to 64.
        grid_scale_factors (list, optional): Factors to determine grid spacing at each
                                             multiresolution stage.
                                             Defaults to [1, 2, 4].
        interp_order (int, optional): Interpolation order of final resampling.
                                      Defaults to 3 (cubic).
        default_value (int, optional): Default image value. Defaults to -1000.
        number_of_iterations (int, optional): Number of iterations at each resolution stage.
                                              Defaults to 20.
        isotropic_resample (bool, optional): Flag whether to resample to isotropic resampling
                                             prior to registration.
                                             Defaults to False.
        initial_isotropic_size (int, optional): Voxel size (in mm) of resampled isotropic image
                                                (if used). Defaults to 1.
        number_of_histogram_bins_mi (int, optional): Number of histogram bins used when calculating
                                                     mutual information. Defaults to 30.
        verbose (bool, optional): Print image registration process information. Defaults to False.
        ncores (int, optional): Number of CPU cores used. Defaults to 8.

    Returns:
        [SimpleITK.Image]: The registered moving (secondary) image.
        [SimleITK.Transform]: The linear transformation.

    Notes:
     - smooth_sigmas are relative to resolution staging
        e.g. for image spacing of 1x1x1 mm^3, with smooth sigma=2 and resolution_staging=4, the
        scale of the Gaussian filter would be 2x4 = 8mm (i.e. 8x8x8 mm^3)

    """

    # Re-cast input images
    fixed_image = sitk.Cast(fixed_image, sitk.sitkFloat32)

    moving_image_type = moving_image.GetPixelID()
    moving_image = sitk.Cast(moving_image, sitk.sitkFloat32)

    # (Optional) isotropic resample
    # This changes the behaviour, so care should be taken
    # For highly anisotropic images may be preferable

    if isotropic_resample:
        # First, copy the fixed image so we can resample back into this space at the end
        fixed_image_original = fixed_image
        fixed_image_original.MakeUnique()

        fixed_image = smooth_and_resample(
            fixed_image,
            isotropic_voxel_size_mm=initial_isotropic_size,
        )

        moving_image = smooth_and_resample(
            moving_image,
            isotropic_voxel_size_mm=initial_isotropic_size,
        )

    else:
        fixed_image_original = fixed_image

    # Set up image registration method
    registration = sitk.ImageRegistrationMethod()
    registration.SetNumberOfThreads(ncores)

    registration.SetShrinkFactorsPerLevel(resolution_staging)
    registration.SetSmoothingSigmasPerLevel(smooth_sigmas)
    registration.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # Choose optimiser
    if optimiser.lower() == "lbfgsb":
        registration.SetOptimizerAsLBFGSB(
            gradientConvergenceTolerance=1e-5,
            numberOfIterations=number_of_iterations,
            maximumNumberOfCorrections=5,
            maximumNumberOfFunctionEvaluations=1024,
            costFunctionConvergenceFactor=1e7,
            trace=verbose,
        )
    elif optimiser.lower() == "lbfgs":
        registration.SetOptimizerAsLBFGS2(
            numberOfIterations=number_of_iterations,
            solutionAccuracy=1e-2,
            hessianApproximateAccuracy=6,
            deltaConvergenceDistance=0,
            deltaConvergenceTolerance=0.01,
            lineSearchMaximumEvaluations=40,
            lineSearchMinimumStep=1e-20,
            lineSearchMaximumStep=1e20,
            lineSearchAccuracy=0.01,
        )
    elif optimiser.lower() == "cgls":
        registration.SetOptimizerAsConjugateGradientLineSearch(
            learningRate=0.05, numberOfIterations=number_of_iterations)
        registration.SetOptimizerScalesFromPhysicalShift()
    elif optimiser.lower() == "gradient_descent":
        registration.SetOptimizerAsGradientDescent(
            learningRate=5.0,
            numberOfIterations=number_of_iterations,
            convergenceMinimumValue=1e-6,
            convergenceWindowSize=10,
        )
        registration.SetOptimizerScalesFromPhysicalShift()
    elif optimiser.lower() == "gradient_descent_line_search":
        registration.SetOptimizerAsGradientDescentLineSearch(
            learningRate=1.0, numberOfIterations=number_of_iterations)
        registration.SetOptimizerScalesFromPhysicalShift()

    # Set metric
    if metric == "correlation":
        registration.SetMetricAsCorrelation()
    elif metric == "mean_squares":
        registration.SetMetricAsMeanSquares()
    elif metric == "demons":
        registration.SetMetricAsDemons()
    elif metric == "mutual_information":
        registration.SetMetricAsMattesMutualInformation(
            numberOfHistogramBins=number_of_histogram_bins_mi)

    registration.SetInterpolator(sitk.sitkLinear)

    # Set sampling
    if isinstance(sampling_rate, float):
        registration.SetMetricSamplingPercentage(sampling_rate)
    elif type(sampling_rate) in [np.ndarray, list]:
        registration.SetMetricSamplingPercentagePerLevel(sampling_rate)

    registration.SetMetricSamplingStrategy(
        sitk.ImageRegistrationMethod.REGULAR)

    # Set masks
    if moving_structure is not False:
        registration.SetMetricMovingMask(moving_structure)

    if fixed_structure is not False:
        registration.SetMetricFixedMask(fixed_structure)

    # Set control point spacing
    transform_domain_mesh_size = control_point_spacing_distance_to_number(
        fixed_image, initial_grid_spacing)

    if verbose:
        print(f"Initial grid size: {transform_domain_mesh_size}")

    # Initialise transform
    initial_transform = sitk.BSplineTransformInitializer(
        fixed_image,
        transformDomainMeshSize=[int(i) for i in transform_domain_mesh_size],
    )
    registration.SetInitialTransformAsBSpline(initial_transform,
                                              inPlace=True,
                                              scaleFactors=grid_scale_factors)

    # (Optionally) add iteration commands
    if verbose:
        registration.AddCommand(
            sitk.sitkIterationEvent,
            lambda: registration_command_iteration(registration),
        )
        registration.AddCommand(
            sitk.sitkMultiResolutionIterationEvent,
            lambda: stage_iteration(registration),
        )

    # Run the registration
    output_transform = registration.Execute(fixed=fixed_image,
                                            moving=moving_image)

    # Resample moving image
    registered_image = apply_transform(
        input_image=moving_image,
        reference_image=fixed_image_original,
        transform=output_transform,
        default_value=default_value,
        interpolator=interp_order,
    )

    registered_image = sitk.Cast(registered_image, moving_image_type)

    # Return outputs
    return registered_image, output_transform