Beispiel #1
0
def save_gif_artery(orig_label_pick_path, bound_pick_path):
    """
    :param orig_label_pick_path: str, path of original segmentation label
    :param bound_pick_path: str, path of boundary detection result
    figures are arranged in the order of
            input | GT_seg | GT_bound | pred_bound
            heatmap[0-256] | heatmap[0-100] | inner bound probmap | outer bound probmap
    besides, we only consider heatmap with range of
            0~(70)~256, namely 0.38438 ~ 0.41066 ~ 0.48048 and
            0~(50)~100, namely 0.38438 ~ 0.40315 ~ 0.42192 respectively
    """

    gif_save_dir = '/'.join(bound_pick_path.split('/')[:-1])
    print("Processing {}".format(gif_save_dir))

    # load original segmentation label
    with open(orig_label_pick_path, 'rb') as reader:
        data_seg = pickle.load(reader)
        labels_seg, start_seg = data_seg["label"], data_seg["start"]

    with open(bound_pick_path, 'rb') as reader:
        data_bound = pickle.load(reader)
        inputs_bound, labels_bound, preds_bound, start_bound, probmaps = \
            data_bound['input'], data_bound['label'], data_bound['pred'], data_bound['start'], data_bound['output']

    assert len(inputs_bound) == len(labels_bound) == len(preds_bound), "inputs, GT and preds should have the " \
                                                                       "same number of slices"
    print(len(inputs_bound), len(labels_seg), start_bound, start_seg)

    scale, rows, cols = 4, 2, 4
    fig = plt.figure(figsize=[scale * cols, scale * rows])
    artery_name = '/'.join(gif_save_dir.split('/')[-2:])

    # add subplots for each figure
    ax1 = fig.add_subplot(rows, cols, 1)
    ax2 = fig.add_subplot(rows, cols, 2)
    ax3 = fig.add_subplot(rows, cols, 3)
    ax4 = fig.add_subplot(rows, cols, 4)
    ax5 = fig.add_subplot(rows, cols, 5)
    ax6 = fig.add_subplot(rows, cols, 6)
    ax7 = fig.add_subplot(rows, cols, 7)
    ax8 = fig.add_subplot(rows, cols, 8)

    # create customed colormap
    top = cm.get_cmap('Reds', 186)
    bottom = cm.get_cmap('Blues', 70)
    newcolors = np.vstack(
        (bottom(np.linspace(1, 0, 70)), top(np.linspace(0, 1, 186))))
    bluered = ListedColormap(newcolors, name='BlueReds')

    labels_seg_cal = labels_seg[(start_bound -
                                 start_seg):]  # seg labels after calibration

    lines = []
    for i in range(len(inputs_bound)):
        input, label_seg, label_bound, pred_bound, probmap = \
            inputs_bound[i], labels_seg_cal[i], labels_bound[i], preds_bound[i], probmaps[i]
        # calculate HDF distance between GT bound and pred bound
        hdf_bound = slicewise_hd95(pred_bound, label_bound, n_classes=3)

        ax1.set_title("{} \n {}".format(artery_name, 'Input'), loc='left')
        ax1.axis('off')
        line1 = ax1.imshow(input, cmap='gray', animated=True)
        line1_text = ax1.text(48,
                              -3,
                              "Slice {}".format(i + start_bound),
                              color='red',
                              fontsize=10)

        ax2.set_title('label_seg')
        ax2.axis('off')
        line2 = ax2.imshow(mask2rgb(label_seg), animated=True)

        ax3.set_title('label_bound')
        ax3.axis('off')
        line3 = ax3.imshow(mask2rgb(label_bound), animated=True)

        ax4.set_title("pred_bound", loc='left')
        ax4.axis('off')
        line4 = ax4.imshow(mask2rgb(pred_bound), animated=True)
        line4_text = ax1.text(400,
                              -3,
                              "Hdf: {:.4f}".format(hdf_bound),
                              color='black',
                              fontsize=10)

        # plot inputs with range [0~256] in colormap
        ax5.set_title("input colormap HU[0~250]")
        ax5.axis('off')
        line5 = ax5.imshow(input,
                           cmap=bluered,
                           vmin=0.38438,
                           vmax=0.48048,
                           animated=True)  # crop HU range 0~255

        # plot inputs with range [0~100] in colormap
        ax6.set_title("input colormap HU[0~100]")
        ax6.axis('off')
        line6 = ax6.imshow(input,
                           cmap=bluered,
                           vmin=0.38438,
                           vmax=0.42192,
                           animated=True)  # crop HU range 0~100

        # inner bound probmap
        ax7.set_title("inner bound probmap")
        ax7.axis('off')
        line7 = ax7.imshow(probmap[1], cmap='seismic',
                           animated=True)  # crop HU range 0~100

        # outer bound probmap
        ax8.set_title("outer bound probmap")
        ax8.axis('off')
        line8 = ax8.imshow(probmap[2], cmap='seismic',
                           animated=True)  # crop HU range 0~100

        lines.append([
            line1, line1_text, line2, line3, line4, line4_text, line5, line6,
            line7, line8
        ])

    # Build the animation using ArtistAnimation function
    ani = animation.ArtistAnimation(fig, lines, interval=50, blit=True)

    # save into gif and mp4 respectively

    # ani.save('{}/artery.gif'.format(gif_save_dir), writer="imagemagick")
    ani.save('{}/artery.mp4'.format(gif_save_dir),
             writer="ffmpeg",
             codec='mpeg4',
             fps=10)


