示例#1
0
def compute_metric_hd(label_a, label_b, auto_crop=True):
    """Compute the Hausdorff distance between two labels

    Args:
        label_a (sitk.Image): A mask to compare
        label_b (sitk.Image): Another mask to compare

    Returns:
        float: The maximum Hausdorff distance
    """
    if (
        sitk.GetArrayViewFromImage(label_a).sum() == 0
        or sitk.GetArrayViewFromImage(label_b).sum() == 0
    ):
        return np.nan
    if auto_crop:
        largest_region = (label_a + label_b) > 0
        crop_box_size, crop_box_index = label_to_roi(largest_region)

        label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
        label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

    hausdorff_distance = sitk.HausdorffDistanceImageFilter()
    hausdorff_distance.Execute(label_a, label_b)
    hausdorff_distance_value = hausdorff_distance.GetHausdorffDistance()

    return hausdorff_distance_value
示例#2
0
def compute_metric_sensitivity(label_a, label_b, auto_crop=True):
    """Compute the sensitivity between two labels

    Args:
        label_a (sitk.Image): A mask to compare
        label_b (sitk.Image): Another mask to compare

    Returns:
        float: The sensitivity between the two labels
    """
    if auto_crop:
        largest_region = (label_a + label_b) > 0
        crop_box_size, crop_box_index = label_to_roi(largest_region)

        label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
        label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

    arr_a = sitk.GetArrayFromImage(label_a).astype(bool)
    arr_b = sitk.GetArrayFromImage(label_b).astype(bool)

    arr_intersection = arr_a & arr_b

    true_pos = arr_intersection.sum()
    false_neg = arr_a.sum() - true_pos

    return float((1.0 * true_pos) / (true_pos + false_neg))
示例#3
0
def generate_valve_from_great_vessel(
    label_great_vessel,
    label_ventricle,
    valve_thickness_mm=8,
):
    """
    Generates a geometrically-defined valve.
    This function is suitable for the pulmonic and aortic valves.

    Args:
        label_great_vessel (SimpleITK.Image): The binary mask for the great vessel
            (pulmonary artery or ascending aorta)
        label_ventricle (SimpleITK.Image): The binary mask for the ventricle (left or right)
        valve_thickness_mm (int, optional): Valve thickness, in millimetres. Defaults to 8.

    Returns:
        SimpleITK.Image: The geometric valve, as a binary mask.
    """

    # To speed up binary morphology operations we first crop all images
    template_img = 0 * label_ventricle
    cb_size, cb_index = label_to_roi(
        (label_great_vessel + label_ventricle) > 0, expansion_mm=(20, 20, 20))

    label_ventricle = crop_to_roi(label_ventricle, cb_size, cb_index)
    label_great_vessel = crop_to_roi(label_great_vessel, cb_size, cb_index)

    # Convert valve thickness to voxels
    _, _, res_z = label_ventricle.GetSpacing()
    valve_thickness = int(valve_thickness_mm / res_z)

    # Dilate the ventricle
    label_ventricle_dilate = sitk.BinaryDilate(label_ventricle,
                                               (valve_thickness, ) * 3)

    # Find the overlap
    overlap = label_great_vessel & label_ventricle_dilate

    # Mask to thinner great vessel
    mask = label_great_vessel | label_ventricle_dilate

    overlap = sitk.Mask(overlap, mask)

    label_valve = sitk.BinaryMorphologicalClosing(overlap)

    # Finally, paste back to the original image space
    label_valve = sitk.Paste(
        template_img,
        label_valve,
        label_valve.GetSize(),
        (0, 0, 0),
        cb_index,
    )

    return label_valve
示例#4
0
def compute_metric_dsc(label_a, label_b, auto_crop=True):
    """Compute the Dice Similarity Coefficient between two labels

    Args:
        label_a (sitk.Image): A mask to compare
        label_b (sitk.Image): Another mask to compare

    Returns:
        float: The Dice Similarity Coefficient
    """
    if auto_crop:
        largest_region = (label_a + label_b) > 0
        crop_box_size, crop_box_index = label_to_roi(largest_region)

        label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
        label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

    arr_a = sitk.GetArrayFromImage(label_a).astype(bool)
    arr_b = sitk.GetArrayFromImage(label_b).astype(bool)
    return 2 * ((arr_a & arr_b).sum()) / (arr_a.sum() + arr_b.sum())
示例#5
0
def compute_metric_masd(label_a, label_b, auto_crop=True):
    """Compute the mean absolute distance between two labels

    Args:
        label_a (sitk.Image): A mask to compare
        label_b (sitk.Image): Another mask to compare

    Returns:
        float: The mean absolute surface distance
    """
    if (
        sitk.GetArrayViewFromImage(label_a).sum() == 0
        or sitk.GetArrayViewFromImage(label_b).sum() == 0
    ):
        return np.nan

    if auto_crop:
        largest_region = (label_a + label_b) > 0
        crop_box_size, crop_box_index = label_to_roi(largest_region)

        label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
        label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

    mean_sd_list = []
    num_points = []
    for (la, lb) in ((label_a, label_b), (label_b, label_a)):

        label_intensity_stat = sitk.LabelIntensityStatisticsImageFilter()
        reference_distance_map = sitk.Abs(
            sitk.SignedMaurerDistanceMap(la, squaredDistance=False, useImageSpacing=True)
        )
        moving_label_contour = sitk.LabelContour(lb)
        label_intensity_stat.Execute(moving_label_contour, reference_distance_map)

        mean_sd_list.append(label_intensity_stat.GetMean(1))
        num_points.append(label_intensity_stat.GetNumberOfPixels(1))

    mean_surf_dist = np.dot(mean_sd_list, num_points) / np.sum(num_points)
    return float(mean_surf_dist)
