Esempio n. 1
0
def segmentation_and_groundtruth_plot(prediction: np.ndarray, ground_truth: np.ndarray, subject_id: int,
                                      structure_name: str, plane: Plane, output_img_dir: Path, annotator: str = None,
                                      save_fig: bool = True) -> None:
    """
    Plot predicted and the ground truth segmentations. Always plots the middle slice (to match surface distance
    plots), which can sometimes lead to an empty plot.
    :param prediction: 3D volume (X x Y x Z) of predicted segmentation
    :param ground_truth: 3D volume (X x Y x Z) of ground truth segmentation
    :param subject_id: ID of subject for annotating plot
    :param structure_name: Name of structure for annotating plot
    :param plane: The plane to view images in  (axial, sagittal or coronal)
    :param output_img_dir: The dir in which to store the plots
    :param annotator: Optional annotator name for annotating plot
    :param save_fig: If True, saves image. Otherwise displays it.
    :return:
    """
    prediction_contour = extract_border(prediction, connectivity=3)
    gt_contour = extract_border(ground_truth, connectivity=3)

    fig, ax = plt.subplots()

    view_dim, origin = plotting_util.get_view_dim_and_origin(plane)
    midpoint = prediction_contour.shape[view_dim] // 2
    pred_slice = np.take(prediction_contour, midpoint, axis=view_dim)
    gt_slice = np.take(gt_contour, midpoint, axis=view_dim)
    total_pixels = pred_slice + gt_slice

    try:
        bounding_box = plotting_util.get_cropped_axes(total_pixels)
    except IndexError:
        bounding_box = tuple([slice(0, total_pixels.shape[0] + 1), slice(0, total_pixels.shape[1] + 1)])

    ax.imshow(gt_slice[bounding_box], cmap="Greens", origin=origin, aspect='equal')
    ax.imshow(pred_slice[bounding_box], cmap="Reds", origin=origin, alpha=0.7, aspect='equal')

    if annotator:
        annot_str = annotator
    else:
        annot_str = ""

    ax.set_aspect('equal')
    fig.suptitle(f"{subject_id} {structure_name} ground truth and pred - {annot_str}")

    if save_fig:
        figpath = str(Path("outputs/") / output_img_dir
                      / f"{subject_id}_{structure_name}_ground_truth_and_pred_{annot_str}.png")
        print(f"saving to {figpath}")
        resize_and_save(5, 5, figpath)
    else:
        fig.show()
Esempio n. 2
0
def surface_distance_ground_truth_plot(ct: np.ndarray, ground_truth: np.ndarray, sds_full: np.ndarray, subject_id: int,
                                       structure: str, plane: Plane, output_img_dir: Path, dice: float = None,
                                       save_fig: bool = True,
                                       annotator: str = None) -> None:
    """
    Plot surface distances where prediction > 0, with ground truth contour
    :param ct: CT scan
    :param ground_truth: Ground truth segmentation
    :param sds_full: Surface distances (full= where prediction > 0)
    :param subject_id: ID of subject for annotating plot
    :param structure: Name of structure for annotating plot
    :param plane: The plane to view images in  (axial, sagittal or coronal)
    :param output_img_dir: The dir in which to store the plots
    :param dice: Optional dice score for annotating plot
    :param save_fig: If True, saves image. Otherwise displays it.
    :param annotator: Optional annotator name for annotating plot
    :return:
    """
    # get dimension to slice across to get the best 2D view
    view_dim, origin = plotting_util.get_view_dim_and_origin(plane)
    midpoint = ground_truth.shape[view_dim] // 2

    # Take image slices and mask where necessary
    sds_full_slice = np.take(sds_full, midpoint, axis=view_dim)
    total_pixels = sds_full_slice
    # If surface distance array covers everywhere with pred > 0, mask at some threshold else centre of every
    # structure will be red
    masked_sds_full_slice = np.ma.masked_where(sds_full_slice == 0, sds_full_slice)

    gt_contour = extract_border(ground_truth, connectivity=3)
    gt_contour_slice = np.take(gt_contour, midpoint, axis=view_dim)

    total_pixels += gt_contour_slice.astype(float)

    try:
        bounding_box = plotting_util.get_cropped_axes(total_pixels)
    except IndexError:
        bounding_box = tuple([slice(0, total_pixels.shape[0] + 1), slice(0, total_pixels.shape[1] + 1)])

    fig, ax = plt.subplots()
    black_cmap = colors.ListedColormap('black')
    sds_cmap = plt.get_cmap("RdYlGn_r")

    bounds = [0.5, 1, 1.5, 2, 2.5, 3, 3.5]
    sds_norm = colors.BoundaryNorm(bounds, sds_cmap.N)

    # plot pixels outside of border in black
    masked_external_pixels = np.ma.masked_where(ground_truth == 1, ground_truth)
    masked_external_slice = np.take(masked_external_pixels, midpoint, axis=view_dim)

    if ct is not None:
        ct_slice = np.take(ct, midpoint, axis=view_dim)
        ax.imshow(ct_slice[bounding_box], cmap="Greys", origin=origin)
        ax.imshow(masked_external_slice[bounding_box], cmap=black_cmap, origin=origin, alpha=0.7)
    else:
        gt_slice = np.take(ground_truth, midpoint, axis=view_dim)
        ax.imshow(gt_slice[bounding_box], cmap='Greys_r', origin=origin)

    cb = ax.imshow(masked_sds_full_slice[bounding_box], cmap=sds_cmap, norm=sds_norm, origin=origin, alpha=0.7)

    fig.colorbar(cb)

    # Plot title
    dice_str = str(dice) if dice else ""
    annot_str = annotator if annotator else ""
    fig.suptitle(f'{subject_id} {structure} sds - {annot_str}. Dice: {dice_str}')

    # Resize image
    ax.set_aspect('equal')

    if save_fig:
        figpath = Path("outputs") / output_img_dir / f"{int(subject_id):03d}_{structure}_sds2_{annot_str}.png"
        print(f"saving to {str(figpath)}")
        resize_and_save(5, 5, figpath)
    else:
        fig.show()