# if __name__ == "__main__":
#     # file_name = 'test_result_2.pickle'
#     # print("file name: {}".format(file_name))
#     # with open(file_name, 'rb') as reader:
#     #     data = pickle.load(reader)
#     # labels = data['GT']  # [N, H, W]
#     # outputs = data['output'] # [N, C, H, W]
#     #
#     # binary_class_slice_wise_pr(labels, outputs, fig_name= 'test_2_binary_pr')
#     # multi_class_slice_wise_pr(labels, outputs, fig_name='test_2_multi_pr_micro')
#     # average_precision(labels, outputs)
#     # path of original annotation
#     orig_label_path = "./PlaqueSegmentation/OrigAnnotation/2d_res_unet_dp_0.001_0.90_0.9_theta-1.0-0.0_100_2_10_dice_Adam_" \
#                       "r-True_flip-True_w-True_rcp-True_tr-False_ns-Falseptr-False_mv-False_sl-False_ds-2_a-0.5_lr-StepLR_" \
#                       "wt-None_o-5_b-False_cal0gt-False_cf-config_dp-0.0_ig-None_w0-10.0_sg-5.0_96_wt-1_mo_False"
#
#     # seg_data_path = "/home/mil/huang/CPR_Segmentation_ver7/PlaqueSegmentation/Experiment23/2d_res_unet_dp_0.0001_0.90_" \
#     #                 "0.9_theta-1.0-0.0_100_100_10_ceb_Adam_r-True_flip-True_w-True_rcp-True_tr-False_ns-Falseptr-False_mv" \
#     #                 "-False_sl-False_ds-2_a-0.5_lr-StepLR_wt-None_o-3_b-False_cal0gt-False_cf-config_dp-0.0_ig-None_w0-10.0_" \
#     #                 "sg-5.0_96_wt-1_mo_False"
#     # bound_data_2d_path = "/home/mil/huang/CPR_Segmentation_ver7/PlaqueBound/Experiment3/2d_res_unet_dp_0.001_0.0_100_100_10" \
#     #                   "_whd_Adam_r-True_flip-True_w-False_ptr-False_mv-False_sl-False_lr-StepLR_wt-None_o-2_b-True_cf-config" \
#     #                   "_dp-0.0_w1-10.0_w2-10.0_sg1-5.0_sg2-5.0_rs-96_wt-2_bt-outer_whda-4_whdb-1"
#     # bound_data_3d_path = "./BoundDetection/Experiment4/3d_res_unet_0.001_100_100_whd_Adam_w-False_sl-True_lr-StepLR_wt-None_o" \
#     #                      "-2_b-True_cf-config_dp-0.0_rs-96_cc-192_wt-2_bt-outer_whda-4_whdb-1_whdr-0.5"
#     # fig_save_dir = "/home/mil/huang/CPR_Segmentation_ver7/PlaqueDetection_20181127/ResultsComparison/seg_bound_comp_debug3"
#
#     # seg_bound_comparison(orig_label_path, seg_data_path, bound_data_2d_path, bound_data_3d_path, fig_save_dir, sample_stack_rows=50)
#
#     bound_data_path = "./BoundDetection/Experiment7/HybridResUNet_ds1int15_0.167"
#     gif_generation(orig_label_path, bound_data_path)
Beispiel #2
0
def sample_wnet(data_list,
                rows=15,
                start_with=0,
                show_every=2,
                scale=4,
                fig_name=None,
                start_inx=0,
                n_class=5,
                width=1):
    """ show segmentation result with bound and corresponding hdf calculated
        plot input, annotation, prediction, bounds and F1 scores
    :param data_list: list, list of data in which each element is a dictionary
    :param start_inx: int, starting slice index for current figure """

    n_probmaps = data_list[0]['bound'].shape[0]  # number of bounds
    cols = 5 + n_probmaps - 1
    n_batch = len(data_list)
    _, ax = plt.subplots(rows, cols, figsize=[scale * cols, scale * rows])

    for ind in range(n_batch):
        input = data_list[ind]['input']
        # print("input shape: {}".format(input.shape))
        label = data_list[ind]['GT']
        pred = data_list[ind]['pred']
        bound_probmap = data_list[ind]['bound']  # predicted bound probmap

        # calculate average F1 score
        label_binary = label_binarize(label.flatten(), classes=range(n_class))
        pred_binary = label_binarize(pred.flatten(), classes=range(n_class))

        f_score = np.zeros(n_class, dtype=np.float32)
        slice_effect_class = 0
        for i in range(n_class):
            if np.sum(label_binary[:, i]) == 0:
                f_score[i] = 0.0
            else:
                slice_effect_class += 1
                f_score[i] = f1_score(label_binary[:, i], pred_binary[:, i])

        ave_f_score = np.sum(f_score) / slice_effect_class

        # calculate average HFD
        label_bound = mask2innerouterbound(label, width=width)
        pred_bound = mask2innerouterbound(pred, width=width)
        hdf = slicewise_hd95(pred_bound, label_bound, n_class)

        if (ind - start_with) % show_every == 0:
            i = (ind - start_with) // show_every
            if i < rows:
                ax[i, 0].imshow(input, cmap='gray')
                ax[i,
                   0].set_title("Slice {} : {}".format(ind + start_inx,
                                                       'input'))
                ax[i, 0].axis('off')

                ax[i, 1].imshow(mask2rgb(label))
                ax[i, 1].set_title('Slice %d : %s' %
                                   (ind + start_inx, 'ground truth'))
                ax[i, 1].axis('off')

                ax[i, 2].imshow(mask2rgb(pred))
                ax[i, 2].set_title('Slice %d : %s' %
                                   (ind + start_inx, 'prediction'))
                ax[i, 2].axis('off')

                # plot overlapping between pred_bound and label_bound
                overlap = pred_bound.copy()
                overlap[label_bound != 0] = 4
                ax[i, 3].imshow(mask2rgb(overlap))
                ax[i, 3].set_title("Slice {:d} : bound hdf={:.4f}".format(
                    ind + start_inx, hdf))
                ax[i, 3].axis('off')

                # plot prob maps for intermediate bounds
                output_title = [
                    'prob map (inner bound)', 'prob map (outer bound)'
                ] if n_probmaps >= 3 else ['prob map']
                for c_inx in range(1, n_probmaps):
                    ax[i, 3 + c_inx].imshow(bound_probmap[c_inx],
                                            cmap='seismic')
                    ax[i, 3 + c_inx].set_title("Slice {:d} : {}".format(
                        ind + start_inx, output_title[c_inx - 1]))
                    ax[i, 3 + c_inx].axis('off')

                ax[i, 3 + n_probmaps].scatter(range(0, n_class), f_score)
                ax[i,
                   3 + n_probmaps].set_title('Slice %d : Ave F-score = %0.2f' %
                                             (ind + start_inx, ave_f_score))
                ax[i, 3 + n_probmaps].set_ylabel('F score')
                ax[i, 3 + n_probmaps].set_ylim([-0.1, 1.1])

    if fig_name:
        plt.savefig(fig_name + '.pdf')
    plt.close()