示例#6
0
def generate_left_ventricle_segments(
    contours,
    label_left_ventricle="LEFTVENTRICLE",
    label_left_atrium="LEFTATRIUM",
    label_right_ventricle="RIGHTVENTRICLE",
    label_mitral_valve="MITRALVALVE",
    label_heart="WHOLEHEART",
    myocardium_thickness_mm=10,
    hole_fill_mm=3,
    optimiser_tol_degrees=1,
    optimiser_max_iter=10,
    min_area_mm2=50,
    verbose=False,
):
    """
    Generates the 17 segments of the left vetricle

    This functions works as follows:
        1.  Heart volume is rotated to align the long axis to the z Cartesian (physical) space.
            Usually means it aligns with the axial axis (for normal simulation CT)
        2.  An optimiser adjusts the orientation to refine this alignment to the vector defined by
            MV COM - LV apex axis (long axis)
        3.  Left ventricle is divided into thirds along the long axis
        4.  Myocardium is defined as the outer 10mm
        5.  Geometric operations are used to define the segments
        6.  Everything is rotated back to the normal orientation
        7.  Some post-processing *magic*

    Args:
        contours (dict): A dictionary containing strings (label names) as keys and SimpleITK.Image
            (masks) as values. Must contain at least the LV, RV, MV, and whole heart.
        label_left_ventricle (str): The name for the left ventricle mask (contour)
        label_left_atrium (str): The name for the left atrium mask (contour)
        label_right_ventricle (str): The name for the right ventricle mask (contour)
        label_mitral_valve (str): The name for the mitral valve mask (contour)
        label_heart (str): The name for the heart mask (contour)
        myocardium_thickness_mm (float, optional): Moycardial thickness, in millimetres.
            Defaults to 10.
        hole_fill_mm (float, optional): Holes smaller than this get filled in. Defaults to 3.
        optimiser_tol_degrees (float, optional): Optimiser tolerance (change in angle per iter).
            Defaults to 1, which typically requires 3-4 iterations.
        optimiser_max_iter (int, optional): Maximum optimiser iterations. Defaults to 10
        verbose (bool, optional): Print of information for debugging. Defaults to False.

    Returns:
        dict : The left ventricle segment dictionary, with labels (int) as keys and the binary
        label defining the segment (SimpleITK.Image) as values.
    """

    if verbose:
        print("Beginning LV segmentation algorithm.")

    # Initial set up
    label_list = [
        label_left_ventricle,
        label_left_atrium,
        label_right_ventricle,
        label_mitral_valve,
        label_heart,
    ]
    working_contours = copy.deepcopy({s: contours[s] for s in label_list})
    output_contours = {}
    overall_transform_list = []

    # Some conversions
    erode_img = [
        int(myocardium_thickness_mm / i)
        for i in working_contours[label_left_ventricle].GetSpacing()
    ]
    hole_fill_img = [
        int(hole_fill_mm / i)
        for i in working_contours[label_heart].GetSpacing()
    ]
    """
    Module 1 - Preparation
    Crop the images
    Rotate to the cardiac axis
    """
    # Crop to the smallest volume possible to make it FAST
    cb_size, cb_index = label_to_roi(
        working_contours[label_heart] > 0,
        expansion_mm=(30, 30, 60),  # Better to make it a bit bigger to be safe
    )

    for label in contours:
        working_contours[label] = crop_to_roi(contours[label], cb_size,
                                              cb_index)

    if verbose:
        print("Module 1: Cropping and initial alignment.")
        vol_before = np.product(contours[label_heart].GetSpacing())
        vol_after = np.product(working_contours[label_heart].GetSpacing())
        print(
            f"  Images cropped. Volume reduction: {vol_before/vol_after:.3f}")

    # Initially we should reorient based on the cardiac axis
    label_orient = (working_contours[label_left_ventricle] +
                    working_contours[label_left_atrium]) > 0

    lsf = sitk.LabelShapeStatisticsImageFilter(
    )  # this will be used throughout
    lsf.Execute(label_orient)
    cardiac_axis = np.array(
        lsf.GetPrincipalAxes(1)[:3])  # First principal axis approx. long axis

    # The principal axis isn't guaranteed to point from base to apex
    # If is points apex to base, we have to invert it
    # So check that here
    if cardiac_axis[2] < 0:
        cardiac_axis = -1 * cardiac_axis

    rotation_angle = vector_angle(cardiac_axis[::-1], (0, 0, 1))
    rotation_axis = np.cross(cardiac_axis[::-1], (0, 0, 1))
    rotation_centre = get_com(label_orient, real_coords=True)

    if verbose:
        print("  Alignment computed.")
        print("    Cardiac axis:    ", cardiac_axis)
        print("    Rotation axis:   ", rotation_axis)
        print("    Rotation angle:  ", rotation_angle)
        print("    Rotation centre: ", rotation_centre)

    rotation_transform = sitk.VersorRigid3DTransform()
    rotation_transform.SetCenter(rotation_centre)
    rotation_transform.SetRotation(rotation_axis, rotation_angle)

    overall_transform_list.append(rotation_transform)

    for label in contours:
        working_contours[label] = sitk.Resample(
            working_contours[label],
            rotation_transform,
            sitk.sitkNearestNeighbor,
            0,
            working_contours[label].GetPixelID(),
        )
    """
    Module 2 - LV orientation alignment
    We use a very simple optimisation regime to enable robust computation of the LV apex
    We compute the vector from the MV COM to the LV apex
    This will be used for orientation (i.e. the long axis)
    """
    optimiser_tol_radians = optimiser_tol_degrees * np.pi / 180

    n = 0

    if verbose:
        print("Module 2: LV orientation alignment.")
        print("  Optimiser tolerance (degrees) =", optimiser_tol_degrees)
        print("  Beginning alignment process")

    while n < optimiser_max_iter and np.abs(
            rotation_angle) > optimiser_tol_radians:

        n += 1

        # Find the LV apex
        lv_locations = np.where(
            sitk.GetArrayViewFromImage(working_contours[label_left_ventricle]))
        lv_apex_z = lv_locations[0].min()
        lv_apex_y = lv_locations[1][lv_locations[0] == lv_apex_z].mean()
        lv_apex_x = lv_locations[2][lv_locations[0] == lv_apex_z].mean()
        lv_apex_loc = np.array([lv_apex_x, lv_apex_y, lv_apex_z])

        # Get the MV COM
        mv_com = np.array(
            get_com(working_contours[label_mitral_valve], real_coords=True))

        # Define the LV axis
        lv_apex_loc_img = np.array(working_contours[label_left_ventricle].
                                   TransformContinuousIndexToPhysicalPoint(
                                       lv_apex_loc.tolist()))
        lv_axis = lv_apex_loc_img - mv_com

        # Compute the rotation parameters
        rotation_axis = np.cross(lv_axis, (0, 0, 1))
        rotation_angle = vector_angle(lv_axis, (0, 0, 1))
        rotation_centre = 0.5 * (
            mv_com + lv_apex_loc_img
        )  # get_com(working_contours[label_left_ventricle], real_coords=True)

        rotation_transform = sitk.VersorRigid3DTransform()
        rotation_transform.SetCenter(rotation_centre)
        rotation_transform.SetRotation(rotation_axis, rotation_angle)

        overall_transform_list.append(rotation_transform)

        if verbose:
            print("    N:               ", n)
            print("    LV apex:         ", lv_apex_loc_img)
            print("    MV COM:          ", mv_com)
            print("    LV axis:         ", lv_axis)
            print("    Rotation axis:   ", rotation_axis)
            print("    Rotation centre: ", rotation_centre)
            print("    Rotation angle:  ", rotation_angle)

        for label in contours:
            working_contours[label] = sitk.Resample(
                working_contours[label],
                rotation_transform,
                sitk.sitkNearestNeighbor,
                0,
                working_contours[label].GetPixelID(),
            )
    """
    Module 3 - Compute the myocardium for the whole LV volume

    Divide this volume into thirds (from MV COM -> LV apex)        
    """

    if verbose:
        print("Module 3: Myocardium generation.")

    # First, let's just extract the myocardium
    label_lv_inner = sitk.BinaryErode(working_contours[label_left_ventricle],
                                      erode_img)
    label_lv_myo = working_contours[label_left_ventricle] - label_lv_inner

    # Mask the myo to a dilation of the blood pool
    # This helps improve shape consistency
    label_lv_myo_mask = sitk.BinaryDilate(label_lv_inner, erode_img)
    label_lv_myo = sitk.Mask(label_lv_myo, label_lv_myo_mask)

    # Computing limits for division into thirds
    # [xstart, ystart, zstart, xsize, ysize, zsize]
    # For the limits, we will use the centre of mass of the MV to the LV apex
    # The inner limit is used to assign the top portion (basal) of the LV to the anterior segment
    lsf.Execute(label_lv_inner)
    _, _, inf_limit_lv, _, _, extent = lsf.GetRegion(1)

    com_mv, _, _ = get_com(working_contours[label_mitral_valve])

    extent = com_mv - inf_limit_lv
    dc = int(extent / 3)

    # Define limits (cut LV into thirds)
    apical_extent = inf_limit_lv + dc
    mid_extent = inf_limit_lv + 2 * dc
    basal_extent = com_mv  # more complete coverage

    if verbose:
        print("  Apex (long axis) slice:      ", inf_limit_lv)
        print("  Apical section extent slice: ", apical_extent)
        print("  Mid section extent slice:    ", mid_extent)
        print("  Basal section extent slice:  ", basal_extent)
        print("    DeltaCut (DC): ", dc)
        print("    Extent:        ", extent)

    # Segment 17
    label_lv_myo_apex = label_lv_myo * 1  # make a copy
    label_lv_myo_apex[:, :, inf_limit_lv:] = 0

    # The apical segment
    label_lv_myo_apical = label_lv_myo * 1  # make a copy
    label_lv_myo_apical[:, :, :inf_limit_lv] = 0
    label_lv_myo_apical[:, :, apical_extent:] = 0

    # The mid segment
    label_lv_myo_mid = label_lv_myo * 1  # make a copy
    label_lv_myo_mid[:, :, :apical_extent] = 0
    label_lv_myo_mid[:, :, mid_extent:] = 0

    # The basal segment
    label_lv_myo_basal = label_lv_myo * 1  # make a copy
    label_lv_myo_basal[:, :, :mid_extent] = 0
    label_lv_myo_basal[:, :, basal_extent:] = 0
    """
    Module 4 - Generate 17 segments

        1. Find the basal (anterior) insertion of the RV
            This defines theta_0
        2. Find the baseline angle for the apical section
            This defines thera_0_apical
        3. Iterate though each section (apical, mid, basal):
            a. Convert each myocardium label loc to polar coords
            b. Assign each label to the appropriate LV segments
    """

    if verbose:
        print("Module 4: Segment generation.")

    # We need the angle for the basal RV insertion
    # This is the most counter-clockwise RV location
    # First, retrieve the most basal 5 slices
    loc_rv_z, loc_rv_y, loc_rv_x = np.where(
        sitk.GetArrayViewFromImage(working_contours[label_right_ventricle]))
    loc_rv_z_basal = np.arange(mid_extent, mid_extent + 5)

    if verbose:
        print("  RV basal slices: ", loc_rv_z_basal)

    theta_rv_insertion = []
    for z in loc_rv_z_basal:
        # Now get all the x and y positions
        loc_rv_basal_x = loc_rv_x[np.where(np.in1d(loc_rv_z, z))]
        loc_rv_basal_y = loc_rv_y[np.where(np.in1d(loc_rv_z, z))]

        # Now define the LV COM on each slice
        lv_com = get_com(working_contours[label_left_ventricle][:, :, int(z)])
        lv_com_basal_x = lv_com[1]
        lv_com_basal_y = lv_com[0]

        # Compute the angle
        theta_rv = np.arctan2(lv_com_basal_y - loc_rv_basal_y,
                              loc_rv_basal_x - lv_com_basal_x)
        theta_rv[theta_rv < 0] += 2 * np.pi
        theta_rv_insertion.append(theta_rv.min())

    theta_0 = np.median(theta_rv_insertion)

    if verbose:
        print("  RV insertion angle (basal section): ", theta_0)

    # We also need the angle in the apical section for accurate segmentation
    lv_com_apical_list = []
    rv_com_apical_list = []
    for n in range(inf_limit_lv, apical_extent):
        lv_com_apical_list.append(
            get_com(working_contours[label_left_ventricle][:, :, n]))
        rv_com_apical_list.append(
            get_com(working_contours[label_right_ventricle][:, :, n]))

    lv_com_apical = np.mean(lv_com_apical_list, axis=0)
    rv_com_apical = np.mean(rv_com_apical_list, axis=0)

    theta_0_apical = np.arctan2(lv_com_apical[0] - rv_com_apical[0],
                                rv_com_apical[1] - lv_com_apical[1])

    if verbose:
        print(" Apical LV-RV COM angle: ", theta_0_apical)

    for i in range(17):
        working_contours[i + 1] = 0 * working_contours[label_heart]

    working_contours[17] = label_lv_myo_apex

    if verbose:
        print("  Computing apical segments")
    # We are now going to compute the segments in cylindical sections
    # First up - apical slices
    for n in range(inf_limit_lv, apical_extent):

        label_lv_myo_slice = label_lv_myo[:, :, n]

        # We will need numpy arrays here
        arr_lv_myo_slice = sitk.GetArrayViewFromImage(label_lv_myo_slice)
        loc_y, loc_x = np.where(arr_lv_myo_slice)

        # Now the origin
        y_0, x_0 = get_com(label_lv_myo_slice)

        # Compute the angle(s)
        theta = -np.arctan2(loc_y - y_0, loc_x - x_0) - theta_0_apical
        # Convert to [0,2*np.pi]
        theta[theta < 0] += 2 * np.pi

        # Compute the radii
        radii = np.sqrt((loc_y - y_0)**2 + (loc_x - x_0)**2)

        # Now assign to different segments
        working_contours[13][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            5 * np.pi / 4,
            7 * np.pi / 4,
            loc_x,
            loc_y,
            min_area_mm2=min_area_mm2,
        )
        working_contours[14][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            1 * np.pi / 4,
            7 * np.pi / 4,
            loc_x,
            loc_y,
            cw=True,
            min_area_mm2=min_area_mm2,
        )
        working_contours[15][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            1 * np.pi / 4,
            3 * np.pi / 4,
            loc_x,
            loc_y,
            min_area_mm2=min_area_mm2,
        )
        working_contours[16][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            3 * np.pi / 4,
            5 * np.pi / 4,
            loc_x,
            loc_y,
            min_area_mm2=min_area_mm2,
        )

    if verbose:
        print("  Computing mid segments")
    # Second up - mid slices
    for n in range(apical_extent, mid_extent):

        label_lv_myo_slice = label_lv_myo[:, :, n]

        # We will need numpy arrays here
        arr_lv_myo_slice = sitk.GetArrayViewFromImage(label_lv_myo_slice)
        loc_y, loc_x = np.where(arr_lv_myo_slice)

        # Now the origin
        y_0, x_0 = get_com(label_lv_myo_slice)

        # Compute the angle(s)
        theta = -np.arctan2(loc_y - y_0, loc_x - x_0) - theta_0
        # Convert to [0,2*np.pi]
        theta[theta < 0] += 2 * np.pi

        # Compute the radii
        radii = np.sqrt((loc_y - y_0)**2 + (loc_x - x_0)**2)

        # Now assign to different segments
        working_contours[8][:, :, n] = extract(label_lv_myo_slice,
                                               theta,
                                               radii,
                                               0,
                                               np.pi / 3,
                                               loc_x,
                                               loc_y,
                                               min_area_mm2=min_area_mm2)
        working_contours[9][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            1 * np.pi / 3,
            2 * np.pi / 3,
            loc_x,
            loc_y,
            min_area_mm2=min_area_mm2,
        )
        working_contours[10][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            2 * np.pi / 3,
            3 * np.pi / 3,
            loc_x,
            loc_y,
            min_area_mm2=min_area_mm2,
        )
        working_contours[11][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            3 * np.pi / 3,
            4 * np.pi / 3,
            loc_x,
            loc_y,
            min_area_mm2=min_area_mm2,
        )
        working_contours[12][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            4 * np.pi / 3,
            5 * np.pi / 3,
            loc_x,
            loc_y,
            min_area_mm2=min_area_mm2,
        )
        working_contours[7][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            5 * np.pi / 3,
            2 * np.pi,
            loc_x,
            loc_y,
            min_area_mm2=min_area_mm2,
        )

    if verbose:
        print("  Computing basal segments")
    # Third up - basal slices
    for n in range(mid_extent, basal_extent):

        label_lv_myo_slice = label_lv_myo[:, :, n]

        # We will need numpy arrays here
        arr_lv_myo_slice = sitk.GetArrayViewFromImage(label_lv_myo_slice)
        loc_y, loc_x = np.where(arr_lv_myo_slice)

        # Now the origin
        y_0, x_0 = get_com(label_lv_myo_slice)

        # Compute the angle(s)
        theta = -np.arctan2(loc_y - y_0, loc_x - x_0) - theta_0
        # Convert to [0,2*np.pi]
        theta[theta < 0] += 2 * np.pi

        # Compute the radii
        radii = np.sqrt((loc_y - y_0)**2 + (loc_x - x_0)**2)

        # Now assign to different segments
        working_contours[2][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            0,
            np.pi / 3,
            loc_x,
            loc_y,
            radius_min=15,
            min_area_mm2=min_area_mm2,
        )
        working_contours[3][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            1 * np.pi / 3,
            2 * np.pi / 3,
            loc_x,
            loc_y,
            radius_min=15,
            min_area_mm2=min_area_mm2,
        )
        working_contours[4][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            2 * np.pi / 3,
            3 * np.pi / 3,
            loc_x,
            loc_y,
            radius_min=15,
            min_area_mm2=min_area_mm2,
        )
        working_contours[5][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            3 * np.pi / 3,
            4 * np.pi / 3,
            loc_x,
            loc_y,
            radius_min=15,
            min_area_mm2=min_area_mm2,
        )
        working_contours[6][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            4 * np.pi / 3,
            5 * np.pi / 3,
            loc_x,
            loc_y,
            radius_min=15,
            min_area_mm2=min_area_mm2,
        )
        working_contours[1][:, :, n] = extract(
            label_lv_myo_slice,
            theta,
            radii,
            5 * np.pi / 3,
            2 * np.pi,
            loc_x,
            loc_y,
            radius_min=15,
            min_area_mm2=min_area_mm2,
        )
    """
    Module 5 - re-orientation into image space

    We perform the total inverse transformation, and paste the labels back into the image space
    """

    if verbose:
        print("  Module 5: Re-orientation.")

    # Compute the total transform
    overall_transform = sitk.CompositeTransform(overall_transform_list)
    inverse_transform = overall_transform.GetInverse()

    # Rotate back to the original reference space
    for segment in range(17):
        new_structure = sitk.Resample(
            working_contours[segment + 1],
            inverse_transform,
            sitk.sitkNearestNeighbor,
            0,
        )

        if hole_fill_mm > 0:
            new_structure = sitk.BinaryMorphologicalClosing(
                new_structure, hole_fill_img)

        new_structure = sitk.Paste(
            contours[label_heart] * 0,
            new_structure,
            new_structure.GetSize(),
            (0, 0, 0),
            cb_index,
        )

        output_contours[segment + 1] = new_structure

    if verbose:
        print("Complete!")

    return output_contours
