def generate_field_asymmetric_extend(mask_image,
                                     vector_asymmetric_extend=(10, 10, 10),
                                     gaussian_smooth=5):
    """
    Extends a structure (defined using a binary mask) using a specified vector.

    Args:
        mask_image ([SimpleITK.Image]): The binary mask to extend.
        vector_asymmetric_extend (tuple, optional): The extension vector applied to the entire binary mask.
                                                    Convention: (+/-, +/-, +/-) = (sup/inf, post/ant, left/right) border is extended.
                                                    Defined in millimetres.
                                                    Defaults to (10, 10, 10).
        gaussian_smooth (int | list, optional): Scale of a Gaussian kernel used to smooth the deformation vector field. Defaults to 5.

    Returns:
        [SimpleITK.Image]: The binary mask following the extension.
        [SimpleITK.DisplacementFieldTransform]: The transform representing the extension.
        [SimpleITK.Image]: The displacement vector field representing the extension.
    """
    # Define array
    # Used for image array manipulations
    mask_image_arr = sitk.GetArrayFromImage(mask_image)

    # The template deformation field
    # Used to generate transforms
    dvf_arr = np.zeros(mask_image_arr.shape + (3, ))
    dvf_arr = dvf_arr - np.array([[[vector_asymmetric_extend[::-1]]]])
    dvf_template = sitk.GetImageFromArray(dvf_arr)

    # Copy image information
    dvf_template.CopyInformation(mask_image)

    dvf_tfm = sitk.DisplacementFieldTransform(
        sitk.Cast(dvf_template, sitk.sitkVectorFloat64))

    mask_image_asymmetric_extend = apply_field(mask_image,
                                               transform=dvf_tfm,
                                               structure=True,
                                               interp=1)

    dvf_template = sitk.Mask(dvf_template, mask_image_asymmetric_extend)

    # smooth
    if np.any(gaussian_smooth):

        if not hasattr(gaussian_smooth, "__iter__"):
            gaussian_smooth = (gaussian_smooth, ) * 3

        dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template,
                                                       gaussian_smooth)

    dvf_tfm = sitk.DisplacementFieldTransform(
        sitk.Cast(dvf_template, sitk.sitkVectorFloat64))

    mask_image_asymmetric_extend = apply_field(mask_image,
                                               transform=dvf_tfm,
                                               structure=True,
                                               interp=1)

    return mask_image_asymmetric_extend, dvf_tfm, dvf_template
예제 #2
0
def apply_augmentation(image, augmentation, masks=[]):

    if not isinstance(image, sitk.Image):
        raise AttributeError("image should be a SimpleITK.Image")

    if isinstance(augmentation, DeformableAugment):
        augmentation = [augmentation]

    if not isinstance(augmentation, Iterable):
        raise AttributeError(
            "augmentation must be a DeformableAugment or an iterable (such as list) of"
            "DeformableAugment's")

    transforms = []
    dvf = None
    for aug in augmentation:

        if not isinstance(aug, DeformableAugment):
            raise AttributeError(
                "Each augmentation must be of type DeformableAugment")

        tfm, field = aug.augment()
        transforms.append(tfm)

        if dvf is None:
            dvf = field
        else:
            dvf += field

    transform = sitk.CompositeTransform(transforms)
    del transforms

    image_deformed = apply_field(
        image,
        transform,
        structure=False,
        default_value=int(sitk.GetArrayViewFromImage(image).min()),
    )

    masks_deformed = []
    for mask in masks:
        masks_deformed.append(
            apply_field(mask, transform=transform, structure=True, interp=1))

    if masks:
        return image_deformed, masks_deformed, dvf

    return image_deformed, dvf