Beispiel #3
0
def plot_seg_bound_comparison(data_list,
                              rows,
                              start_with,
                              show_every,
                              start_inx,
                              n_class,
                              fig_name=None,
                              width=2,
                              scale=4):
    """ plot result comparison between seg and bound detection """
    cols = 6  # [input, label_seg, label_bound, pred_bound(converted), pred_bound_2d, pred_bound_3d]
    n_batch = len(data_list)
    # print("number of slices: {}".format(n_batch))
    _, ax = plt.subplots(rows, cols, figsize=[scale * cols, scale * rows])

    for ind in range(n_batch):
        input = data_list[ind]['input']
        label_seg = data_list[ind]['GT_seg']
        pred_seg = data_list[ind][
            'pred_seg']  # seg prediction is not plotted here
        pred_bound_conv = mask2outerbound(
            pred_seg, width=width)  # convert seg to inner-outer bound
        label_bound = data_list[ind]['GT_bound']
        pred_bound_2d = data_list[ind]['pred_2d_bound']
        pred_bound_3d = data_list[ind]['pred_3d_bound']
        # print("input: {}, seg: {}, pred_seg: {}, label_bound: {}, pred_bound_2d: {}, pred_bound_3d: {}".format(input.shape,
        #         label_seg.shape, pred_seg.shape, label_bound.shape, pred_bound_2d.shape, pred_bound_3d.shape))
        # print()

        # # calculate average F1 score
        # label_binary = label_binarize(label_seg.flatten(), classes=range(n_class))
        # pred_binary = label_binarize(pred_seg.flatten(), classes=range(n_class))
        #
        # f_score = np.zeros(n_class, dtype=np.float32)
        # slice_effect_class = 0
        # for i in range(n_class):
        #     if np.sum(label_binary[:,i]) == 0:
        #             f_score[i] = 0.0
        #     else:
        #         slice_effect_class += 1
        #         f_score[i] = f1_score(label_binary[:,i], pred_binary[:,i])
        #
        # ave_f_score = np.sum(f_score) / slice_effect_class

        # calculate average HFD
        hdf_seg = slicewise_hd95(pred_bound_conv, label_bound, n_class)
        hdf_bound_2d = slicewise_hd95(pred_bound_2d, label_bound, n_class)
        hdf_bound_3d = slicewise_hd95(pred_bound_3d, label_bound, n_class)

        if (ind - start_with) % show_every == 0:
            i = (ind - start_with) // show_every
            if i < rows:
                ax[i, 0].imshow(input, cmap='gray')
                ax[i,
                   0].set_title("Slice {} : {}".format(ind + start_inx,
                                                       'input'))
                ax[i, 0].axis('off')

                ax[i, 1].imshow(mask2rgb(label_seg))
                ax[i, 1].set_title('Slice %d : %s' %
                                   (ind + start_inx, 'label_seg'))
                ax[i, 1].axis('off')

                label_bound_cp = label_bound.copy()
                label_bound_cp[label_bound != 0] = 4

                ax[i, 2].imshow(mask2rgb(label_bound_cp))
                ax[i, 2].set_title('Slice %d : %s' %
                                   (ind + start_inx, 'label_bound'))
                ax[i, 2].axis('off')

                # plot overlapping between pred_bound_conv and label_bound
                overlap_seg = pred_bound_conv.copy()
                overlap_seg[label_bound != 0] = 4

                ax[i, 3].imshow(mask2rgb(overlap_seg))
                ax[i, 3].set_title(
                    "Slice {:d} : bound from seg (hdf={:.4f})".format(
                        ind + start_inx, hdf_seg))
                ax[i, 3].axis('off')

                overlap_bound_2d = pred_bound_2d.copy()
                overlap_bound_2d[label_bound != 0] = 4
                ax[i, 4].imshow(mask2rgb(overlap_bound_2d))
                ax[i, 4].set_title("Slice {:d} : 2D bound (hdf={:.4f})".format(
                    ind + start_inx, hdf_bound_2d))
                ax[i, 4].axis('off')

                overlap_bound_3d = pred_bound_3d.copy()
                overlap_bound_3d[label_bound != 0] = 4
                ax[i, 5].imshow(mask2rgb(overlap_bound_3d))
                ax[i, 5].set_title("Slice {:d} : 3D bound (hdf={:.4f})".format(
                    ind + start_inx, hdf_bound_3d))
                ax[i, 5].axis('off')

    if fig_name:
        plt.savefig(fig_name + '.pdf')

    plt.close()