示例#7
0
def generate_valve_using_cylinder(
    label_atrium,
    label_ventricle,
    radius_mm=15,
    height_mm=10,
):
    """
    Generates a geometrically-defined valve.
    This function is suitable for the tricuspid and mitral valves.

    Args:
        label_atrium (SimpleITK.Image): The binary mask for the (left or right) atrium.
        label_ventricle (SimpleITK.Image): The binary mask for the (left or right) ventricle.
        radius_mm (int, optional): The valve radius, in mm. Defaults to 15.
        height_mm (int, optional): The valve height (i.e. perpendicular extent), in mm.
            Defaults to 10.

    Returns:
        SimpleITK.Image: The geometrically defined valve
    """
    # To speed up binary morphology operations we first crop all images
    template_img = 0 * label_ventricle
    cb_size, cb_index = label_to_roi((label_atrium + label_ventricle) > 0,
                                     expansion_mm=(20, 20, 20))

    label_atrium = crop_to_roi(label_atrium, cb_size, cb_index)
    label_ventricle = crop_to_roi(label_ventricle, cb_size, cb_index)

    # Define the overlap region (using binary dilation)
    # Increment overlap to make sure we have enough voxels
    dilation = 1
    overlap_vol = 0
    while overlap_vol <= 2000:
        dilation_img = [
            int(dilation / i) for i in label_ventricle.GetSpacing()
        ]
        overlap = sitk.BinaryDilate(label_atrium,
                                    dilation_img) & sitk.BinaryDilate(
                                        label_ventricle, dilation_img)
        overlap_vol = np.sum(
            sitk.GetArrayFromImage(overlap) * np.product(overlap.GetSpacing()))
        dilation += 1

    # Now we can calculate the location of the valve
    valve_loc = get_com(overlap, as_int=True)
    valve_loc_real = get_com(overlap, real_coords=True)

    # Now we create a cylinder with the user_defined parameters
    cylinder = insert_cylinder_image(0 * label_ventricle, radius_mm, height_mm,
                                     valve_loc[::-1])

    # Now we compute the first principal moment (long axis) of the combined chambers
    # f = sitk.LabelShapeStatisticsImageFilter()
    # f.Execute(label_ventricle + label_atrium)
    # orientation_vector = f.GetPrincipalAxes(1)[:3]

    # A more robust method is to use the COM offset from the chambers
    # as a proxy for the long axis of the LV/RV
    orientation_vector = np.array(get_com(
        label_ventricle, real_coords=True)) - np.array(
            get_com(label_atrium, real_coords=True))

    # Another method is to compute the third principal moment of the overlap region
    # f = sitk.LabelShapeStatisticsImageFilter()
    # f.Execute(overlap)
    # orientation_vector = f.GetPrincipalAxes(1)[:3]

    # Get the rotation parameters
    rotation_angle = vector_angle(orientation_vector, (0, 0, 1),
                                  smallest=False)
    rotation_axis = np.cross(orientation_vector, (0, 0, 1))

    # Rotate the cylinder to define the valve
    label_valve = rotate_image(
        cylinder,
        rotation_centre=valve_loc_real,
        rotation_axis=rotation_axis,
        rotation_angle_radians=rotation_angle,
        interpolation=sitk.sitkNearestNeighbor,
        default_value=0,
    )

    # Now we want to trim any parts of the valve too close to the edge of the chambers
    # combined_chambers = sitk.BinaryDilate(label_atrium, (3,) * 3) | sitk.BinaryDilate(
    #     label_ventricle, (3,) * 3
    # )
    # combined_chambers = sitk.BinaryErode(combined_chambers, (6, 6, 6))

    # label_valve = sitk.Mask(label_valve, combined_chambers)

    # Finally, paste back to the original image space
    label_valve = sitk.Paste(
        template_img,
        label_valve,
        label_valve.GetSize(),
        (0, 0, 0),
        cb_index,
    )

    return label_valve