예제 #3
0
파일: run.py 프로젝트: SimonBiggs/platipy
def run_cardiac_segmentation(img, settings=CARDIAC_SETTINGS_DEFAULTS):
    """Runs the atlas-based cardiac segmentation

    Args:
        img (sitk.Image):
        settings (dict, optional): Dictionary containing settings for algorithm.
                                   Defaults to default_settings.

    Returns:
        dict: Dictionary containing output of segmentation
    """

    results = {}
    return_as_cropped = settings["returnAsCropped"]
    """
    Initialisation - Read in atlases
    - image files
    - structure files

        Atlas structure:
        'ID': 'Original': 'CT Image'    : sitk.Image
                            'Struct A'    : sitk.Image
                            'Struct B'    : sitk.Image
                'RIR'     : 'CT Image'    : sitk.Image
                            'Transform'   : transform parameter map
                            'Struct A'    : sitk.Image
                            'Struct B'    : sitk.Image
                'DIR'     : 'CT Image'    : sitk.Image
                            'Transform'   : displacement field transform
                            'Weight Map'  : sitk.Image
                            'Struct A'    : sitk.Image
                            'Struct B'    : sitk.Image


    """

    logger.info("")
    # Settings
    atlas_path = settings["atlasSettings"]["atlasPath"]
    atlas_id_list = settings["atlasSettings"]["atlasIdList"]
    atlas_structures = settings["atlasSettings"]["atlasStructures"]

    atlas_image_format = settings["atlasSettings"]["atlasImageFormat"]
    atlas_label_format = settings["atlasSettings"]["atlasLabelFormat"]

    auto_crop_atlas = settings["atlasSettings"]["autoCropAtlas"]

    atlas_set = {}
    for atlas_id in atlas_id_list:
        atlas_set[atlas_id] = {}
        atlas_set[atlas_id]["Original"] = {}

        image = sitk.ReadImage(
            f"{atlas_path}/{atlas_image_format.format(atlas_id)}")

        structures = {
            struct: sitk.ReadImage(
                f"{atlas_path}/{atlas_label_format.format(atlas_id, struct)}")
            for struct in atlas_structures
        }

        if auto_crop_atlas:
            logger.info(f"Automatically cropping atlas: {atlas_id}")

            original_volume = np.product(image.GetSize())

            label_stats_image_filter = sitk.LabelStatisticsImageFilter()
            label_stats_image_filter.Execute(image,
                                             sum(structures.values()) > 0)
            bounding_box = list(label_stats_image_filter.GetBoundingBox(1))
            index = [bounding_box[x * 2] for x in range(3)]
            size = [
                bounding_box[(x * 2) + 1] - bounding_box[x * 2]
                for x in range(3)
            ]

            image = sitk.RegionOfInterest(image, size=size, index=index)

            final_volume = np.product(image.GetSize())
            logger.info(
                f"  > Volume reduced by factor {original_volume/final_volume:.2f}"
            )

            for struct in atlas_structures:
                structures[struct] = sitk.RegionOfInterest(structures[struct],
                                                           size=size,
                                                           index=index)

        atlas_set[atlas_id]["Original"]["CT Image"] = image

        for struct in atlas_structures:
            atlas_set[atlas_id]["Original"][struct] = structures[struct]
    """
    Step 1 - Automatic cropping using a translation transform
    - Registration of atlas images (maximum 5)
    - Potential expansion of the bounding box to ensure entire volume of interest is enclosed
    - Target image is cropped
    """
    # Settings
    quick_reg_settings = {
        "shrink_factors": [8],
        "smooth_sigmas": [0],
        "sampling_rate": 0.75,
        "default_value": -1024,
        "number_of_iterations": 25,
        "final_interp": 3,
        "metric": "mean_squares",
        "optimiser": "gradient_descent_line_search",
    }

    registered_crop_images = []

    logger.info(f"Running initial Translation tranform to crop image volume")

    for atlas_id in atlas_id_list[:min([8, len(atlas_id_list)])]:

        logger.info(f"  > atlas {atlas_id}")

        # Register the atlases
        atlas_set[atlas_id]["RIR"] = {}
        atlas_image = atlas_set[atlas_id]["Original"]["CT Image"]

        reg_image, _ = initial_registration(
            img,
            atlas_image,
            moving_structure=False,
            fixed_structure=False,
            options=quick_reg_settings,
            trace=False,
            reg_method="Similarity",
        )

        registered_crop_images.append(sitk.Cast(reg_image, sitk.sitkFloat32))

        del reg_image

    combined_image_extent = sum(registered_crop_images) / len(
        registered_crop_images) > -1000

    shape_filter = sitk.LabelShapeStatisticsImageFilter()
    shape_filter.Execute(combined_image_extent)
    bounding_box = np.array(shape_filter.GetBoundingBox(1))
    """
    Crop image to region of interest (ROI)
    --> Defined by images
    """

    expansion = settings["autoCropSettings"]["expansion"]
    expansion_array = expansion * np.array(img.GetSpacing())

    crop_box_size, crop_box_index = label_to_roi(img,
                                                 combined_image_extent,
                                                 expansion=expansion_array)
    img_crop = crop_to_roi(img, crop_box_size, crop_box_index)
    logger.info(f"Calculated crop box\n\
                {crop_box_index}\n\
                {crop_box_size}\n\n\
                Volume reduced by factor {np.product(img.GetSize())/np.product(crop_box_size)}"
                )
    """
    Step 2 - Rigid registration of target images
    - Individual atlas images are registered to the target
    - The transformation is used to propagate the labels onto the target
    """
    initial_reg = settings["rigidSettings"]["initialReg"]
    rigid_options = settings["rigidSettings"]["options"]
    trace = settings["rigidSettings"]["trace"]
    guide_structure = settings["rigidSettings"]["guideStructure"]

    logger.info(f"Running {initial_reg} tranform to align atlas images")

    for atlas_id in atlas_id_list:
        # Register the atlases

        logger.info(f"  > atlas {atlas_id}")

        atlas_set[atlas_id]["RIR"] = {}
        atlas_image = atlas_set[atlas_id]["Original"]["CT Image"]

        if guide_structure:
            atlas_struct = atlas_set[atlas_id]["Original"][guide_structure]
        else:
            atlas_struct = False

        rigid_image, initial_tfm = initial_registration(
            img_crop,
            atlas_image,
            moving_structure=atlas_struct,
            options=rigid_options,
            trace=trace,
            reg_method=initial_reg,
        )

        # Save in the atlas dict
        atlas_set[atlas_id]["RIR"]["CT Image"] = rigid_image
        atlas_set[atlas_id]["RIR"]["Transform"] = initial_tfm

        # sitk.WriteImage(rigidImage, f'./RR_{atlas_id}.nii.gz')

        for struct in atlas_structures:
            input_struct = atlas_set[atlas_id]["Original"][struct]
            atlas_set[atlas_id]["RIR"][struct] = transform_propagation(
                img_crop,
                input_struct,
                initial_tfm,
                structure=True,
                interp=sitk.sitkLinear,
            )
    """
    Step 3 - Deformable image registration
    - Using Fast Symmetric Diffeomorphic Demons
    """
    # Settings
    isotropic_resample = settings["deformableSettings"]["isotropicResample"]
    resolution_staging = settings["deformableSettings"]["resolutionStaging"]
    iteration_staging = settings["deformableSettings"]["iterationStaging"]
    smoothing_sigmas = settings["deformableSettings"]["smoothingSigmas"]
    ncores = settings["deformableSettings"]["ncores"]
    trace = settings["deformableSettings"]["trace"]

    logger.info(f"Running DIR to register atlas images")

    for atlas_id in atlas_id_list:

        logger.info(f"  > atlas {atlas_id}")

        # Register the atlases
        atlas_set[atlas_id]["DIR"] = {}
        atlas_image = atlas_set[atlas_id]["RIR"]["CT Image"]

        cleaned_img_crop = sitk.Mask(img_crop,
                                     atlas_image > -1023,
                                     outsideValue=-1024)

        deform_image, deform_field = fast_symmetric_forces_demons_registration(
            cleaned_img_crop,
            atlas_image,
            resolution_staging=resolution_staging,
            iteration_staging=iteration_staging,
            isotropic_resample=isotropic_resample,
            smoothing_sigmas=smoothing_sigmas,
            ncores=ncores,
            trace=trace,
        )

        # Save in the atlas dict
        atlas_set[atlas_id]["DIR"]["CT Image"] = deform_image
        atlas_set[atlas_id]["DIR"]["Transform"] = deform_field

        # sitk.WriteImage(deformImage, f'./DIR_{atlas_id}.nii.gz')

        for struct in atlas_structures:
            input_struct = atlas_set[atlas_id]["RIR"][struct]
            atlas_set[atlas_id]["DIR"][struct] = apply_field(
                input_struct,
                deform_field,
                structure=True,
                interp=sitk.sitkLinear)
    """
    Step 4 - Iterative atlas removal
    - This is an automatic process that will attempt to remove inconsistent atlases from the entire set

    """
    # Compute weight maps
    # Here we use simple GWV as this minises the potentially negative influence of mis-registered atlases
    reference_structure = settings["IARSettings"]["referenceStructure"]

    if reference_structure:

        smooth_distance_maps = settings["IARSettings"]["smoothDistanceMaps"]
        smooth_sigma = settings["IARSettings"]["smoothSigma"]
        z_score_statistic = settings["IARSettings"]["zScoreStatistic"]
        outlier_method = settings["IARSettings"]["outlierMethod"]
        outlier_factor = settings["IARSettings"]["outlierFactor"]
        min_best_atlases = settings["IARSettings"]["minBestAtlases"]
        project_on_sphere = settings["IARSettings"]["project_on_sphere"]

        for atlas_id in atlas_id_list:
            atlas_image = atlas_set[atlas_id]["DIR"]["CT Image"]
            weight_map = compute_weight_map(img_crop,
                                            atlas_image,
                                            vote_type="global")
            atlas_set[atlas_id]["DIR"]["Weight Map"] = weight_map

        atlas_set = run_iar(
            atlas_set=atlas_set,
            structure_name=reference_structure,
            smooth_maps=smooth_distance_maps,
            smooth_sigma=smooth_sigma,
            z_score=z_score_statistic,
            outlier_method=outlier_method,
            min_best_atlases=min_best_atlases,
            n_factor=outlier_factor,
            iteration=0,
            single_step=False,
            project_on_sphere=project_on_sphere,
        )

    else:
        logger.info(
            "IAR: No reference structure, skipping iterative atlas removal.")
    """
    Step 4 - Vessel Splining

    """

    vessel_name_list = settings["vesselSpliningSettings"]["vesselNameList"]

    if len(vessel_name_list) > 0:

        vessel_radius_mm = settings["vesselSpliningSettings"][
            "vesselRadius_mm"]
        splining_direction = settings["vesselSpliningSettings"][
            "spliningDirection"]
        stop_condition = settings["vesselSpliningSettings"]["stopCondition"]
        stop_condition_value = settings["vesselSpliningSettings"][
            "stopConditionValue"]

        segmented_vessel_dict = vesselSplineGeneration(
            img_crop,
            atlas_set,
            vessel_name_list,
            vessel_radius_mm,
            stop_condition,
            stop_condition_value,
            splining_direction,
        )
    else:
        logger.info("No vessel splining required, continue.")
    """
    Step 5 - Label Fusion
    """
    # Compute weight maps
    vote_type = settings["labelFusionSettings"]["voteType"]
    vote_params = settings["labelFusionSettings"]["voteParams"]

    # Compute weight maps
    for atlas_id in list(atlas_set.keys()):
        atlas_image = atlas_set[atlas_id]["DIR"]["CT Image"]
        weight_map = compute_weight_map(img_crop,
                                        atlas_image,
                                        vote_type=vote_type,
                                        vote_params=vote_params)
        atlas_set[atlas_id]["DIR"]["Weight Map"] = weight_map

    combined_label_dict = combine_labels(atlas_set, atlas_structures)
    """
    Step 6 - Paste the cropped structure into the original image space
    """

    template_img_binary = sitk.Cast((img * 0), sitk.sitkUInt8)

    vote_structures = settings["labelFusionSettings"]["optimalThreshold"].keys(
    )

    for structure_name in vote_structures:

        probability_map = combined_label_dict[structure_name]

        optimal_threshold = settings["labelFusionSettings"][
            "optimalThreshold"][structure_name]

        binary_struct = process_probability_image(probability_map,
                                                  optimal_threshold)

        if return_as_cropped:
            results[structure_name] = binary_struct

        else:
            paste_binary_img = sitk.Paste(
                template_img_binary,
                binary_struct,
                binary_struct.GetSize(),
                (0, 0, 0),
                crop_box_index,
            )

            results[structure_name] = paste_binary_img

    for structure_name in vessel_name_list:
        binary_struct = segmented_vessel_dict[structure_name]

        if return_as_cropped:
            results[structure_name] = binary_struct

        else:
            paste_img_binary = sitk.Paste(
                template_img_binary,
                binary_struct,
                binary_struct.GetSize(),
                (0, 0, 0),
                crop_box_index,
            )

            results[structure_name] = paste_img_binary

    if return_as_cropped:
        results["CROP_IMAGE"] = img_crop

    return results