Beispiel #4
0
def sample_list_hdf(data_list,
                    rows=15,
                    start_with=0,
                    show_every=2,
                    scale=4,
                    fig_name=None,
                    start_inx=0,
                    n_class=5):
    """ show results as a list with Hausdorff distance calculated from each slice
    Args:
        data_list: list, list of data in which each element is a dictionary
        start_inx: int, starting slice index for current figure
    """

    output_cols = len(
        data_list[0]['output'])  # whether single or multiple channels
    cols = 5 + output_cols - 1

    n_batch = len(data_list)
    _, ax = plt.subplots(rows, cols, figsize=[scale * cols, scale * rows])

    for ind in range(n_batch):
        input = data_list[ind]['input']
        label = data_list[ind]['GT']
        pred = data_list[ind]['pred']
        output = data_list[ind]['output']  # [C, H, W]

        hdf = slicewise_hd95(pred, label, n_class)

        if (ind - start_with) % show_every == 0:
            i = (ind - start_with) // show_every
            if i < rows:
                ax[i, 0].imshow(
                    input,
                    cmap='gray')  # we don't consider multiple inputs here
                ax[i,
                   0].set_title("Slice {} : {}".format(ind + start_inx,
                                                       'input'))
                ax[i, 0].axis('off')

                ax[i, 1].imshow(mask2rgb(label))
                ax[i, 1].set_title('Slice %d : %s' %
                                   (ind + start_inx, 'ground truth'))
                ax[i, 1].axis('off')

                ax[i, 2].imshow(mask2rgb(pred))
                ax[i,
                   2].set_title("Slice {:d} : prediction (hdf={:.4f})".format(
                       ind + start_inx, hdf))
                ax[i, 2].axis('off')

                # plot overlapping between pred ang GT annotation
                overlap = pred.copy()
                overlap[label != 0] = 4
                ax[i, 3].imshow(mask2rgb(overlap))
                ax[i, 3].set_title("Slice {:d} : {}".format(
                    ind + start_inx, 'overlap of GT and pred'))
                ax[i, 3].axis('off')

                # plot prob map for different channels
                # if more than 3 channels, plot all channels which are not equal to 0
                output_title = [
                    'prob map (inner bound)', 'prob map (outer bound)'
                ] if output_cols >= 3 else ['prob map']
                for c_inx in range(1, output_cols):
                    ax[i, 3 + c_inx].imshow(output[c_inx], cmap='seismic')
                    ax[i, 3 + c_inx].set_title("Slice {:d} : {}".format(
                        ind + start_inx, output_title[c_inx - 1]))
                    ax[i, 3 + c_inx].axis('off')

    # plt.show()
    if fig_name:
        plt.savefig(fig_name + '.pdf')
    plt.close()