示例#8
0
文件: run.py 项目: pyplati/platipy
def run_segmentation(img, settings=MUTLIATLAS_SETTINGS_DEFAULTS):
    """Runs the atlas-based segmentation algorithm

    Args:
        img (sitk.Image):
        settings (dict, optional): Dictionary containing settings for algorithm.
                                   Defaults to default_settings.

    Returns:
        dict: Dictionary containing output of segmentation
    """

    results = {}
    results_prob = {}

    """
    Initialisation - Read in atlases
    - image files
    - structure files

        Atlas structure:
        'ID': 'Original': 'CT Image'    : sitk.Image
                            'Struct A'    : sitk.Image
                            'Struct B'    : sitk.Image
                'RIR'     : 'CT Image'    : sitk.Image
                            'Transform'   : transform parameter map
                            'Struct A'    : sitk.Image
                            'Struct B'    : sitk.Image
                'DIR'     : 'CT Image'    : sitk.Image
                            'Transform'   : displacement field transform
                            'Weight Map'  : sitk.Image
                            'Struct A'    : sitk.Image
                            'Struct B'    : sitk.Image


    """

    logger.info("")
    # Settings
    atlas_path = settings["atlas_settings"]["atlas_path"]
    atlas_id_list = settings["atlas_settings"]["atlas_id_list"]
    atlas_structure_list = settings["atlas_settings"]["atlas_structure_list"]

    atlas_image_format = settings["atlas_settings"]["atlas_image_format"]
    atlas_label_format = settings["atlas_settings"]["atlas_label_format"]

    crop_atlas_to_structures = settings["atlas_settings"]["crop_atlas_to_structures"]
    crop_atlas_expansion_mm = settings["atlas_settings"]["crop_atlas_expansion_mm"]

    atlas_set = {}
    for atlas_id in atlas_id_list:
        atlas_set[atlas_id] = {}
        atlas_set[atlas_id]["Original"] = {}

        image = sitk.ReadImage(f"{atlas_path}/{atlas_image_format.format(atlas_id)}")

        structures = {
            struct: sitk.ReadImage(f"{atlas_path}/{atlas_label_format.format(atlas_id, struct)}")
            for struct in atlas_structure_list
        }

        if crop_atlas_to_structures:
            logger.info(f"Automatically cropping atlas: {atlas_id}")

            original_volume = np.product(image.GetSize())

            crop_box_size, crop_box_index = label_to_roi(
                structures.values(), expansion_mm=crop_atlas_expansion_mm
            )

            image = crop_to_roi(image, size=crop_box_size, index=crop_box_index)

            final_volume = np.product(image.GetSize())

            logger.info(f"  > Volume reduced by factor {original_volume/final_volume:.2f}")

            for struct in atlas_structure_list:
                structures[struct] = crop_to_roi(
                    structures[struct], size=crop_box_size, index=crop_box_index
                )

        atlas_set[atlas_id]["Original"]["CT Image"] = image

        for struct in atlas_structure_list:
            atlas_set[atlas_id]["Original"][struct] = structures[struct]

    """
    Step 1 - Automatic cropping
    If we have a guide structure:
        - use structure to crop target image

    Otherwise:
        - using a quick registration to register each atlas
        - expansion of the bounding box to ensure entire volume of interest is enclosed
        - target image is cropped
    """

    expansion_mm = settings["auto_crop_target_image_settings"]["expansion_mm"]

    quick_reg_settings = {
        "reg_method": "similarity",
        "shrink_factors": [8],
        "smooth_sigmas": [0],
        "sampling_rate": 0.75,
        "default_value": -1000,
        "number_of_iterations": 25,
        "final_interp": sitk.sitkLinear,
        "metric": "mean_squares",
        "optimiser": "gradient_descent_line_search",
    }

    registered_crop_images = []

    logger.info("Running initial Translation tranform to crop image volume")

    for atlas_id in atlas_id_list[: min([8, len(atlas_id_list)])]:

        logger.info(f"  > atlas {atlas_id}")

        # Register the atlases
        atlas_set[atlas_id]["RIR"] = {}
        atlas_image = atlas_set[atlas_id]["Original"]["CT Image"]

        reg_image, _ = linear_registration(
            img,
            atlas_image,
            **quick_reg_settings,
        )

        registered_crop_images.append(sitk.Cast(reg_image, sitk.sitkFloat32))

        del reg_image

    combined_image = sum(registered_crop_images) / len(registered_crop_images) > -1000

    crop_box_size, crop_box_index = label_to_roi(combined_image, expansion_mm=expansion_mm)

    img_crop = crop_to_roi(img, crop_box_size, crop_box_index)

    logger.info("Calculated crop box:")
    logger.info(f"  > {crop_box_index}")
    logger.info(f"  > {crop_box_size}")
    logger.info(f"  > Vol reduction = {np.product(img.GetSize())/np.product(crop_box_size):.2f}")

    """
    Step 2 - Rigid registration of target images
    - Individual atlas images are registered to the target
    - The transformation is used to propagate the labels onto the target
    """
    linear_registration_settings = settings["linear_registration_settings"]

    logger.info(
        f"Running {linear_registration_settings['reg_method']} tranform to align atlas images"
    )

    for atlas_id in atlas_id_list:
        # Register the atlases

        logger.info(f"  > atlas {atlas_id}")

        atlas_set[atlas_id]["RIR"] = {}

        target_reg_image = img_crop
        atlas_reg_image = atlas_set[atlas_id]["Original"]["CT Image"]

        _, initial_tfm = linear_registration(
            target_reg_image,
            atlas_reg_image,
            **linear_registration_settings,
        )

        # Save in the atlas dict
        atlas_set[atlas_id]["RIR"]["Transform"] = initial_tfm

        atlas_set[atlas_id]["RIR"]["CT Image"] = apply_transform(
            input_image=atlas_set[atlas_id]["Original"]["CT Image"],
            reference_image=img_crop,
            transform=initial_tfm,
            default_value=-1000,
            interpolator=sitk.sitkLinear,
        )

        # sitk.WriteImage(rigid_image, f"./RR_{atlas_id}.nii.gz")

        for struct in atlas_structure_list:
            input_struct = atlas_set[atlas_id]["Original"][struct]
            atlas_set[atlas_id]["RIR"][struct] = apply_transform(
                input_image=input_struct,
                reference_image=img_crop,
                transform=initial_tfm,
                default_value=0,
                interpolator=sitk.sitkNearestNeighbor,
            )

        atlas_set[atlas_id]["Original"] = None

    """
    Step 3 - Deformable image registration
    - Using Fast Symmetric Diffeomorphic Demons
    """

    # Settings
    deformable_registration_settings = settings["deformable_registration_settings"]

    logger.info("Running DIR to refine atlas image registration")

    for atlas_id in atlas_id_list:

        logger.info(f"  > atlas {atlas_id}")

        # Register the atlases
        atlas_set[atlas_id]["DIR"] = {}

        atlas_reg_image = atlas_set[atlas_id]["RIR"]["CT Image"]
        target_reg_image = img_crop

        _, dir_tfm, _ = fast_symmetric_forces_demons_registration(
            target_reg_image,
            atlas_reg_image,
            **deformable_registration_settings,
        )

        # Save in the atlas dict
        atlas_set[atlas_id]["DIR"]["Transform"] = dir_tfm

        atlas_set[atlas_id]["DIR"]["CT Image"] = apply_transform(
            input_image=atlas_set[atlas_id]["RIR"]["CT Image"],
            transform=dir_tfm,
            default_value=-1000,
            interpolator=sitk.sitkLinear,
        )

        for struct in atlas_structure_list:
            input_struct = atlas_set[atlas_id]["RIR"][struct]
            atlas_set[atlas_id]["DIR"][struct] = apply_transform(
                input_image=input_struct,
                transform=dir_tfm,
                default_value=0,
                interpolator=sitk.sitkNearestNeighbor,
            )

        atlas_set[atlas_id]["RIR"] = None

    """
    Step 4 - Label Fusion
    """
    # Compute weight maps
    vote_type = settings["label_fusion_settings"]["vote_type"]
    vote_params = settings["label_fusion_settings"]["vote_params"]

    # Compute weight maps
    for atlas_id in list(atlas_set.keys()):
        atlas_image = atlas_set[atlas_id]["DIR"]["CT Image"]
        weight_map = compute_weight_map(
            img_crop, atlas_image, vote_type=vote_type, vote_params=vote_params
        )
        atlas_set[atlas_id]["DIR"]["Weight Map"] = weight_map

    combined_label_dict = combine_labels(atlas_set, atlas_structure_list)

    """
    Step 6 - Paste the cropped structure into the original image space
    """
    logger.info("Generating binary segmentations.")
    template_img_binary = sitk.Cast((img * 0), sitk.sitkUInt8)
    template_img_prob = sitk.Cast((img * 0), sitk.sitkFloat64)

    for structure_name in atlas_structure_list:

        probability_map = combined_label_dict[structure_name]

        if structure_name in settings["label_fusion_settings"]["optimal_threshold"]:
            optimal_threshold = settings["label_fusion_settings"]["optimal_threshold"][
                structure_name
            ]
        else:
            optimal_threshold = 0.5

        binary_struct = process_probability_image(probability_map, optimal_threshold)

        # Un-crop binary structure
        paste_img_binary = sitk.Paste(
            template_img_binary,
            binary_struct,
            binary_struct.GetSize(),
            (0, 0, 0),
            crop_box_index,
        )
        results[structure_name] = paste_img_binary

        # Un-crop probability map
        paste_prob_img = sitk.Paste(
            template_img_prob,
            probability_map,
            probability_map.GetSize(),
            (0, 0, 0),
            crop_box_index,
        )
        results_prob[structure_name] = paste_prob_img

    """
    Step 8 - Post-processing
    """
    postprocessing_settings = settings["postprocessing_settings"]

    if postprocessing_settings["run_postprocessing"]:
        logger.info("Running post-processing.")

        # Remove any smaller components and perform morphological closing (hole filling)
        binaryfillhole_img = [
            int(postprocessing_settings["binaryfillhole_mm"] / sp) for sp in img.GetSpacing()
        ]

        for structure_name in postprocessing_settings["structures_for_binaryfillhole"]:

            if structure_name not in results.keys():
                continue

            contour_s = results[structure_name]
            contour_s = sitk.RelabelComponent(sitk.ConnectedComponent(contour_s)) == 1
            contour_s = sitk.BinaryMorphologicalClosing(contour_s, binaryfillhole_img)
            results[structure_name] = contour_s

        if len(postprocessing_settings["structures_for_overlap_correction"]) >= 2:
            # Remove any overlaps
            input_overlap = {
                s: results[s] for s in postprocessing_settings["structures_for_overlap_correction"]
            }
            output_overlap = correct_volume_overlap(input_overlap)

            for s in postprocessing_settings["structures_for_overlap_correction"]:
                results[s] = output_overlap[s]

    logger.info("Done!")

    return results, results_prob