def generate_field_radial_bend(
        reference_image,
        body_mask,
        reference_point,
        axis_of_rotation=[0, 0, -1],
        scale=0.1,
        mask_bend_from_reference_point=("z", "inf"),
        gaussian_smooth=5,
):
    """
    Generates a synthetic field characterised by radial bending.
    Typically, this field would be used to simulate a moving head and so masking is important.

    Args:
        reference_image ([SimpleITK.Image]): The image to be deformed.
        body_mask ([SimpleITK.Image]): A binary mask in which the deformation field will be defined
        reference_point ([tuple]): The point (z,y,x) about which the rotation field is defined.
        axis_of_rotation (tuple, optional): The axis of rotation (z,y,x). Defaults to [0, 0, -1].
        scale (int, optional): The deformation vector length at each point will equal scale multiplied by the distance to that point from reference_point. Defaults to 1.
        mask_bend_from_reference_point (tuple, optional): The dimension (z=axial, y=coronal, x=sagittal) and limit (inf/sup, post/ant, left/right) for masking the vector field, relative to the reference point. Defaults to ("z", "inf").
        gaussian_smooth (int | list, optional): Scale of a Gaussian kernel used to smooth the deformation vector field. Defaults to 5.

    Returns:
        [SimpleITK.Image]: The binary mask following the expansion.
        [SimpleITK.DisplacementFieldTransform]: The transform representing the expansion.
        [SimpleITK.Image]: The displacement vector field representing the expansion.
    """

    body_mask_arr = sitk.GetArrayFromImage(body_mask)

    if mask_bend_from_reference_point is not False:
        if mask_bend_from_reference_point[0] == "z":
            if mask_bend_from_reference_point[1] == "inf":
                body_mask_arr[:reference_point[0], :, :] = 0
            elif mask_bend_from_reference_point[1] == "sup":
                body_mask_arr[reference_point[0]:, :, :] = 0
        if mask_bend_from_reference_point[0] == "y":
            if mask_bend_from_reference_point[1] == "post":
                body_mask_arr[:, reference_point[1]:, :] = 0
            elif mask_bend_from_reference_point[1] == "ant":
                body_mask_arr[:, :reference_point[1], :] = 0
        if mask_bend_from_reference_point[0] == "x":
            if mask_bend_from_reference_point[1] == "left":
                body_mask_arr[:, :, reference_point[2]:] = 0
            elif mask_bend_from_reference_point[1] == "right":
                body_mask_arr[:, :, :reference_point[2]] = 0

    pt_arr = np.array(np.where(body_mask_arr))
    vector_ref_to_pt = pt_arr - np.array(reference_point)[:, None]

    # Normalise the normal vector (axis_of_rotation)
    axis_of_rotation = np.array(axis_of_rotation)
    axis_of_rotation = axis_of_rotation / np.linalg.norm(axis_of_rotation)

    deformation_vectors = np.cross(vector_ref_to_pt[::-1].T,
                                   axis_of_rotation[::-1])

    dvf_template = sitk.Image(reference_image.GetSize(),
                              sitk.sitkVectorFloat64, 3)
    dvf_template_arr = sitk.GetArrayFromImage(dvf_template)

    if scale is not False:
        dvf_template_arr[np.where(body_mask_arr)] = deformation_vectors * scale

    dvf_template = sitk.GetImageFromArray(dvf_template_arr)
    dvf_template.CopyInformation(reference_image)

    # smooth
    if np.any(gaussian_smooth):

        if not hasattr(gaussian_smooth, "__iter__"):
            gaussian_smooth = (gaussian_smooth, ) * 3

        dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template,
                                                       gaussian_smooth)

    dvf_tfm = sitk.DisplacementFieldTransform(
        sitk.Cast(dvf_template, sitk.sitkVectorFloat64))
    reference_image_bend = apply_field(
        reference_image,
        transform=dvf_tfm,
        structure=False,
        default_value=int(sitk.GetArrayViewFromImage(reference_image).min()),
        interp=2,
    )

    return reference_image_bend, dvf_tfm, dvf_template
