def generate_field_asymmetric_extend(mask_image, vector_asymmetric_extend=(10, 10, 10), gaussian_smooth=5): """ Extends a structure (defined using a binary mask) using a specified vector. Args: mask_image ([SimpleITK.Image]): The binary mask to extend. vector_asymmetric_extend (tuple, optional): The extension vector applied to the entire binary mask. Convention: (+/-, +/-, +/-) = (sup/inf, post/ant, left/right) border is extended. Defined in millimetres. Defaults to (10, 10, 10). gaussian_smooth (int | list, optional): Scale of a Gaussian kernel used to smooth the deformation vector field. Defaults to 5. Returns: [SimpleITK.Image]: The binary mask following the extension. [SimpleITK.DisplacementFieldTransform]: The transform representing the extension. [SimpleITK.Image]: The displacement vector field representing the extension. """ # Define array # Used for image array manipulations mask_image_arr = sitk.GetArrayFromImage(mask_image) # The template deformation field # Used to generate transforms dvf_arr = np.zeros(mask_image_arr.shape + (3, )) dvf_arr = dvf_arr - np.array([[[vector_asymmetric_extend[::-1]]]]) dvf_template = sitk.GetImageFromArray(dvf_arr) # Copy image information dvf_template.CopyInformation(mask_image) dvf_tfm = sitk.DisplacementFieldTransform( sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) mask_image_asymmetric_extend = apply_field(mask_image, transform=dvf_tfm, structure=True, interp=1) dvf_template = sitk.Mask(dvf_template, mask_image_asymmetric_extend) # smooth if np.any(gaussian_smooth): if not hasattr(gaussian_smooth, "__iter__"): gaussian_smooth = (gaussian_smooth, ) * 3 dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) dvf_tfm = sitk.DisplacementFieldTransform( sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) mask_image_asymmetric_extend = apply_field(mask_image, transform=dvf_tfm, structure=True, interp=1) return mask_image_asymmetric_extend, dvf_tfm, dvf_template
def apply_augmentation(image, augmentation, masks=[]): if not isinstance(image, sitk.Image): raise AttributeError("image should be a SimpleITK.Image") if isinstance(augmentation, DeformableAugment): augmentation = [augmentation] if not isinstance(augmentation, Iterable): raise AttributeError( "augmentation must be a DeformableAugment or an iterable (such as list) of" "DeformableAugment's") transforms = [] dvf = None for aug in augmentation: if not isinstance(aug, DeformableAugment): raise AttributeError( "Each augmentation must be of type DeformableAugment") tfm, field = aug.augment() transforms.append(tfm) if dvf is None: dvf = field else: dvf += field transform = sitk.CompositeTransform(transforms) del transforms image_deformed = apply_field( image, transform, structure=False, default_value=int(sitk.GetArrayViewFromImage(image).min()), ) masks_deformed = [] for mask in masks: masks_deformed.append( apply_field(mask, transform=transform, structure=True, interp=1)) if masks: return image_deformed, masks_deformed, dvf return image_deformed, dvf
def run_cardiac_segmentation(img, settings=CARDIAC_SETTINGS_DEFAULTS): """Runs the atlas-based cardiac segmentation Args: img (sitk.Image): settings (dict, optional): Dictionary containing settings for algorithm. Defaults to default_settings. Returns: dict: Dictionary containing output of segmentation """ results = {} return_as_cropped = settings["returnAsCropped"] """ Initialisation - Read in atlases - image files - structure files Atlas structure: 'ID': 'Original': 'CT Image' : sitk.Image 'Struct A' : sitk.Image 'Struct B' : sitk.Image 'RIR' : 'CT Image' : sitk.Image 'Transform' : transform parameter map 'Struct A' : sitk.Image 'Struct B' : sitk.Image 'DIR' : 'CT Image' : sitk.Image 'Transform' : displacement field transform 'Weight Map' : sitk.Image 'Struct A' : sitk.Image 'Struct B' : sitk.Image """ logger.info("") # Settings atlas_path = settings["atlasSettings"]["atlasPath"] atlas_id_list = settings["atlasSettings"]["atlasIdList"] atlas_structures = settings["atlasSettings"]["atlasStructures"] atlas_image_format = settings["atlasSettings"]["atlasImageFormat"] atlas_label_format = settings["atlasSettings"]["atlasLabelFormat"] auto_crop_atlas = settings["atlasSettings"]["autoCropAtlas"] atlas_set = {} for atlas_id in atlas_id_list: atlas_set[atlas_id] = {} atlas_set[atlas_id]["Original"] = {} image = sitk.ReadImage( f"{atlas_path}/{atlas_image_format.format(atlas_id)}") structures = { struct: sitk.ReadImage( f"{atlas_path}/{atlas_label_format.format(atlas_id, struct)}") for struct in atlas_structures } if auto_crop_atlas: logger.info(f"Automatically cropping atlas: {atlas_id}") original_volume = np.product(image.GetSize()) label_stats_image_filter = sitk.LabelStatisticsImageFilter() label_stats_image_filter.Execute(image, sum(structures.values()) > 0) bounding_box = list(label_stats_image_filter.GetBoundingBox(1)) index = [bounding_box[x * 2] for x in range(3)] size = [ bounding_box[(x * 2) + 1] - bounding_box[x * 2] for x in range(3) ] image = sitk.RegionOfInterest(image, size=size, index=index) final_volume = np.product(image.GetSize()) logger.info( f" > Volume reduced by factor {original_volume/final_volume:.2f}" ) for struct in atlas_structures: structures[struct] = sitk.RegionOfInterest(structures[struct], size=size, index=index) atlas_set[atlas_id]["Original"]["CT Image"] = image for struct in atlas_structures: atlas_set[atlas_id]["Original"][struct] = structures[struct] """ Step 1 - Automatic cropping using a translation transform - Registration of atlas images (maximum 5) - Potential expansion of the bounding box to ensure entire volume of interest is enclosed - Target image is cropped """ # Settings quick_reg_settings = { "shrink_factors": [8], "smooth_sigmas": [0], "sampling_rate": 0.75, "default_value": -1024, "number_of_iterations": 25, "final_interp": 3, "metric": "mean_squares", "optimiser": "gradient_descent_line_search", } registered_crop_images = [] logger.info(f"Running initial Translation tranform to crop image volume") for atlas_id in atlas_id_list[:min([8, len(atlas_id_list)])]: logger.info(f" > atlas {atlas_id}") # Register the atlases atlas_set[atlas_id]["RIR"] = {} atlas_image = atlas_set[atlas_id]["Original"]["CT Image"] reg_image, _ = initial_registration( img, atlas_image, moving_structure=False, fixed_structure=False, options=quick_reg_settings, trace=False, reg_method="Similarity", ) registered_crop_images.append(sitk.Cast(reg_image, sitk.sitkFloat32)) del reg_image combined_image_extent = sum(registered_crop_images) / len( registered_crop_images) > -1000 shape_filter = sitk.LabelShapeStatisticsImageFilter() shape_filter.Execute(combined_image_extent) bounding_box = np.array(shape_filter.GetBoundingBox(1)) """ Crop image to region of interest (ROI) --> Defined by images """ expansion = settings["autoCropSettings"]["expansion"] expansion_array = expansion * np.array(img.GetSpacing()) crop_box_size, crop_box_index = label_to_roi(img, combined_image_extent, expansion=expansion_array) img_crop = crop_to_roi(img, crop_box_size, crop_box_index) logger.info(f"Calculated crop box\n\ {crop_box_index}\n\ {crop_box_size}\n\n\ Volume reduced by factor {np.product(img.GetSize())/np.product(crop_box_size)}" ) """ Step 2 - Rigid registration of target images - Individual atlas images are registered to the target - The transformation is used to propagate the labels onto the target """ initial_reg = settings["rigidSettings"]["initialReg"] rigid_options = settings["rigidSettings"]["options"] trace = settings["rigidSettings"]["trace"] guide_structure = settings["rigidSettings"]["guideStructure"] logger.info(f"Running {initial_reg} tranform to align atlas images") for atlas_id in atlas_id_list: # Register the atlases logger.info(f" > atlas {atlas_id}") atlas_set[atlas_id]["RIR"] = {} atlas_image = atlas_set[atlas_id]["Original"]["CT Image"] if guide_structure: atlas_struct = atlas_set[atlas_id]["Original"][guide_structure] else: atlas_struct = False rigid_image, initial_tfm = initial_registration( img_crop, atlas_image, moving_structure=atlas_struct, options=rigid_options, trace=trace, reg_method=initial_reg, ) # Save in the atlas dict atlas_set[atlas_id]["RIR"]["CT Image"] = rigid_image atlas_set[atlas_id]["RIR"]["Transform"] = initial_tfm # sitk.WriteImage(rigidImage, f'./RR_{atlas_id}.nii.gz') for struct in atlas_structures: input_struct = atlas_set[atlas_id]["Original"][struct] atlas_set[atlas_id]["RIR"][struct] = transform_propagation( img_crop, input_struct, initial_tfm, structure=True, interp=sitk.sitkLinear, ) """ Step 3 - Deformable image registration - Using Fast Symmetric Diffeomorphic Demons """ # Settings isotropic_resample = settings["deformableSettings"]["isotropicResample"] resolution_staging = settings["deformableSettings"]["resolutionStaging"] iteration_staging = settings["deformableSettings"]["iterationStaging"] smoothing_sigmas = settings["deformableSettings"]["smoothingSigmas"] ncores = settings["deformableSettings"]["ncores"] trace = settings["deformableSettings"]["trace"] logger.info(f"Running DIR to register atlas images") for atlas_id in atlas_id_list: logger.info(f" > atlas {atlas_id}") # Register the atlases atlas_set[atlas_id]["DIR"] = {} atlas_image = atlas_set[atlas_id]["RIR"]["CT Image"] cleaned_img_crop = sitk.Mask(img_crop, atlas_image > -1023, outsideValue=-1024) deform_image, deform_field = fast_symmetric_forces_demons_registration( cleaned_img_crop, atlas_image, resolution_staging=resolution_staging, iteration_staging=iteration_staging, isotropic_resample=isotropic_resample, smoothing_sigmas=smoothing_sigmas, ncores=ncores, trace=trace, ) # Save in the atlas dict atlas_set[atlas_id]["DIR"]["CT Image"] = deform_image atlas_set[atlas_id]["DIR"]["Transform"] = deform_field # sitk.WriteImage(deformImage, f'./DIR_{atlas_id}.nii.gz') for struct in atlas_structures: input_struct = atlas_set[atlas_id]["RIR"][struct] atlas_set[atlas_id]["DIR"][struct] = apply_field( input_struct, deform_field, structure=True, interp=sitk.sitkLinear) """ Step 4 - Iterative atlas removal - This is an automatic process that will attempt to remove inconsistent atlases from the entire set """ # Compute weight maps # Here we use simple GWV as this minises the potentially negative influence of mis-registered atlases reference_structure = settings["IARSettings"]["referenceStructure"] if reference_structure: smooth_distance_maps = settings["IARSettings"]["smoothDistanceMaps"] smooth_sigma = settings["IARSettings"]["smoothSigma"] z_score_statistic = settings["IARSettings"]["zScoreStatistic"] outlier_method = settings["IARSettings"]["outlierMethod"] outlier_factor = settings["IARSettings"]["outlierFactor"] min_best_atlases = settings["IARSettings"]["minBestAtlases"] project_on_sphere = settings["IARSettings"]["project_on_sphere"] for atlas_id in atlas_id_list: atlas_image = atlas_set[atlas_id]["DIR"]["CT Image"] weight_map = compute_weight_map(img_crop, atlas_image, vote_type="global") atlas_set[atlas_id]["DIR"]["Weight Map"] = weight_map atlas_set = run_iar( atlas_set=atlas_set, structure_name=reference_structure, smooth_maps=smooth_distance_maps, smooth_sigma=smooth_sigma, z_score=z_score_statistic, outlier_method=outlier_method, min_best_atlases=min_best_atlases, n_factor=outlier_factor, iteration=0, single_step=False, project_on_sphere=project_on_sphere, ) else: logger.info( "IAR: No reference structure, skipping iterative atlas removal.") """ Step 4 - Vessel Splining """ vessel_name_list = settings["vesselSpliningSettings"]["vesselNameList"] if len(vessel_name_list) > 0: vessel_radius_mm = settings["vesselSpliningSettings"][ "vesselRadius_mm"] splining_direction = settings["vesselSpliningSettings"][ "spliningDirection"] stop_condition = settings["vesselSpliningSettings"]["stopCondition"] stop_condition_value = settings["vesselSpliningSettings"][ "stopConditionValue"] segmented_vessel_dict = vesselSplineGeneration( img_crop, atlas_set, vessel_name_list, vessel_radius_mm, stop_condition, stop_condition_value, splining_direction, ) else: logger.info("No vessel splining required, continue.") """ Step 5 - Label Fusion """ # Compute weight maps vote_type = settings["labelFusionSettings"]["voteType"] vote_params = settings["labelFusionSettings"]["voteParams"] # Compute weight maps for atlas_id in list(atlas_set.keys()): atlas_image = atlas_set[atlas_id]["DIR"]["CT Image"] weight_map = compute_weight_map(img_crop, atlas_image, vote_type=vote_type, vote_params=vote_params) atlas_set[atlas_id]["DIR"]["Weight Map"] = weight_map combined_label_dict = combine_labels(atlas_set, atlas_structures) """ Step 6 - Paste the cropped structure into the original image space """ template_img_binary = sitk.Cast((img * 0), sitk.sitkUInt8) vote_structures = settings["labelFusionSettings"]["optimalThreshold"].keys( ) for structure_name in vote_structures: probability_map = combined_label_dict[structure_name] optimal_threshold = settings["labelFusionSettings"][ "optimalThreshold"][structure_name] binary_struct = process_probability_image(probability_map, optimal_threshold) if return_as_cropped: results[structure_name] = binary_struct else: paste_binary_img = sitk.Paste( template_img_binary, binary_struct, binary_struct.GetSize(), (0, 0, 0), crop_box_index, ) results[structure_name] = paste_binary_img for structure_name in vessel_name_list: binary_struct = segmented_vessel_dict[structure_name] if return_as_cropped: results[structure_name] = binary_struct else: paste_img_binary = sitk.Paste( template_img_binary, binary_struct, binary_struct.GetSize(), (0, 0, 0), crop_box_index, ) results[structure_name] = paste_img_binary if return_as_cropped: results["CROP_IMAGE"] = img_crop return results
def generate_field_radial_bend( reference_image, body_mask, reference_point, axis_of_rotation=[0, 0, -1], scale=0.1, mask_bend_from_reference_point=("z", "inf"), gaussian_smooth=5, ): """ Generates a synthetic field characterised by radial bending. Typically, this field would be used to simulate a moving head and so masking is important. Args: reference_image ([SimpleITK.Image]): The image to be deformed. body_mask ([SimpleITK.Image]): A binary mask in which the deformation field will be defined reference_point ([tuple]): The point (z,y,x) about which the rotation field is defined. axis_of_rotation (tuple, optional): The axis of rotation (z,y,x). Defaults to [0, 0, -1]. scale (int, optional): The deformation vector length at each point will equal scale multiplied by the distance to that point from reference_point. Defaults to 1. mask_bend_from_reference_point (tuple, optional): The dimension (z=axial, y=coronal, x=sagittal) and limit (inf/sup, post/ant, left/right) for masking the vector field, relative to the reference point. Defaults to ("z", "inf"). gaussian_smooth (int | list, optional): Scale of a Gaussian kernel used to smooth the deformation vector field. Defaults to 5. Returns: [SimpleITK.Image]: The binary mask following the expansion. [SimpleITK.DisplacementFieldTransform]: The transform representing the expansion. [SimpleITK.Image]: The displacement vector field representing the expansion. """ body_mask_arr = sitk.GetArrayFromImage(body_mask) if mask_bend_from_reference_point is not False: if mask_bend_from_reference_point[0] == "z": if mask_bend_from_reference_point[1] == "inf": body_mask_arr[:reference_point[0], :, :] = 0 elif mask_bend_from_reference_point[1] == "sup": body_mask_arr[reference_point[0]:, :, :] = 0 if mask_bend_from_reference_point[0] == "y": if mask_bend_from_reference_point[1] == "post": body_mask_arr[:, reference_point[1]:, :] = 0 elif mask_bend_from_reference_point[1] == "ant": body_mask_arr[:, :reference_point[1], :] = 0 if mask_bend_from_reference_point[0] == "x": if mask_bend_from_reference_point[1] == "left": body_mask_arr[:, :, reference_point[2]:] = 0 elif mask_bend_from_reference_point[1] == "right": body_mask_arr[:, :, :reference_point[2]] = 0 pt_arr = np.array(np.where(body_mask_arr)) vector_ref_to_pt = pt_arr - np.array(reference_point)[:, None] # Normalise the normal vector (axis_of_rotation) axis_of_rotation = np.array(axis_of_rotation) axis_of_rotation = axis_of_rotation / np.linalg.norm(axis_of_rotation) deformation_vectors = np.cross(vector_ref_to_pt[::-1].T, axis_of_rotation[::-1]) dvf_template = sitk.Image(reference_image.GetSize(), sitk.sitkVectorFloat64, 3) dvf_template_arr = sitk.GetArrayFromImage(dvf_template) if scale is not False: dvf_template_arr[np.where(body_mask_arr)] = deformation_vectors * scale dvf_template = sitk.GetImageFromArray(dvf_template_arr) dvf_template.CopyInformation(reference_image) # smooth if np.any(gaussian_smooth): if not hasattr(gaussian_smooth, "__iter__"): gaussian_smooth = (gaussian_smooth, ) * 3 dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) dvf_tfm = sitk.DisplacementFieldTransform( sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) reference_image_bend = apply_field( reference_image, transform=dvf_tfm, structure=False, default_value=int(sitk.GetArrayViewFromImage(reference_image).min()), interp=2, ) return reference_image_bend, dvf_tfm, dvf_template
def generate_field_expand( mask_image, bone_mask=False, expand=3, gaussian_smooth=5, ): """ Expands a structure (defined using a binary mask) using a specified vector to define the dilation kernel. Args: mask_image ([SimpleITK.Image]): The binary mask to expand. bone_mask ([SimpleITK.Image, optional]): A binary mask defining regions where we expect restricted deformations. vector_asymmetric_extend (int |tuple, optional): The expansion vector applied to the entire binary mask. Convention: (z,y,x) size of expansion kernel. Defined in millimetres. Defaults to 3. gaussian_smooth (int | list, optional): Scale of a Gaussian kernel used to smooth the deformation vector field. Defaults to 5. Returns: [SimpleITK.Image]: The binary mask following the expansion. [SimpleITK.DisplacementFieldTransform]: The transform representing the expansion. [SimpleITK.Image]: The displacement vector field representing the expansion. """ registration_mask_original = convert_mask_to_reg_structure( mask_image_original) if bone_mask is not False: mask_image_original = mask_image + bone_mask else: mask_image_original = mask_image # Use binary erosion to create a smaller volume if not hasattr(expand, "__iter__"): expand = (expand, ) * 3 expand = np.array(expand) # Convert voxels to millimetres expand = expand / np.array(mask_image.GetSpacing()[::-1]) # Re-order to (x,y,z) expand = expand[::-1] # expand = [int(i / j) for i, j in zip(expand, mask_image.GetSpacing()[::-1])][::-1] # If all negative: erode if np.all(np.array(expand) <= 0): print("All factors negative: shrinking only.") mask_image_expand = sitk.BinaryErode( mask_image, np.abs(expand).astype(int).tolist(), sitk.sitkBall) # If all positive: dilate elif np.all(np.array(expand) >= 0): print("All factors positive: expansion only.") mask_image_expand = sitk.BinaryDilate( mask_image, np.abs(expand).astype(int).tolist(), sitk.sitkBall) # Otherwise: sequential operations else: print("Mixed factors: shrinking and expansion.") expansion_kernel = expand * (expand > 0) shrink_kernel = expand * (expand < 0) mask_image_expand = sitk.BinaryDilate( mask_image, np.abs(expansion_kernel).astype(int).tolist(), sitk.sitkBall) mask_image_expand = sitk.BinaryErode( mask_image_expand, np.abs(shrink_kernel).astype(int).tolist(), sitk.sitkBall) registration_mask_expand = convert_mask_to_reg_structure(mask_image_expand) if bone_mask is not False: registration_mask_expand = registration_mask_expand + bone_mask # Use DIR to find the deformation _, _, dvf_template = fast_symmetric_forces_demons_registration( registration_mask_expand, registration_mask_original, isotropic_resample=True, resolution_staging=[4, 2], iteration_staging=[10, 10], ncores=8, return_field=True, ) # smooth if np.any(gaussian_smooth): if not hasattr(gaussian_smooth, "__iter__"): gaussian_smooth = (gaussian_smooth, ) * 3 dvf_template = sitk.SmoothingRecursiveGaussian(dvf_template, gaussian_smooth) dvf_tfm = sitk.DisplacementFieldTransform( sitk.Cast(dvf_template, sitk.sitkVectorFloat64)) mask_image_symmetric_expand = apply_field(mask_image, transform=dvf_tfm, structure=True, interp=1) return mask_image_symmetric_expand, dvf_tfm, dvf_template