示例#9
0
def quick_optimise_probability(
    metric_function,
    manual_contour,
    probability_image,
    p_0=0.5,
    delta=0.5,
    tolerance=0.01,
    mode="min",
    create_figure=False,
    auto_crop=True,
    metric_args={},
):
    """Optimise the probability threshold used to generate a binary segmentation.
    This is a simple parameter sweep, with linearly decreasing resolution. It will usually converge
    between 5 and 10 iterations.

    Args:
        metric_function (function to return float): The metric function, this takes in two binary
            masks (a reference, and test [SimpleITK.Image]) and returns a metric (as a float).
            Typical choices would be the DSC, a surface distance metric, dose difference,
            relative volume. Additional arguments can be passed in through `metric_args` argument.
        manual_contour (SimpleITK.Image): The reference (manual) contour.
        probability_image (SimpleITK.Image): The probability map from which the optimal threshold
            will be derived. This does NOT have to be scaled to [0,1].
        p_0 (float, optional): Initial guess of the optimal threshold. Defaults to 0.5.
        delta (float, optional): The window size of the optimiser. Defaults to 0.5.
        tolerance (float, optional): If the metric changes by an amount less that `tolerance`
            the optimiser will stop. Defaults to 0.01.
        mode (str, optional): Specifies whether the metric should be maximised ("max") or minimised
            ("min"). Defaults to "min".
        create_figure (bool, optional): Create a matplotlib figure showing the optimisation. This
            is not returned, so make sure you can capture this (e.g. using IPython).
            Defaults to False.
        auto_crop (bool, optional): Crop the image volumes to the region of interest. Speeds up the
            process signficiantly so don't turn off unless you have a good reason to!
            Defaults to True.
        metric_args (dict, optional): Additional arguments passes to the metric function. This
            could be useful if you are calculating a dose-based metric and require a dose grid to
            be passed to the metric function. Defaults to {}.

    Returns:
        tuple (float, float): The optimal probability, optimal metric value.
    """
    # Auto crop images
    if auto_crop:
        cb_size, cb_index = label_to_roi(
            (manual_contour > 0) | (probability_image > 0), expansion_mm=10)
        manual_contour = crop_to_roi(manual_contour, cb_size, cb_index)
        probability_image = crop_to_roi(probability_image, cb_size, cb_index)

    # Set up
    n_iter = 0
    p_best = p_0

    auto_contour = process_probability_image(probability_image, threshold=p_0)

    m_n = metric_function(manual_contour, auto_contour, **metric_args)
    m_best = m_n

    print("Starting optimisation.")
    print(f"n = 0 | p = {p_best:.3f} | metric = {m_n:.3f}")

    p_list = [p_best]
    m_list = [m_best]

    improv = 0

    while np.abs(improv) > tolerance or n_iter <= 3:

        n_iter += 1
        m_n = m_best

        p_new = [
            p_best - 3 * delta / 4,
            p_best - delta / 2,
            p_best - delta / 4,
            p_best + delta / 4,
            p_best + delta / 2,
            p_best + 3 * delta / 4,
        ]
        m_new = [
            metric_function(
                manual_contour,
                process_probability_image(probability_image, threshold=p),
                **metric_args,
            ) for p in p_new
        ]

        p_list = p_list + p_new
        m_list = m_list + m_new

        if mode == "min":
            p_best = p_list[np.argmin(m_list)]
            m_best = np.min(m_list)
        elif mode == "max":
            p_best = p_list[np.argmax(m_list)]
            m_best = np.max(m_list)

        improv = m_best - m_n

        delta /= 4

        print(f"n = {n_iter} | p = {p_best:.3f} | metric = {m_best:.3f}")

    if create_figure:
        fig, ax = plt.subplots(1, 1)
        ax.scatter(p_list, m_list, c="k", zorder=1)
        ax.plot(*list(zip(*sorted(zip(p_list, m_list)))), c="k", zorder=1)
        ax.scatter((p_best), (m_best),
                   c="r",
                   label=f"Optimum ({p_best:.2f},{m_best:.2f})",
                   zorder=2)
        ax.set_xlim(0, 1)
        ax.set_xlabel("Probability Difference (from Optimal)")
        ax.set_ylabel("Metric Value")
        ax.grid()
        ax.set_axisbelow(True)
        ax.set_title(f"Optimiser | {metric_function.__name__}, mode = {mode}")
        ax.legend()
        fig.show()

    return p_best, m_best