def generate_field_expand(
    mask_image,
    bone_mask=False,
    expand=3,
    gaussian_smooth=5,
):
    """
    Expands a structure (defined using a binary mask) using a specified vector to define the dilation kernel.

    Args:
        mask_image ([SimpleITK.Image]): The binary mask to expand.
        bone_mask ([SimpleITK.Image, optional]): A binary mask defining regions where we expect restricted deformations.
        vector_asymmetric_extend (int |tuple, optional): The expansion vector applied to the entire binary mask.
                                                    Convention: (z,y,x) size of expansion kernel.
                                                    Defined in millimetres.
                                                    Defaults to 3.
        gaussian_smooth (int | list, optional): Scale of a Gaussian kernel used to smooth the deformation vector field. Defaults to 5.

    Returns:
        [SimpleITK.Image]: The binary mask following the expansion.
        [SimpleITK.DisplacementFieldTransform]: The transform representing the expansion.
        [SimpleITK.Image]: The displacement vector field representing the expansion.
    """

    registration_mask_original = convert_mask_to_reg_structure(
        mask_image_original)

    if bone_mask is not False:
        mask_image_original = mask_image + bone_mask
    else:
        mask_image_original = mask_image

    # Use binary erosion to create a smaller volume
    if not hasattr(expand, "__iter__"):
        expand = (expand, ) * 3

    expand = np.array(expand)

    # Convert voxels to millimetres
    expand = expand / np.array(mask_image.GetSpacing()[::-1])

    # Re-order to (x,y,z)
    expand = expand[::-1]
    # expand = [int(i / j) for i, j in zip(expand, mask_image.GetSpacing()[::-1])][::-1]

    # If all negative: erode
    if np.all(np.array(expand) <= 0):
        print("All factors negative: shrinking only.")
        mask_image_expand = sitk.BinaryErode(
            mask_image,
            np.abs(expand).astype(int).tolist(), sitk.sitkBall)

    # If all positive: dilate
    elif np.all(np.array(expand) >= 0):
        print("All factors positive: expansion only.")
        mask_image_expand = sitk.BinaryDilate(
            mask_image,
            np.abs(expand).astype(int).tolist(), sitk.sitkBall)

    # Otherwise: sequential operations
    else:
        print("Mixed factors: shrinking and expansion.")
        expansion_kernel = expand * (expand > 0)
        shrink_kernel = expand * (expand < 0)

        mask_image_expand = sitk.BinaryDilate(
            mask_image,
            np.abs(expansion_kernel).astype(int).tolist(), sitk.sitkBall)
        mask_image_expand = sitk.BinaryErode(
            mask_image_expand,
            np.abs(shrink_kernel).astype(int).tolist(), sitk.sitkBall)

    registration_mask_expand = convert_mask_to_reg_structure(mask_image_expand)
    if bone_mask is not False:
        registration_mask_expand = registration_mask_expand + bone_mask

    # Use DIR to find the deformation
    _, _, dvf_template = fast_symmetric_forces_demons_registration(
        registration_mask_expand,
        registration_mask_original,
        isotropic_resample=True,
        resolution_staging=[4, 2],
        iteration_staging=[10, 10],
        ncores=8,
        return_field=True,
    )

    # smooth
    if np.any(gaussian_smooth):

        if not hasattr(gaussian_smooth, "__iter__"):
            gaussian_smooth = (gaussian_smooth, ) * 3

        dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template,
                                                       gaussian_smooth)

    dvf_tfm = sitk.DisplacementFieldTransform(
        sitk.Cast(dvf_template, sitk.sitkVectorFloat64))

    mask_image_symmetric_expand = apply_field(mask_image,
                                              transform=dvf_tfm,
                                              structure=True,
                                              interp=1)

    return mask_image_symmetric_expand, dvf_tfm, dvf_template