def compute_metric_hd(label_a, label_b, auto_crop=True): """Compute the Hausdorff distance between two labels Args: label_a (sitk.Image): A mask to compare label_b (sitk.Image): Another mask to compare Returns: float: The maximum Hausdorff distance """ if ( sitk.GetArrayViewFromImage(label_a).sum() == 0 or sitk.GetArrayViewFromImage(label_b).sum() == 0 ): return np.nan if auto_crop: largest_region = (label_a + label_b) > 0 crop_box_size, crop_box_index = label_to_roi(largest_region) label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index) label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index) hausdorff_distance = sitk.HausdorffDistanceImageFilter() hausdorff_distance.Execute(label_a, label_b) hausdorff_distance_value = hausdorff_distance.GetHausdorffDistance() return hausdorff_distance_value
def compute_metric_sensitivity(label_a, label_b, auto_crop=True): """Compute the sensitivity between two labels Args: label_a (sitk.Image): A mask to compare label_b (sitk.Image): Another mask to compare Returns: float: The sensitivity between the two labels """ if auto_crop: largest_region = (label_a + label_b) > 0 crop_box_size, crop_box_index = label_to_roi(largest_region) label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index) label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index) arr_a = sitk.GetArrayFromImage(label_a).astype(bool) arr_b = sitk.GetArrayFromImage(label_b).astype(bool) arr_intersection = arr_a & arr_b true_pos = arr_intersection.sum() false_neg = arr_a.sum() - true_pos return float((1.0 * true_pos) / (true_pos + false_neg))
def generate_valve_from_great_vessel( label_great_vessel, label_ventricle, valve_thickness_mm=8, ): """ Generates a geometrically-defined valve. This function is suitable for the pulmonic and aortic valves. Args: label_great_vessel (SimpleITK.Image): The binary mask for the great vessel (pulmonary artery or ascending aorta) label_ventricle (SimpleITK.Image): The binary mask for the ventricle (left or right) valve_thickness_mm (int, optional): Valve thickness, in millimetres. Defaults to 8. Returns: SimpleITK.Image: The geometric valve, as a binary mask. """ # To speed up binary morphology operations we first crop all images template_img = 0 * label_ventricle cb_size, cb_index = label_to_roi( (label_great_vessel + label_ventricle) > 0, expansion_mm=(20, 20, 20)) label_ventricle = crop_to_roi(label_ventricle, cb_size, cb_index) label_great_vessel = crop_to_roi(label_great_vessel, cb_size, cb_index) # Convert valve thickness to voxels _, _, res_z = label_ventricle.GetSpacing() valve_thickness = int(valve_thickness_mm / res_z) # Dilate the ventricle label_ventricle_dilate = sitk.BinaryDilate(label_ventricle, (valve_thickness, ) * 3) # Find the overlap overlap = label_great_vessel & label_ventricle_dilate # Mask to thinner great vessel mask = label_great_vessel | label_ventricle_dilate overlap = sitk.Mask(overlap, mask) label_valve = sitk.BinaryMorphologicalClosing(overlap) # Finally, paste back to the original image space label_valve = sitk.Paste( template_img, label_valve, label_valve.GetSize(), (0, 0, 0), cb_index, ) return label_valve
def compute_metric_dsc(label_a, label_b, auto_crop=True): """Compute the Dice Similarity Coefficient between two labels Args: label_a (sitk.Image): A mask to compare label_b (sitk.Image): Another mask to compare Returns: float: The Dice Similarity Coefficient """ if auto_crop: largest_region = (label_a + label_b) > 0 crop_box_size, crop_box_index = label_to_roi(largest_region) label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index) label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index) arr_a = sitk.GetArrayFromImage(label_a).astype(bool) arr_b = sitk.GetArrayFromImage(label_b).astype(bool) return 2 * ((arr_a & arr_b).sum()) / (arr_a.sum() + arr_b.sum())
def compute_metric_masd(label_a, label_b, auto_crop=True): """Compute the mean absolute distance between two labels Args: label_a (sitk.Image): A mask to compare label_b (sitk.Image): Another mask to compare Returns: float: The mean absolute surface distance """ if ( sitk.GetArrayViewFromImage(label_a).sum() == 0 or sitk.GetArrayViewFromImage(label_b).sum() == 0 ): return np.nan if auto_crop: largest_region = (label_a + label_b) > 0 crop_box_size, crop_box_index = label_to_roi(largest_region) label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index) label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index) mean_sd_list = [] num_points = [] for (la, lb) in ((label_a, label_b), (label_b, label_a)): label_intensity_stat = sitk.LabelIntensityStatisticsImageFilter() reference_distance_map = sitk.Abs( sitk.SignedMaurerDistanceMap(la, squaredDistance=False, useImageSpacing=True) ) moving_label_contour = sitk.LabelContour(lb) label_intensity_stat.Execute(moving_label_contour, reference_distance_map) mean_sd_list.append(label_intensity_stat.GetMean(1)) num_points.append(label_intensity_stat.GetNumberOfPixels(1)) mean_surf_dist = np.dot(mean_sd_list, num_points) / np.sum(num_points) return float(mean_surf_dist)
def generate_left_ventricle_segments( contours, label_left_ventricle="LEFTVENTRICLE", label_left_atrium="LEFTATRIUM", label_right_ventricle="RIGHTVENTRICLE", label_mitral_valve="MITRALVALVE", label_heart="WHOLEHEART", myocardium_thickness_mm=10, hole_fill_mm=3, optimiser_tol_degrees=1, optimiser_max_iter=10, min_area_mm2=50, verbose=False, ): """ Generates the 17 segments of the left vetricle This functions works as follows: 1. Heart volume is rotated to align the long axis to the z Cartesian (physical) space. Usually means it aligns with the axial axis (for normal simulation CT) 2. An optimiser adjusts the orientation to refine this alignment to the vector defined by MV COM - LV apex axis (long axis) 3. Left ventricle is divided into thirds along the long axis 4. Myocardium is defined as the outer 10mm 5. Geometric operations are used to define the segments 6. Everything is rotated back to the normal orientation 7. Some post-processing *magic* Args: contours (dict): A dictionary containing strings (label names) as keys and SimpleITK.Image (masks) as values. Must contain at least the LV, RV, MV, and whole heart. label_left_ventricle (str): The name for the left ventricle mask (contour) label_left_atrium (str): The name for the left atrium mask (contour) label_right_ventricle (str): The name for the right ventricle mask (contour) label_mitral_valve (str): The name for the mitral valve mask (contour) label_heart (str): The name for the heart mask (contour) myocardium_thickness_mm (float, optional): Moycardial thickness, in millimetres. Defaults to 10. hole_fill_mm (float, optional): Holes smaller than this get filled in. Defaults to 3. optimiser_tol_degrees (float, optional): Optimiser tolerance (change in angle per iter). Defaults to 1, which typically requires 3-4 iterations. optimiser_max_iter (int, optional): Maximum optimiser iterations. Defaults to 10 verbose (bool, optional): Print of information for debugging. Defaults to False. Returns: dict : The left ventricle segment dictionary, with labels (int) as keys and the binary label defining the segment (SimpleITK.Image) as values. """ if verbose: print("Beginning LV segmentation algorithm.") # Initial set up label_list = [ label_left_ventricle, label_left_atrium, label_right_ventricle, label_mitral_valve, label_heart, ] working_contours = copy.deepcopy({s: contours[s] for s in label_list}) output_contours = {} overall_transform_list = [] # Some conversions erode_img = [ int(myocardium_thickness_mm / i) for i in working_contours[label_left_ventricle].GetSpacing() ] hole_fill_img = [ int(hole_fill_mm / i) for i in working_contours[label_heart].GetSpacing() ] """ Module 1 - Preparation Crop the images Rotate to the cardiac axis """ # Crop to the smallest volume possible to make it FAST cb_size, cb_index = label_to_roi( working_contours[label_heart] > 0, expansion_mm=(30, 30, 60), # Better to make it a bit bigger to be safe ) for label in contours: working_contours[label] = crop_to_roi(contours[label], cb_size, cb_index) if verbose: print("Module 1: Cropping and initial alignment.") vol_before = np.product(contours[label_heart].GetSpacing()) vol_after = np.product(working_contours[label_heart].GetSpacing()) print( f" Images cropped. Volume reduction: {vol_before/vol_after:.3f}") # Initially we should reorient based on the cardiac axis label_orient = (working_contours[label_left_ventricle] + working_contours[label_left_atrium]) > 0 lsf = sitk.LabelShapeStatisticsImageFilter( ) # this will be used throughout lsf.Execute(label_orient) cardiac_axis = np.array( lsf.GetPrincipalAxes(1)[:3]) # First principal axis approx. long axis # The principal axis isn't guaranteed to point from base to apex # If is points apex to base, we have to invert it # So check that here if cardiac_axis[2] < 0: cardiac_axis = -1 * cardiac_axis rotation_angle = vector_angle(cardiac_axis[::-1], (0, 0, 1)) rotation_axis = np.cross(cardiac_axis[::-1], (0, 0, 1)) rotation_centre = get_com(label_orient, real_coords=True) if verbose: print(" Alignment computed.") print(" Cardiac axis: ", cardiac_axis) print(" Rotation axis: ", rotation_axis) print(" Rotation angle: ", rotation_angle) print(" Rotation centre: ", rotation_centre) rotation_transform = sitk.VersorRigid3DTransform() rotation_transform.SetCenter(rotation_centre) rotation_transform.SetRotation(rotation_axis, rotation_angle) overall_transform_list.append(rotation_transform) for label in contours: working_contours[label] = sitk.Resample( working_contours[label], rotation_transform, sitk.sitkNearestNeighbor, 0, working_contours[label].GetPixelID(), ) """ Module 2 - LV orientation alignment We use a very simple optimisation regime to enable robust computation of the LV apex We compute the vector from the MV COM to the LV apex This will be used for orientation (i.e. the long axis) """ optimiser_tol_radians = optimiser_tol_degrees * np.pi / 180 n = 0 if verbose: print("Module 2: LV orientation alignment.") print(" Optimiser tolerance (degrees) =", optimiser_tol_degrees) print(" Beginning alignment process") while n < optimiser_max_iter and np.abs( rotation_angle) > optimiser_tol_radians: n += 1 # Find the LV apex lv_locations = np.where( sitk.GetArrayViewFromImage(working_contours[label_left_ventricle])) lv_apex_z = lv_locations[0].min() lv_apex_y = lv_locations[1][lv_locations[0] == lv_apex_z].mean() lv_apex_x = lv_locations[2][lv_locations[0] == lv_apex_z].mean() lv_apex_loc = np.array([lv_apex_x, lv_apex_y, lv_apex_z]) # Get the MV COM mv_com = np.array( get_com(working_contours[label_mitral_valve], real_coords=True)) # Define the LV axis lv_apex_loc_img = np.array(working_contours[label_left_ventricle]. TransformContinuousIndexToPhysicalPoint( lv_apex_loc.tolist())) lv_axis = lv_apex_loc_img - mv_com # Compute the rotation parameters rotation_axis = np.cross(lv_axis, (0, 0, 1)) rotation_angle = vector_angle(lv_axis, (0, 0, 1)) rotation_centre = 0.5 * ( mv_com + lv_apex_loc_img ) # get_com(working_contours[label_left_ventricle], real_coords=True) rotation_transform = sitk.VersorRigid3DTransform() rotation_transform.SetCenter(rotation_centre) rotation_transform.SetRotation(rotation_axis, rotation_angle) overall_transform_list.append(rotation_transform) if verbose: print(" N: ", n) print(" LV apex: ", lv_apex_loc_img) print(" MV COM: ", mv_com) print(" LV axis: ", lv_axis) print(" Rotation axis: ", rotation_axis) print(" Rotation centre: ", rotation_centre) print(" Rotation angle: ", rotation_angle) for label in contours: working_contours[label] = sitk.Resample( working_contours[label], rotation_transform, sitk.sitkNearestNeighbor, 0, working_contours[label].GetPixelID(), ) """ Module 3 - Compute the myocardium for the whole LV volume Divide this volume into thirds (from MV COM -> LV apex) """ if verbose: print("Module 3: Myocardium generation.") # First, let's just extract the myocardium label_lv_inner = sitk.BinaryErode(working_contours[label_left_ventricle], erode_img) label_lv_myo = working_contours[label_left_ventricle] - label_lv_inner # Mask the myo to a dilation of the blood pool # This helps improve shape consistency label_lv_myo_mask = sitk.BinaryDilate(label_lv_inner, erode_img) label_lv_myo = sitk.Mask(label_lv_myo, label_lv_myo_mask) # Computing limits for division into thirds # [xstart, ystart, zstart, xsize, ysize, zsize] # For the limits, we will use the centre of mass of the MV to the LV apex # The inner limit is used to assign the top portion (basal) of the LV to the anterior segment lsf.Execute(label_lv_inner) _, _, inf_limit_lv, _, _, extent = lsf.GetRegion(1) com_mv, _, _ = get_com(working_contours[label_mitral_valve]) extent = com_mv - inf_limit_lv dc = int(extent / 3) # Define limits (cut LV into thirds) apical_extent = inf_limit_lv + dc mid_extent = inf_limit_lv + 2 * dc basal_extent = com_mv # more complete coverage if verbose: print(" Apex (long axis) slice: ", inf_limit_lv) print(" Apical section extent slice: ", apical_extent) print(" Mid section extent slice: ", mid_extent) print(" Basal section extent slice: ", basal_extent) print(" DeltaCut (DC): ", dc) print(" Extent: ", extent) # Segment 17 label_lv_myo_apex = label_lv_myo * 1 # make a copy label_lv_myo_apex[:, :, inf_limit_lv:] = 0 # The apical segment label_lv_myo_apical = label_lv_myo * 1 # make a copy label_lv_myo_apical[:, :, :inf_limit_lv] = 0 label_lv_myo_apical[:, :, apical_extent:] = 0 # The mid segment label_lv_myo_mid = label_lv_myo * 1 # make a copy label_lv_myo_mid[:, :, :apical_extent] = 0 label_lv_myo_mid[:, :, mid_extent:] = 0 # The basal segment label_lv_myo_basal = label_lv_myo * 1 # make a copy label_lv_myo_basal[:, :, :mid_extent] = 0 label_lv_myo_basal[:, :, basal_extent:] = 0 """ Module 4 - Generate 17 segments 1. Find the basal (anterior) insertion of the RV This defines theta_0 2. Find the baseline angle for the apical section This defines thera_0_apical 3. Iterate though each section (apical, mid, basal): a. Convert each myocardium label loc to polar coords b. Assign each label to the appropriate LV segments """ if verbose: print("Module 4: Segment generation.") # We need the angle for the basal RV insertion # This is the most counter-clockwise RV location # First, retrieve the most basal 5 slices loc_rv_z, loc_rv_y, loc_rv_x = np.where( sitk.GetArrayViewFromImage(working_contours[label_right_ventricle])) loc_rv_z_basal = np.arange(mid_extent, mid_extent + 5) if verbose: print(" RV basal slices: ", loc_rv_z_basal) theta_rv_insertion = [] for z in loc_rv_z_basal: # Now get all the x and y positions loc_rv_basal_x = loc_rv_x[np.where(np.in1d(loc_rv_z, z))] loc_rv_basal_y = loc_rv_y[np.where(np.in1d(loc_rv_z, z))] # Now define the LV COM on each slice lv_com = get_com(working_contours[label_left_ventricle][:, :, int(z)]) lv_com_basal_x = lv_com[1] lv_com_basal_y = lv_com[0] # Compute the angle theta_rv = np.arctan2(lv_com_basal_y - loc_rv_basal_y, loc_rv_basal_x - lv_com_basal_x) theta_rv[theta_rv < 0] += 2 * np.pi theta_rv_insertion.append(theta_rv.min()) theta_0 = np.median(theta_rv_insertion) if verbose: print(" RV insertion angle (basal section): ", theta_0) # We also need the angle in the apical section for accurate segmentation lv_com_apical_list = [] rv_com_apical_list = [] for n in range(inf_limit_lv, apical_extent): lv_com_apical_list.append( get_com(working_contours[label_left_ventricle][:, :, n])) rv_com_apical_list.append( get_com(working_contours[label_right_ventricle][:, :, n])) lv_com_apical = np.mean(lv_com_apical_list, axis=0) rv_com_apical = np.mean(rv_com_apical_list, axis=0) theta_0_apical = np.arctan2(lv_com_apical[0] - rv_com_apical[0], rv_com_apical[1] - lv_com_apical[1]) if verbose: print(" Apical LV-RV COM angle: ", theta_0_apical) for i in range(17): working_contours[i + 1] = 0 * working_contours[label_heart] working_contours[17] = label_lv_myo_apex if verbose: print(" Computing apical segments") # We are now going to compute the segments in cylindical sections # First up - apical slices for n in range(inf_limit_lv, apical_extent): label_lv_myo_slice = label_lv_myo[:, :, n] # We will need numpy arrays here arr_lv_myo_slice = sitk.GetArrayViewFromImage(label_lv_myo_slice) loc_y, loc_x = np.where(arr_lv_myo_slice) # Now the origin y_0, x_0 = get_com(label_lv_myo_slice) # Compute the angle(s) theta = -np.arctan2(loc_y - y_0, loc_x - x_0) - theta_0_apical # Convert to [0,2*np.pi] theta[theta < 0] += 2 * np.pi # Compute the radii radii = np.sqrt((loc_y - y_0)**2 + (loc_x - x_0)**2) # Now assign to different segments working_contours[13][:, :, n] = extract( label_lv_myo_slice, theta, radii, 5 * np.pi / 4, 7 * np.pi / 4, loc_x, loc_y, min_area_mm2=min_area_mm2, ) working_contours[14][:, :, n] = extract( label_lv_myo_slice, theta, radii, 1 * np.pi / 4, 7 * np.pi / 4, loc_x, loc_y, cw=True, min_area_mm2=min_area_mm2, ) working_contours[15][:, :, n] = extract( label_lv_myo_slice, theta, radii, 1 * np.pi / 4, 3 * np.pi / 4, loc_x, loc_y, min_area_mm2=min_area_mm2, ) working_contours[16][:, :, n] = extract( label_lv_myo_slice, theta, radii, 3 * np.pi / 4, 5 * np.pi / 4, loc_x, loc_y, min_area_mm2=min_area_mm2, ) if verbose: print(" Computing mid segments") # Second up - mid slices for n in range(apical_extent, mid_extent): label_lv_myo_slice = label_lv_myo[:, :, n] # We will need numpy arrays here arr_lv_myo_slice = sitk.GetArrayViewFromImage(label_lv_myo_slice) loc_y, loc_x = np.where(arr_lv_myo_slice) # Now the origin y_0, x_0 = get_com(label_lv_myo_slice) # Compute the angle(s) theta = -np.arctan2(loc_y - y_0, loc_x - x_0) - theta_0 # Convert to [0,2*np.pi] theta[theta < 0] += 2 * np.pi # Compute the radii radii = np.sqrt((loc_y - y_0)**2 + (loc_x - x_0)**2) # Now assign to different segments working_contours[8][:, :, n] = extract(label_lv_myo_slice, theta, radii, 0, np.pi / 3, loc_x, loc_y, min_area_mm2=min_area_mm2) working_contours[9][:, :, n] = extract( label_lv_myo_slice, theta, radii, 1 * np.pi / 3, 2 * np.pi / 3, loc_x, loc_y, min_area_mm2=min_area_mm2, ) working_contours[10][:, :, n] = extract( label_lv_myo_slice, theta, radii, 2 * np.pi / 3, 3 * np.pi / 3, loc_x, loc_y, min_area_mm2=min_area_mm2, ) working_contours[11][:, :, n] = extract( label_lv_myo_slice, theta, radii, 3 * np.pi / 3, 4 * np.pi / 3, loc_x, loc_y, min_area_mm2=min_area_mm2, ) working_contours[12][:, :, n] = extract( label_lv_myo_slice, theta, radii, 4 * np.pi / 3, 5 * np.pi / 3, loc_x, loc_y, min_area_mm2=min_area_mm2, ) working_contours[7][:, :, n] = extract( label_lv_myo_slice, theta, radii, 5 * np.pi / 3, 2 * np.pi, loc_x, loc_y, min_area_mm2=min_area_mm2, ) if verbose: print(" Computing basal segments") # Third up - basal slices for n in range(mid_extent, basal_extent): label_lv_myo_slice = label_lv_myo[:, :, n] # We will need numpy arrays here arr_lv_myo_slice = sitk.GetArrayViewFromImage(label_lv_myo_slice) loc_y, loc_x = np.where(arr_lv_myo_slice) # Now the origin y_0, x_0 = get_com(label_lv_myo_slice) # Compute the angle(s) theta = -np.arctan2(loc_y - y_0, loc_x - x_0) - theta_0 # Convert to [0,2*np.pi] theta[theta < 0] += 2 * np.pi # Compute the radii radii = np.sqrt((loc_y - y_0)**2 + (loc_x - x_0)**2) # Now assign to different segments working_contours[2][:, :, n] = extract( label_lv_myo_slice, theta, radii, 0, np.pi / 3, loc_x, loc_y, radius_min=15, min_area_mm2=min_area_mm2, ) working_contours[3][:, :, n] = extract( label_lv_myo_slice, theta, radii, 1 * np.pi / 3, 2 * np.pi / 3, loc_x, loc_y, radius_min=15, min_area_mm2=min_area_mm2, ) working_contours[4][:, :, n] = extract( label_lv_myo_slice, theta, radii, 2 * np.pi / 3, 3 * np.pi / 3, loc_x, loc_y, radius_min=15, min_area_mm2=min_area_mm2, ) working_contours[5][:, :, n] = extract( label_lv_myo_slice, theta, radii, 3 * np.pi / 3, 4 * np.pi / 3, loc_x, loc_y, radius_min=15, min_area_mm2=min_area_mm2, ) working_contours[6][:, :, n] = extract( label_lv_myo_slice, theta, radii, 4 * np.pi / 3, 5 * np.pi / 3, loc_x, loc_y, radius_min=15, min_area_mm2=min_area_mm2, ) working_contours[1][:, :, n] = extract( label_lv_myo_slice, theta, radii, 5 * np.pi / 3, 2 * np.pi, loc_x, loc_y, radius_min=15, min_area_mm2=min_area_mm2, ) """ Module 5 - re-orientation into image space We perform the total inverse transformation, and paste the labels back into the image space """ if verbose: print(" Module 5: Re-orientation.") # Compute the total transform overall_transform = sitk.CompositeTransform(overall_transform_list) inverse_transform = overall_transform.GetInverse() # Rotate back to the original reference space for segment in range(17): new_structure = sitk.Resample( working_contours[segment + 1], inverse_transform, sitk.sitkNearestNeighbor, 0, ) if hole_fill_mm > 0: new_structure = sitk.BinaryMorphologicalClosing( new_structure, hole_fill_img) new_structure = sitk.Paste( contours[label_heart] * 0, new_structure, new_structure.GetSize(), (0, 0, 0), cb_index, ) output_contours[segment + 1] = new_structure if verbose: print("Complete!") return output_contours
def generate_valve_using_cylinder( label_atrium, label_ventricle, radius_mm=15, height_mm=10, ): """ Generates a geometrically-defined valve. This function is suitable for the tricuspid and mitral valves. Args: label_atrium (SimpleITK.Image): The binary mask for the (left or right) atrium. label_ventricle (SimpleITK.Image): The binary mask for the (left or right) ventricle. radius_mm (int, optional): The valve radius, in mm. Defaults to 15. height_mm (int, optional): The valve height (i.e. perpendicular extent), in mm. Defaults to 10. Returns: SimpleITK.Image: The geometrically defined valve """ # To speed up binary morphology operations we first crop all images template_img = 0 * label_ventricle cb_size, cb_index = label_to_roi((label_atrium + label_ventricle) > 0, expansion_mm=(20, 20, 20)) label_atrium = crop_to_roi(label_atrium, cb_size, cb_index) label_ventricle = crop_to_roi(label_ventricle, cb_size, cb_index) # Define the overlap region (using binary dilation) # Increment overlap to make sure we have enough voxels dilation = 1 overlap_vol = 0 while overlap_vol <= 2000: dilation_img = [ int(dilation / i) for i in label_ventricle.GetSpacing() ] overlap = sitk.BinaryDilate(label_atrium, dilation_img) & sitk.BinaryDilate( label_ventricle, dilation_img) overlap_vol = np.sum( sitk.GetArrayFromImage(overlap) * np.product(overlap.GetSpacing())) dilation += 1 # Now we can calculate the location of the valve valve_loc = get_com(overlap, as_int=True) valve_loc_real = get_com(overlap, real_coords=True) # Now we create a cylinder with the user_defined parameters cylinder = insert_cylinder_image(0 * label_ventricle, radius_mm, height_mm, valve_loc[::-1]) # Now we compute the first principal moment (long axis) of the combined chambers # f = sitk.LabelShapeStatisticsImageFilter() # f.Execute(label_ventricle + label_atrium) # orientation_vector = f.GetPrincipalAxes(1)[:3] # A more robust method is to use the COM offset from the chambers # as a proxy for the long axis of the LV/RV orientation_vector = np.array(get_com( label_ventricle, real_coords=True)) - np.array( get_com(label_atrium, real_coords=True)) # Another method is to compute the third principal moment of the overlap region # f = sitk.LabelShapeStatisticsImageFilter() # f.Execute(overlap) # orientation_vector = f.GetPrincipalAxes(1)[:3] # Get the rotation parameters rotation_angle = vector_angle(orientation_vector, (0, 0, 1), smallest=False) rotation_axis = np.cross(orientation_vector, (0, 0, 1)) # Rotate the cylinder to define the valve label_valve = rotate_image( cylinder, rotation_centre=valve_loc_real, rotation_axis=rotation_axis, rotation_angle_radians=rotation_angle, interpolation=sitk.sitkNearestNeighbor, default_value=0, ) # Now we want to trim any parts of the valve too close to the edge of the chambers # combined_chambers = sitk.BinaryDilate(label_atrium, (3,) * 3) | sitk.BinaryDilate( # label_ventricle, (3,) * 3 # ) # combined_chambers = sitk.BinaryErode(combined_chambers, (6, 6, 6)) # label_valve = sitk.Mask(label_valve, combined_chambers) # Finally, paste back to the original image space label_valve = sitk.Paste( template_img, label_valve, label_valve.GetSize(), (0, 0, 0), cb_index, ) return label_valve
def run_segmentation(img, settings=MUTLIATLAS_SETTINGS_DEFAULTS): """Runs the atlas-based segmentation algorithm Args: img (sitk.Image): settings (dict, optional): Dictionary containing settings for algorithm. Defaults to default_settings. Returns: dict: Dictionary containing output of segmentation """ results = {} results_prob = {} """ Initialisation - Read in atlases - image files - structure files Atlas structure: 'ID': 'Original': 'CT Image' : sitk.Image 'Struct A' : sitk.Image 'Struct B' : sitk.Image 'RIR' : 'CT Image' : sitk.Image 'Transform' : transform parameter map 'Struct A' : sitk.Image 'Struct B' : sitk.Image 'DIR' : 'CT Image' : sitk.Image 'Transform' : displacement field transform 'Weight Map' : sitk.Image 'Struct A' : sitk.Image 'Struct B' : sitk.Image """ logger.info("") # Settings atlas_path = settings["atlas_settings"]["atlas_path"] atlas_id_list = settings["atlas_settings"]["atlas_id_list"] atlas_structure_list = settings["atlas_settings"]["atlas_structure_list"] atlas_image_format = settings["atlas_settings"]["atlas_image_format"] atlas_label_format = settings["atlas_settings"]["atlas_label_format"] crop_atlas_to_structures = settings["atlas_settings"]["crop_atlas_to_structures"] crop_atlas_expansion_mm = settings["atlas_settings"]["crop_atlas_expansion_mm"] atlas_set = {} for atlas_id in atlas_id_list: atlas_set[atlas_id] = {} atlas_set[atlas_id]["Original"] = {} image = sitk.ReadImage(f"{atlas_path}/{atlas_image_format.format(atlas_id)}") structures = { struct: sitk.ReadImage(f"{atlas_path}/{atlas_label_format.format(atlas_id, struct)}") for struct in atlas_structure_list } if crop_atlas_to_structures: logger.info(f"Automatically cropping atlas: {atlas_id}") original_volume = np.product(image.GetSize()) crop_box_size, crop_box_index = label_to_roi( structures.values(), expansion_mm=crop_atlas_expansion_mm ) image = crop_to_roi(image, size=crop_box_size, index=crop_box_index) final_volume = np.product(image.GetSize()) logger.info(f" > Volume reduced by factor {original_volume/final_volume:.2f}") for struct in atlas_structure_list: structures[struct] = crop_to_roi( structures[struct], size=crop_box_size, index=crop_box_index ) atlas_set[atlas_id]["Original"]["CT Image"] = image for struct in atlas_structure_list: atlas_set[atlas_id]["Original"][struct] = structures[struct] """ Step 1 - Automatic cropping If we have a guide structure: - use structure to crop target image Otherwise: - using a quick registration to register each atlas - expansion of the bounding box to ensure entire volume of interest is enclosed - target image is cropped """ expansion_mm = settings["auto_crop_target_image_settings"]["expansion_mm"] quick_reg_settings = { "reg_method": "similarity", "shrink_factors": [8], "smooth_sigmas": [0], "sampling_rate": 0.75, "default_value": -1000, "number_of_iterations": 25, "final_interp": sitk.sitkLinear, "metric": "mean_squares", "optimiser": "gradient_descent_line_search", } registered_crop_images = [] logger.info("Running initial Translation tranform to crop image volume") for atlas_id in atlas_id_list[: min([8, len(atlas_id_list)])]: logger.info(f" > atlas {atlas_id}") # Register the atlases atlas_set[atlas_id]["RIR"] = {} atlas_image = atlas_set[atlas_id]["Original"]["CT Image"] reg_image, _ = linear_registration( img, atlas_image, **quick_reg_settings, ) registered_crop_images.append(sitk.Cast(reg_image, sitk.sitkFloat32)) del reg_image combined_image = sum(registered_crop_images) / len(registered_crop_images) > -1000 crop_box_size, crop_box_index = label_to_roi(combined_image, expansion_mm=expansion_mm) img_crop = crop_to_roi(img, crop_box_size, crop_box_index) logger.info("Calculated crop box:") logger.info(f" > {crop_box_index}") logger.info(f" > {crop_box_size}") logger.info(f" > Vol reduction = {np.product(img.GetSize())/np.product(crop_box_size):.2f}") """ Step 2 - Rigid registration of target images - Individual atlas images are registered to the target - The transformation is used to propagate the labels onto the target """ linear_registration_settings = settings["linear_registration_settings"] logger.info( f"Running {linear_registration_settings['reg_method']} tranform to align atlas images" ) for atlas_id in atlas_id_list: # Register the atlases logger.info(f" > atlas {atlas_id}") atlas_set[atlas_id]["RIR"] = {} target_reg_image = img_crop atlas_reg_image = atlas_set[atlas_id]["Original"]["CT Image"] _, initial_tfm = linear_registration( target_reg_image, atlas_reg_image, **linear_registration_settings, ) # Save in the atlas dict atlas_set[atlas_id]["RIR"]["Transform"] = initial_tfm atlas_set[atlas_id]["RIR"]["CT Image"] = apply_transform( input_image=atlas_set[atlas_id]["Original"]["CT Image"], reference_image=img_crop, transform=initial_tfm, default_value=-1000, interpolator=sitk.sitkLinear, ) # sitk.WriteImage(rigid_image, f"./RR_{atlas_id}.nii.gz") for struct in atlas_structure_list: input_struct = atlas_set[atlas_id]["Original"][struct] atlas_set[atlas_id]["RIR"][struct] = apply_transform( input_image=input_struct, reference_image=img_crop, transform=initial_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor, ) atlas_set[atlas_id]["Original"] = None """ Step 3 - Deformable image registration - Using Fast Symmetric Diffeomorphic Demons """ # Settings deformable_registration_settings = settings["deformable_registration_settings"] logger.info("Running DIR to refine atlas image registration") for atlas_id in atlas_id_list: logger.info(f" > atlas {atlas_id}") # Register the atlases atlas_set[atlas_id]["DIR"] = {} atlas_reg_image = atlas_set[atlas_id]["RIR"]["CT Image"] target_reg_image = img_crop _, dir_tfm, _ = fast_symmetric_forces_demons_registration( target_reg_image, atlas_reg_image, **deformable_registration_settings, ) # Save in the atlas dict atlas_set[atlas_id]["DIR"]["Transform"] = dir_tfm atlas_set[atlas_id]["DIR"]["CT Image"] = apply_transform( input_image=atlas_set[atlas_id]["RIR"]["CT Image"], transform=dir_tfm, default_value=-1000, interpolator=sitk.sitkLinear, ) for struct in atlas_structure_list: input_struct = atlas_set[atlas_id]["RIR"][struct] atlas_set[atlas_id]["DIR"][struct] = apply_transform( input_image=input_struct, transform=dir_tfm, default_value=0, interpolator=sitk.sitkNearestNeighbor, ) atlas_set[atlas_id]["RIR"] = None """ Step 4 - Label Fusion """ # Compute weight maps vote_type = settings["label_fusion_settings"]["vote_type"] vote_params = settings["label_fusion_settings"]["vote_params"] # Compute weight maps for atlas_id in list(atlas_set.keys()): atlas_image = atlas_set[atlas_id]["DIR"]["CT Image"] weight_map = compute_weight_map( img_crop, atlas_image, vote_type=vote_type, vote_params=vote_params ) atlas_set[atlas_id]["DIR"]["Weight Map"] = weight_map combined_label_dict = combine_labels(atlas_set, atlas_structure_list) """ Step 6 - Paste the cropped structure into the original image space """ logger.info("Generating binary segmentations.") template_img_binary = sitk.Cast((img * 0), sitk.sitkUInt8) template_img_prob = sitk.Cast((img * 0), sitk.sitkFloat64) for structure_name in atlas_structure_list: probability_map = combined_label_dict[structure_name] if structure_name in settings["label_fusion_settings"]["optimal_threshold"]: optimal_threshold = settings["label_fusion_settings"]["optimal_threshold"][ structure_name ] else: optimal_threshold = 0.5 binary_struct = process_probability_image(probability_map, optimal_threshold) # Un-crop binary structure paste_img_binary = sitk.Paste( template_img_binary, binary_struct, binary_struct.GetSize(), (0, 0, 0), crop_box_index, ) results[structure_name] = paste_img_binary # Un-crop probability map paste_prob_img = sitk.Paste( template_img_prob, probability_map, probability_map.GetSize(), (0, 0, 0), crop_box_index, ) results_prob[structure_name] = paste_prob_img """ Step 8 - Post-processing """ postprocessing_settings = settings["postprocessing_settings"] if postprocessing_settings["run_postprocessing"]: logger.info("Running post-processing.") # Remove any smaller components and perform morphological closing (hole filling) binaryfillhole_img = [ int(postprocessing_settings["binaryfillhole_mm"] / sp) for sp in img.GetSpacing() ] for structure_name in postprocessing_settings["structures_for_binaryfillhole"]: if structure_name not in results.keys(): continue contour_s = results[structure_name] contour_s = sitk.RelabelComponent(sitk.ConnectedComponent(contour_s)) == 1 contour_s = sitk.BinaryMorphologicalClosing(contour_s, binaryfillhole_img) results[structure_name] = contour_s if len(postprocessing_settings["structures_for_overlap_correction"]) >= 2: # Remove any overlaps input_overlap = { s: results[s] for s in postprocessing_settings["structures_for_overlap_correction"] } output_overlap = correct_volume_overlap(input_overlap) for s in postprocessing_settings["structures_for_overlap_correction"]: results[s] = output_overlap[s] logger.info("Done!") return results, results_prob
def quick_optimise_probability( metric_function, manual_contour, probability_image, p_0=0.5, delta=0.5, tolerance=0.01, mode="min", create_figure=False, auto_crop=True, metric_args={}, ): """Optimise the probability threshold used to generate a binary segmentation. This is a simple parameter sweep, with linearly decreasing resolution. It will usually converge between 5 and 10 iterations. Args: metric_function (function to return float): The metric function, this takes in two binary masks (a reference, and test [SimpleITK.Image]) and returns a metric (as a float). Typical choices would be the DSC, a surface distance metric, dose difference, relative volume. Additional arguments can be passed in through `metric_args` argument. manual_contour (SimpleITK.Image): The reference (manual) contour. probability_image (SimpleITK.Image): The probability map from which the optimal threshold will be derived. This does NOT have to be scaled to [0,1]. p_0 (float, optional): Initial guess of the optimal threshold. Defaults to 0.5. delta (float, optional): The window size of the optimiser. Defaults to 0.5. tolerance (float, optional): If the metric changes by an amount less that `tolerance` the optimiser will stop. Defaults to 0.01. mode (str, optional): Specifies whether the metric should be maximised ("max") or minimised ("min"). Defaults to "min". create_figure (bool, optional): Create a matplotlib figure showing the optimisation. This is not returned, so make sure you can capture this (e.g. using IPython). Defaults to False. auto_crop (bool, optional): Crop the image volumes to the region of interest. Speeds up the process signficiantly so don't turn off unless you have a good reason to! Defaults to True. metric_args (dict, optional): Additional arguments passes to the metric function. This could be useful if you are calculating a dose-based metric and require a dose grid to be passed to the metric function. Defaults to {}. Returns: tuple (float, float): The optimal probability, optimal metric value. """ # Auto crop images if auto_crop: cb_size, cb_index = label_to_roi( (manual_contour > 0) | (probability_image > 0), expansion_mm=10) manual_contour = crop_to_roi(manual_contour, cb_size, cb_index) probability_image = crop_to_roi(probability_image, cb_size, cb_index) # Set up n_iter = 0 p_best = p_0 auto_contour = process_probability_image(probability_image, threshold=p_0) m_n = metric_function(manual_contour, auto_contour, **metric_args) m_best = m_n print("Starting optimisation.") print(f"n = 0 | p = {p_best:.3f} | metric = {m_n:.3f}") p_list = [p_best] m_list = [m_best] improv = 0 while np.abs(improv) > tolerance or n_iter <= 3: n_iter += 1 m_n = m_best p_new = [ p_best - 3 * delta / 4, p_best - delta / 2, p_best - delta / 4, p_best + delta / 4, p_best + delta / 2, p_best + 3 * delta / 4, ] m_new = [ metric_function( manual_contour, process_probability_image(probability_image, threshold=p), **metric_args, ) for p in p_new ] p_list = p_list + p_new m_list = m_list + m_new if mode == "min": p_best = p_list[np.argmin(m_list)] m_best = np.min(m_list) elif mode == "max": p_best = p_list[np.argmax(m_list)] m_best = np.max(m_list) improv = m_best - m_n delta /= 4 print(f"n = {n_iter} | p = {p_best:.3f} | metric = {m_best:.3f}") if create_figure: fig, ax = plt.subplots(1, 1) ax.scatter(p_list, m_list, c="k", zorder=1) ax.plot(*list(zip(*sorted(zip(p_list, m_list)))), c="k", zorder=1) ax.scatter((p_best), (m_best), c="r", label=f"Optimum ({p_best:.2f},{m_best:.2f})", zorder=2) ax.set_xlim(0, 1) ax.set_xlabel("Probability Difference (from Optimal)") ax.set_ylabel("Metric Value") ax.grid() ax.set_axisbelow(True) ax.set_title(f"Optimiser | {metric_function.__name__}, mode = {mode}") ax.legend() fig.show() return p_best, m_best
def geometric_sinoatrialnode(label_svc, label_ra, label_wholeheart, radius_mm=10): """Geometric definition of the cardiac sinoatrial node (SAN). This is largely inspired by Loap et al 2021 [https://doi.org/10.1016/j.prro.2021.02.002] Args: label_svc (SimpleITK.Image): The binary mask defining the superior vena cava. label_ra (SimpleITK.Image): The binary mask defining the right atrium. label_wholeheart (SimpleITK.Image): The binary mask defining the whole heart. radius_mm (int, optional): The radius of the SAN, in millimetres. Defaults to 10. Returns: SimpleITK.Image: A binary mask defining the SAN """ # To speed up binary morphology operations we first crop all images template_img = 0 * label_wholeheart cb_size, cb_index = label_to_roi( (label_svc + label_ra + label_wholeheart) > 0, expansion_mm=(20, 20, 20) ) label_svc = crop_to_roi(label_svc, cb_size, cb_index) label_ra = crop_to_roi(label_ra, cb_size, cb_index) label_wholeheart = crop_to_roi(label_wholeheart, cb_size, cb_index) arr_svc = sitk.GetArrayFromImage(label_svc) arr_ra = sitk.GetArrayFromImage(label_ra) # First, find the most inferior slice of the SVC # This defines the z location of the SAN inf_limit_svc = np.min(np.where(arr_svc)[0]) # Now expand the SVC until it touches the RA on the inf slice overlap = 0 dilate = 1 dilate_ax = 0 while overlap == 0: label_svc_dilate = sitk.BinaryDilate(label_svc, (dilate, dilate, dilate_ax)) label_overlap = label_svc_dilate & label_ra overlap = sitk.GetArrayFromImage(label_overlap)[inf_limit_svc, :, :].sum() dilate += 1 if dilate >= 3: arr_svc = sitk.GetArrayFromImage(label_svc_dilate) inf_limit_svc = np.min(np.where(arr_svc)[0]) dilate_ax += 1 # Locate the point on intersection intersect_loc = get_com(label_overlap) # Create an image with a single voxel of value 1 located at the point of intersection arr_intersect = arr_ra * 0 arr_intersect[inf_limit_svc, intersect_loc[1], intersect_loc[2]] = 1 label_intersect = sitk.GetImageFromArray(arr_intersect) label_intersect.CopyInformation(label_ra) # Define the locations greater than 10mm from the WH # Ensures the SAN doesn't extend outside the heart volume potential_san_region = sitk.BinaryErode(label_wholeheart, (10, 10, 0)) # Find the point in this region closest to the intersection # First generate a distance map distancemap_san = sitk.SignedMaurerDistanceMap( label_intersect, squaredDistance=False, useImageSpacing=True ) # Then get the distance from the intersection at all possible points arr_distancemap_san = sitk.GetArrayFromImage(distancemap_san) arr_potential_san_region = sitk.GetArrayFromImage(potential_san_region) yloc, xloc = np.where(arr_potential_san_region[inf_limit_svc, :, :]) distances = arr_distancemap_san[inf_limit_svc, yloc, xloc] # Find where the distance is a minimum location_of_min = distances.argmin() # Now define the SAN location sphere_centre = (inf_limit_svc, yloc[location_of_min], xloc[location_of_min]) # Generate an image label_san = insert_sphere_image(label_ra * 0, sp_radius=radius_mm, sp_centre=sphere_centre) # Finally, paste the label into the original image space label_san = sitk.Paste( template_img, label_san, label_san.GetSize(), (0, 0, 0), cb_index, ) return label_san
def geometric_atrioventricularnode(label_la, label_lv, label_ra, label_rv, radius_mm=10): """Geometric definition of the cardiac atrioventricular node (AVN). This is largely inspired by Loap et al 2021 [https://doi.org/10.1016/j.prro.2021.02.002] Args: label_la (SimpleITK.Image): The binary mask defining the left atrium. label_lv (SimpleITK.Image): The binary mask defining the left ventricle. label_ra (SimpleITK.Image): The binary mask defining the right atrium. label_rv (SimpleITK.Image): The binary mask defining the right ventricle. radius_mm (float, optional): The radius of the AVN, in millimetres. Defaults to 10. Returns: SimpleITK.Image: A binary mask defining the AVN """ # To speed up binary morphology operations we first crop all images template_img = 0 * label_ra cb_size, cb_index = label_to_roi( (label_la + label_lv + label_ra + label_rv) > 0, expansion_mm=(20, 20, 20) ) label_la = crop_to_roi(label_la, cb_size, cb_index) label_lv = crop_to_roi(label_lv, cb_size, cb_index) label_ra = crop_to_roi(label_ra, cb_size, cb_index) label_rv = crop_to_roi(label_rv, cb_size, cb_index) # First, find the most inferior slice of the left atrium arr_la = sitk.GetArrayFromImage(label_la) inf_limit_la = np.min(np.where(arr_la)[0]) # Now progress 1cm in the superior direction # This defines the slice of the AVN centre slice_loc = int(inf_limit_la + 10 / label_la.GetSpacing()[2]) # Create 2D images at this slice location label_la_2d = label_la[:, :, slice_loc] label_lv_2d = label_lv[:, :, slice_loc] label_ra_2d = label_ra[:, :, slice_loc] label_rv_2d = label_rv[:, :, slice_loc] # We now iteratively erode the structures to ensure they do not overlap # This ensures we can measure the closest point without any errors # LEFT ATRIUM overlap = 1 erode = 1 while overlap > 0: label_lv_2d = sitk.BinaryErode(label_lv_2d, (erode, erode)) label_overlap = label_lv_2d & label_la_2d overlap = sitk.GetArrayFromImage(label_overlap).sum() erode += 1 # LEFT ATRIUM overlap = 0 erode = 1 while overlap > 0: label_la_2d = sitk.BinaryErode(label_la_2d, (erode, erode)) label_overlap = label_la_2d & label_ra_2d overlap = sitk.GetArrayFromImage(label_overlap).sum() erode += 1 # RIGHT ATRIUM overlap = 0 erode = 1 while overlap > 0: label_ra_2d = sitk.BinaryErode(label_ra_2d, (erode, erode)) label_overlap = label_ra_2d & label_rv_2d overlap = sitk.GetArrayFromImage(label_overlap).sum() erode += 1 # RIGHT VENTRICLE overlap = 0 erode = 1 while overlap > 0: label_rv_2d = sitk.BinaryErode(label_rv_2d, (erode, erode)) label_overlap = label_rv_2d & label_lv_2d overlap = sitk.GetArrayFromImage(label_overlap).sum() erode += 1 # Measure closest points y_la, x_la = get_closest_point_2d(label_rv_2d, label_la_2d) y_lv, x_lv = get_closest_point_2d(label_ra_2d, label_lv_2d) y_ra, x_ra = get_closest_point_2d(label_lv_2d, label_ra_2d) y_rv, x_rv = get_closest_point_2d(label_la_2d, label_rv_2d) # Take the arithmetic mean x_location = np.mean((x_la, x_lv, x_ra, x_rv), dtype=int) y_location = np.mean((y_la, y_lv, y_ra, y_rv), dtype=int) # Now define the AVN location sphere_centre = (slice_loc, y_location, x_location) # Generate an image label_avn = insert_sphere_image(label_ra * 0, sp_radius=radius_mm, sp_centre=sphere_centre) # Finally, paste the label into the original image space label_avn = sitk.Paste( template_img, label_avn, label_avn.GetSize(), (0, 0, 0), cb_index, ) return label_avn