示例#10
0
def geometric_sinoatrialnode(label_svc, label_ra, label_wholeheart, radius_mm=10):
    """Geometric definition of the cardiac sinoatrial node (SAN).
    This is largely inspired by Loap et al 2021 [https://doi.org/10.1016/j.prro.2021.02.002]

    Args:
        label_svc (SimpleITK.Image): The binary mask defining the superior vena cava.
        label_ra (SimpleITK.Image): The binary mask defining the right atrium.
        label_wholeheart (SimpleITK.Image): The binary mask defining the whole heart.
        radius_mm (int, optional): The radius of the SAN, in millimetres. Defaults to 10.

    Returns:
        SimpleITK.Image: A binary mask defining the SAN
    """

    # To speed up binary morphology operations we first crop all images
    template_img = 0 * label_wholeheart
    cb_size, cb_index = label_to_roi(
        (label_svc + label_ra + label_wholeheart) > 0, expansion_mm=(20, 20, 20)
    )

    label_svc = crop_to_roi(label_svc, cb_size, cb_index)
    label_ra = crop_to_roi(label_ra, cb_size, cb_index)
    label_wholeheart = crop_to_roi(label_wholeheart, cb_size, cb_index)

    arr_svc = sitk.GetArrayFromImage(label_svc)
    arr_ra = sitk.GetArrayFromImage(label_ra)

    # First, find the most inferior slice of the SVC
    # This defines the z location of the SAN
    inf_limit_svc = np.min(np.where(arr_svc)[0])

    # Now expand the SVC until it touches the RA on the inf slice
    overlap = 0
    dilate = 1
    dilate_ax = 0
    while overlap == 0:
        label_svc_dilate = sitk.BinaryDilate(label_svc, (dilate, dilate, dilate_ax))
        label_overlap = label_svc_dilate & label_ra
        overlap = sitk.GetArrayFromImage(label_overlap)[inf_limit_svc, :, :].sum()
        dilate += 1

        if dilate >= 3:
            arr_svc = sitk.GetArrayFromImage(label_svc_dilate)
            inf_limit_svc = np.min(np.where(arr_svc)[0])
            dilate_ax += 1

    # Locate the point on intersection
    intersect_loc = get_com(label_overlap)

    # Create an image with a single voxel of value 1 located at the point of intersection
    arr_intersect = arr_ra * 0
    arr_intersect[inf_limit_svc, intersect_loc[1], intersect_loc[2]] = 1
    label_intersect = sitk.GetImageFromArray(arr_intersect)
    label_intersect.CopyInformation(label_ra)

    # Define the locations greater than 10mm from the WH
    # Ensures the SAN doesn't extend outside the heart volume
    potential_san_region = sitk.BinaryErode(label_wholeheart, (10, 10, 0))

    # Find the point in this region closest to the intersection
    # First generate a distance map
    distancemap_san = sitk.SignedMaurerDistanceMap(
        label_intersect, squaredDistance=False, useImageSpacing=True
    )

    # Then get the distance from the intersection at all possible points
    arr_distancemap_san = sitk.GetArrayFromImage(distancemap_san)
    arr_potential_san_region = sitk.GetArrayFromImage(potential_san_region)

    yloc, xloc = np.where(arr_potential_san_region[inf_limit_svc, :, :])

    distances = arr_distancemap_san[inf_limit_svc, yloc, xloc]

    # Find where the distance is a minimum
    location_of_min = distances.argmin()

    # Now define the SAN location
    sphere_centre = (inf_limit_svc, yloc[location_of_min], xloc[location_of_min])

    # Generate an image
    label_san = insert_sphere_image(label_ra * 0, sp_radius=radius_mm, sp_centre=sphere_centre)

    # Finally, paste the label into the original image space
    label_san = sitk.Paste(
        template_img,
        label_san,
        label_san.GetSize(),
        (0, 0, 0),
        cb_index,
    )

    return label_san
示例#11
0
def geometric_atrioventricularnode(label_la, label_lv, label_ra, label_rv, radius_mm=10):
    """Geometric definition of the cardiac atrioventricular node (AVN).
    This is largely inspired by Loap et al 2021 [https://doi.org/10.1016/j.prro.2021.02.002]

    Args:
        label_la (SimpleITK.Image): The binary mask defining the left atrium.
        label_lv (SimpleITK.Image): The binary mask defining the left ventricle.
        label_ra (SimpleITK.Image): The binary mask defining the right atrium.
        label_rv (SimpleITK.Image): The binary mask defining the right ventricle.
        radius_mm (float, optional): The radius of the AVN, in millimetres. Defaults to 10.

    Returns:
        SimpleITK.Image: A binary mask defining the AVN
    """

    # To speed up binary morphology operations we first crop all images
    template_img = 0 * label_ra
    cb_size, cb_index = label_to_roi(
        (label_la + label_lv + label_ra + label_rv) > 0, expansion_mm=(20, 20, 20)
    )

    label_la = crop_to_roi(label_la, cb_size, cb_index)
    label_lv = crop_to_roi(label_lv, cb_size, cb_index)
    label_ra = crop_to_roi(label_ra, cb_size, cb_index)
    label_rv = crop_to_roi(label_rv, cb_size, cb_index)

    # First, find the most inferior slice of the left atrium
    arr_la = sitk.GetArrayFromImage(label_la)
    inf_limit_la = np.min(np.where(arr_la)[0])

    # Now progress 1cm in the superior direction
    # This defines the slice of the AVN centre
    slice_loc = int(inf_limit_la + 10 / label_la.GetSpacing()[2])

    # Create 2D images at this slice location
    label_la_2d = label_la[:, :, slice_loc]
    label_lv_2d = label_lv[:, :, slice_loc]
    label_ra_2d = label_ra[:, :, slice_loc]
    label_rv_2d = label_rv[:, :, slice_loc]

    # We now iteratively erode the structures to ensure they do not overlap
    # This ensures we can measure the closest point without any errors
    # LEFT ATRIUM
    overlap = 1
    erode = 1
    while overlap > 0:
        label_lv_2d = sitk.BinaryErode(label_lv_2d, (erode, erode))
        label_overlap = label_lv_2d & label_la_2d
        overlap = sitk.GetArrayFromImage(label_overlap).sum()
        erode += 1

    # LEFT ATRIUM
    overlap = 0
    erode = 1
    while overlap > 0:
        label_la_2d = sitk.BinaryErode(label_la_2d, (erode, erode))
        label_overlap = label_la_2d & label_ra_2d
        overlap = sitk.GetArrayFromImage(label_overlap).sum()
        erode += 1

    # RIGHT ATRIUM
    overlap = 0
    erode = 1
    while overlap > 0:
        label_ra_2d = sitk.BinaryErode(label_ra_2d, (erode, erode))
        label_overlap = label_ra_2d & label_rv_2d
        overlap = sitk.GetArrayFromImage(label_overlap).sum()
        erode += 1

    # RIGHT VENTRICLE
    overlap = 0
    erode = 1
    while overlap > 0:
        label_rv_2d = sitk.BinaryErode(label_rv_2d, (erode, erode))
        label_overlap = label_rv_2d & label_lv_2d
        overlap = sitk.GetArrayFromImage(label_overlap).sum()
        erode += 1

    # Measure closest points
    y_la, x_la = get_closest_point_2d(label_rv_2d, label_la_2d)
    y_lv, x_lv = get_closest_point_2d(label_ra_2d, label_lv_2d)
    y_ra, x_ra = get_closest_point_2d(label_lv_2d, label_ra_2d)
    y_rv, x_rv = get_closest_point_2d(label_la_2d, label_rv_2d)

    # Take the arithmetic mean
    x_location = np.mean((x_la, x_lv, x_ra, x_rv), dtype=int)
    y_location = np.mean((y_la, y_lv, y_ra, y_rv), dtype=int)

    # Now define the AVN location
    sphere_centre = (slice_loc, y_location, x_location)

    # Generate an image
    label_avn = insert_sphere_image(label_ra * 0, sp_radius=radius_mm, sp_centre=sphere_centre)

    # Finally, paste the label into the original image space
    label_avn = sitk.Paste(
        template_img,
        label_avn,
        label_avn.GetSize(),
        (0, 0, 0),
        cb_index,
    )

    return label_avn