Exemplo n.º 1
0
def visualize_ced(normed_mean_error_dict,
                  error_threshold,
                  normalized=True,
                  truncated_list=None,
                  display2terminal=True,
                  display_list=None,
                  title='2D PCK curve',
                  debug=True,
                  vis=False,
                  pck_savepath=None,
                  table_savepath=None,
                  closefig=True):
    '''
    visualize the cumulative error distribution curve (alse called NME curve or pck curve)
    all parameters are represented by percentage

    parameter:
        normed_mean_error_dict:     a dictionary whose keys are the method name and values are (N, ) numpy array to represent error in evaluation
        error_threshold:            threshold to display in x axis

    return:
        AUC:                        area under the curve
        MSE:                        mean square error
    '''
    if debug:
        assert isdict(
            normed_mean_error_dict
        ), 'the input normalized mean error dictionary is not correct'
        assert islogical(
            normalized), 'the normalization flag should be logical'
        if normalized:
            assert error_threshold > 0 and error_threshold < 100, 'threshold percentage is not well set'
        if save:
            assert is_path_exists_or_creatable(
                pck_savepath
            ), 'please provide a valid path to save the pck results'
            assert is_path_exists_or_creatable(
                table_savepath
            ), 'please provide a valid path to save the table results'
        assert isstring(title), 'title is not correct'
        if truncated_list is not None:
            assert islistofscalar(
                truncated_list), 'the input truncated list is not correct'
        if display_list is not None:
            assert islist(display_list) and len(
                display_list) == len(normed_mean_error_dict
                                     ), 'the input display list is not correct'
            assert CHECK_EQ_LIST_UNORDERED(
                display_list, normed_mean_error_dict.keys(), debug=debug
            ), 'the input display list does not match the error dictionary key list'
        else:
            display_list = normed_mean_error_dict.keys()

    # set display parameters
    width, height = 1000, 800
    legend_fontsize = 10
    scale_distance = 48.8
    line_index, color_index = 0, 0

    figsize = width / float(dpi), height / float(dpi)
    fig = plt.figure(figsize=figsize)

    # set figure handle
    num_bins = 1000
    if normalized:
        maximum_x = 1
        scale = num_bins / 100
    else:
        maximum_x = error_threshold + 1
        scale = num_bins / maximum_x
    x_axis = np.linspace(
        0, maximum_x,
        num_bins)  # error axis, percentage of normalization factor
    y_axis = np.zeros(num_bins)
    interval_y = 10
    interval_x = 1
    plt.xlim(0, error_threshold)
    plt.ylim(0, 100)
    plt.yticks(np.arange(0, 100 + interval_y, interval_y))
    plt.xticks(np.arange(0, error_threshold + interval_x, interval_x))
    plt.grid()
    plt.title(title, fontsize=20)
    if normalized:
        plt.xlabel('Normalized error euclidean distance (%)', fontsize=16)
    else:
        plt.xlabel('Absolute error euclidean distance', fontsize=16)

    # calculate metrics for each method
    num_methods = len(normed_mean_error_dict)
    num_images = len(normed_mean_error_dict.values()[0])
    metrics_dict = dict()
    metrics_table = list()
    table_title = ['Method Name / Metrics', 'AUC', 'MSE']
    append2title = False
    assert num_images > 0, 'number of error array should be larger than 0'
    for ordered_index in range(num_methods):
        method_name = display_list[ordered_index]
        normed_mean_error = normed_mean_error_dict[method_name]

        if debug:
            assert isnparray(
                normed_mean_error
            ) and normed_mean_error.ndim == 1, 'shape of error distance is not good'
            assert len(
                normed_mean_error
            ) == num_images, 'number of testing images should be equal for all methods'
            assert len(linestyle_set) * len(color_set) >= len(
                normed_mean_error_dict)

        color_tmp = color_set[color_index]
        line_tmp = linestyle_set[line_index]

        for i in range(num_bins):
            y_axis[i] = float(
                (normed_mean_error <
                 x_axis[i]).sum()) / num_images  # percentage of error

        # calculate area under the curve and mean square error
        entry = dict()
        entry['AUC'] = np.sum(y_axis[:error_threshold * scale]) / (
            error_threshold * scale)  # bigger, better
        entry['MSE'] = np.mean(normed_mean_error)  # smaller, better
        metrics_table_tmp = [
            str(method_name),
            '%.2f' % (entry['AUC']),
            '%.1f' % (entry['MSE'])
        ]
        if truncated_list is not None:
            tmse_dict = calculate_truncated_mse(normed_mean_error.tolist(),
                                                truncated_list,
                                                debug=debug)
            for threshold in truncated_list:
                entry['AUC/%s' %
                      threshold] = np.sum(y_axis[:error_threshold * scale]) / (
                          error_threshold * scale)  # bigger, better
                entry['MSE/%s' % threshold] = tmse_dict[threshold]['T-MSE']
                entry['percentage/%s' %
                      threshold] = tmse_dict[threshold]['percentage']

                if not append2title:
                    table_title.append('AUC/%s' % threshold)
                    table_title.append('MSE/%s' % threshold)
                    table_title.append('pct/%s' % threshold)
                metrics_table_tmp.append('%.2f' %
                                         (entry['AUC/%s' % threshold]))
                metrics_table_tmp.append('%.1f' %
                                         (entry['MSE/%s' % threshold]))
                metrics_table_tmp.append(
                    '%.1f' % (100 * entry['percentage/%s' % threshold]) + '%')

        # print metrics_table_tmp
        metrics_table.append(metrics_table_tmp)
        append2title = True
        metrics_dict[method_name] = entry

        # draw
        label = '%s, AUC: %.2f, MSE: %.1f (%.0f um)' % (
            method_name, entry['AUC'], entry['MSE'],
            entry['MSE'] * scale_distance)
        if normalized:
            plt.plot(x_axis * 100,
                     y_axis * 100,
                     color=color_tmp,
                     linestyle=line_tmp,
                     label=label,
                     lw=3)
        else:
            plt.plot(x_axis,
                     y_axis * 100,
                     color=color_tmp,
                     linestyle=line_tmp,
                     label=label,
                     lw=3)
        plt.legend(loc=4, fontsize=legend_fontsize)

        color_index += 1
        if color_index / len(color_set) == 1:
            line_index += 1
            color_index = color_index % len(color_set)

    # plt.grid()
    plt.ylabel('{} Test Images (%)'.format(num_images), fontsize=16)
    save_vis_close_helper(fig=fig,
                          ax=None,
                          vis=vis,
                          transparent=False,
                          save_path=pck_savepath,
                          debug=debug,
                          closefig=closefig)

    # reorder the table
    order_index_list = [
        display_list.index(method_name_tmp)
        for method_name_tmp in normed_mean_error_dict.keys()
    ]
    order_index_list = [0] + [
        order_index_tmp + 1 for order_index_tmp in order_index_list
    ]

    # print table to terminal
    metrics_table = [table_title] + metrics_table
    # metrics_table = list_reorder([table_title] + metrics_table, order_index_list, debug=debug)
    table = AsciiTable(metrics_table)
    if display2terminal:
        print('\nprint detailed metrics')
        print(table.table)

    # save table to file
    if table_savepath is not None:
        table_file = open(table_savepath, 'w')
        table_file.write(table.table)
        table_file.close()
        if display2terminal:
            print('\nsave detailed metrics to %s' % table_savepath)

    return metrics_dict, metrics_table
Exemplo n.º 2
0
def visualize_pts_array(input_pts,
                        color_index=0,
                        pts_size=20,
                        label=False,
                        label_list=None,
                        label_size=20,
                        vis_threshold=0.3,
                        covariance=False,
                        plot_occl=False,
                        xlim=None,
                        ylim=None,
                        fig=None,
                        ax=None,
                        save_path=None,
                        vis=False,
                        warning=True,
                        debug=True,
                        closefig=True):
    '''
    plot keypoints with covariance ellipse

    parameters:
        pts_array:      2(3) x num_pts numpy array, the third channel could be confidence or occlusion
    '''
    # obtain the points
    try:
        pts_array = safe_2dptsarray(input_pts,
                                    homogeneous=True,
                                    warning=warning,
                                    debug=debug)
    except AssertionError:
        pts_array = safe_2dptsarray(input_pts,
                                    homogeneous=False,
                                    warning=warning,
                                    debug=debug)
    if debug:
        assert is2dptsarray(pts_array) or is2dptsarray_occlusion(
            pts_array) or is2dptsarray_confidence(
                pts_array), 'input points are not correct'
    num_pts = pts_array.shape[1]

    # obtain a label list if required but not provided
    if debug: assert islogical(label), 'label flag is not correct'
    if label and (label_list is None):
        label_list = [str(i) for i in xrange(num_pts)]
    if label_list is not None and debug:
        assert islistofstring(label_list), 'labels are not correct'

    # obtain the color index
    if islist(color_index):
        if debug:
            assert not (
                plot_occl or covariance
            ), 'the occlusion or covariance are not compatible with plotting different colors during scattering'
        color_tmp = [color_set_big[index_tmp] for index_tmp in color_index]
    else:
        color_tmp = color_set_big[color_index % len(color_set_big)]

    fig, ax = get_fig_ax_helper(fig=fig, ax=ax)
    std, conf = None, 0.95
    if is2dptsarray(pts_array):  # only 2d points without third rows
        if debug and islist(color_tmp):
            assert len(
                color_tmp
            ) == num_pts, 'number of points to plot is not equal to number of colors provided'
        ax.scatter(pts_array[0, :],
                   pts_array[1, :],
                   color=color_tmp,
                   s=pts_size)
        pts_visible_index = range(pts_array.shape[1])
        pts_ignore_index = []
        pts_invisible_index = []
    else:
        # automatically justify if the third row is confidence or occlusion flag
        num_float_elements = np.where(
            np.logical_and(
                pts_array[2, :] != -1,
                np.logical_and(pts_array[2, :] != 0,
                               pts_array[2, :] != 1)))[0].tolist()
        if len(num_float_elements) > 0: type_3row = 'conf'
        else: type_3row = 'occu'

        if type_3row == 'occu':
            pts_visible_index = np.where(pts_array[
                2, :] == 1)[0].tolist()  # plot visible points in red color
            pts_ignore_index = np.where(pts_array[2, :] == -1)[0].tolist(
            )  # do not plot points with annotation, usually visible, but not annotated
            pts_invisible_index = np.where(pts_array[
                2, :] == 0)[0].tolist()  # plot invisible points in blue color
        else:
            pts_visible_index = np.where(
                pts_array[2, :] > vis_threshold)[0].tolist()
            pts_invisible_index = np.where(
                pts_array[2, :] <= vis_threshold)[0].tolist()
            pts_ignore_index = []

        if debug and islist(color_tmp):
            assert len(color_tmp) == len(
                pts_visible_index
            ), 'number of points to plot is not equal to number of colors provided'
        ax.scatter(pts_array[0, pts_visible_index],
                   pts_array[1, pts_visible_index],
                   color=color_tmp,
                   s=pts_size)
        if plot_occl:
            ax.scatter(pts_array[0, pts_invisible_index],
                       pts_array[1, pts_invisible_index],
                       color=color_set_big[(color_index + 1) %
                                           len(color_set_big)],
                       s=pts_size)
        if covariance:
            visualize_pts_covariance(pts_array[0:2, :],
                                     std=std,
                                     conf=conf,
                                     fig=fig,
                                     ax=ax,
                                     debug=debug,
                                     color=color_tmp)

    if plot_occl: not_plot_index = pts_ignore_index
    else: not_plot_index = pts_ignore_index + pts_invisible_index
    if label_list is not None:
        for pts_index in xrange(num_pts):
            label_tmp = label_list[pts_index]
            if pts_index in not_plot_index: continue
            else:
                # note that the annotation is based on the coordinate instead of the order of plotting the points, so the orider in pts_index does not matter
                if islist(color_index):
                    plt.annotate(
                        label_tmp,
                        xy=(pts_array[0, pts_index], pts_array[1, pts_index]),
                        xytext=(-1, 1),
                        color=color_set_big[(color_index[pts_index] + 5) %
                                            len(color_set_big)],
                        textcoords='offset points',
                        ha='right',
                        va='bottom',
                        fontsize=label_size)
                else:
                    plt.annotate(label_tmp,
                                 xy=(pts_array[0, pts_index],
                                     pts_array[1, pts_index]),
                                 xytext=(-1, 1),
                                 color=color_set_big[(color_index + 5) %
                                                     len(color_set_big)],
                                 textcoords='offset points',
                                 ha='right',
                                 va='bottom',
                                 fontsize=label_size)

    # set axis
    if xlim is not None:
        if debug:
            assert islist(xlim) and len(xlim) == 2, 'the x lim is not correct'
        plt.xlim(xlim[0], xlim[1])
    if ylim is not None:
        if debug:
            assert islist(ylim) and len(ylim) == 2, 'the y lim is not correct'
        plt.ylim(ylim[0], ylim[1])

    return save_vis_close_helper(fig=fig,
                                 ax=ax,
                                 vis=vis,
                                 save_path=save_path,
                                 warning=warning,
                                 debug=debug,
                                 closefig=closefig,
                                 transparent=False)
Exemplo n.º 3
0
def visualize_pts(pts,
                  title=None,
                  fig=None,
                  ax=None,
                  display_range=False,
                  xlim=[-100, 100],
                  ylim=[-100, 100],
                  display_list=None,
                  covariance=False,
                  mse=False,
                  mse_value=None,
                  vis=True,
                  save_path=None,
                  debug=True,
                  closefig=True):
    '''
    visualize point scatter plot

    parameter:
        pts:            2 x num_pts numpy array or a dictionary containing 2 x num_pts numpy array
    '''

    if debug:
        if isdict(pts):
            for pts_tmp in pts.values():
                assert is2dptsarray(
                    pts_tmp
                ), 'input points within dictionary are not correct: (2, num_pts) vs %s' % print_np_shape(
                    pts_tmp)
            if display_list is not None:
                assert islist(display_list) and len(display_list) == len(
                    pts), 'the input display list is not correct'
                assert CHECK_EQ_LIST_UNORDERED(
                    display_list, pts.keys(), debug=debug
                ), 'the input display list does not match the points key list'
            else:
                display_list = pts.keys()
        else:
            assert is2dptsarray(
                pts
            ), 'input points are not correct: (2, num_pts) vs %s' % print_np_shape(
                pts)
        if title is not None: assert isstring(title), 'title is not correct'
        else: title = 'Point Error Vector Distribution Map'
        assert islogical(
            display_range
        ), 'the flag determine if to display in a specific range should be logical value'
        if display_range:
            assert islist(xlim) and islist(ylim) and len(xlim) == 2 and len(
                ylim) == 2, 'the input range for x and y is not correct'
            assert xlim[1] > xlim[0] and ylim[1] > ylim[
                0], 'the input range for x and y is not correct'

    # figure setting
    width, height = 1024, 1024
    fig, _ = get_fig_ax_helper(fig=fig, ax=ax, width=width, height=height)
    if ax is None:
        plt.title(title, fontsize=20)
        if isdict(pts):
            num_pts_all = pts.values()[0].shape[1]
            if all(pts_tmp.shape[1] == num_pts_all
                   for pts_tmp in pts.values()):
                plt.xlabel('x coordinate (%d points)' %
                           pts.values()[0].shape[1],
                           fontsize=16)
                plt.ylabel('y coordinate (%d points)' %
                           pts.values()[0].shape[1],
                           fontsize=16)
            else:
                print('number of points is different across different methods')
                plt.xlabel('x coordinate', fontsize=16)
                plt.ylabel('y coordinate', fontsize=16)
        else:
            plt.xlabel('x coordinate (%d points)' % pts.shape[1], fontsize=16)
            plt.ylabel('y coordinate (%d points)' % pts.shape[1], fontsize=16)
        plt.axis('equal')
        ax = plt.gca()
        ax.grid()

    # internal parameters
    pts_size = 5
    std = None
    conf = 0.98
    color_index = 0
    marker_index = 0
    hatch_index = 0
    alpha = 0.2
    legend_fontsize = 10
    scale_distance = 48.8
    linewidth = 2

    # plot points
    handle_dict = dict()  # for legend
    if isdict(pts):
        num_methods = len(pts)
        assert len(color_set) * len(marker_set) >= num_methods and len(
            color_set
        ) * len(
            hatch_set
        ) >= num_methods, 'color in color set is not enough to use, please use different markers'
        mse_return = dict()
        for method_name, pts_tmp in pts.items():
            color_tmp = color_set[color_index]
            marker_tmp = marker_set[marker_index]
            hatch_tmp = hatch_set[hatch_index]

            # plot covariance ellipse
            if covariance:
                _, covariance_number = visualize_pts_covariance(
                    pts_tmp[0:2, :],
                    std=std,
                    conf=conf,
                    ax=ax,
                    debug=debug,
                    color=color_tmp,
                    hatch=hatch_tmp,
                    linewidth=linewidth)
            handle_tmp = ax.scatter(pts_tmp[0, :],
                                    pts_tmp[1, :],
                                    color=color_tmp,
                                    marker=marker_tmp,
                                    s=pts_size,
                                    alpha=alpha)
            if mse:
                if mse_value is None:
                    num_pts = pts_tmp.shape[1]
                    mse_tmp, _ = pts_euclidean(pts_tmp[0:2, :],
                                               np.zeros((2, num_pts),
                                                        dtype='float32'),
                                               debug=debug)
                else:
                    mse_tmp = mse_value[method_name]
                display_string = '%s, MSE: %.1f (%.1f um), Covariance: %.1f' % (
                    method_name, mse_tmp, mse_tmp * scale_distance,
                    covariance_number)
                mse_return[method_name] = mse_tmp
            else:
                display_string = method_name
            handle_dict[display_string] = handle_tmp
            color_index += 1
            if color_index / len(color_set) == 1:
                marker_index += 1
                hatch_index += 1
                color_index = color_index % len(color_set)

        # reorder the handle before plot
        handle_key_list = handle_dict.keys()
        handle_value_list = handle_dict.values()
        order_index_list = [
            display_list.index(method_name_tmp.split(', ')[0])
            for method_name_tmp in handle_dict.keys()
        ]
        ordered_handle_key_list = list_reorder(handle_key_list,
                                               order_index_list,
                                               debug=debug)
        ordered_handle_value_list = list_reorder(handle_value_list,
                                                 order_index_list,
                                                 debug=debug)
        plt.legend(list2tuple(ordered_handle_value_list),
                   list2tuple(ordered_handle_key_list),
                   scatterpoints=1,
                   markerscale=4,
                   loc='lower left',
                   fontsize=legend_fontsize)

    else:
        color_tmp = color_set[color_index]
        marker_tmp = marker_set[marker_index]
        hatch_tmp = hatch_set[hatch_index]
        handle_tmp = ax.scatter(pts[0, :],
                                pts[1, :],
                                color=color_tmp,
                                marker=marker_tmp,
                                s=pts_size,
                                alpha=alpha)

        # plot covariance ellipse
        if covariance:
            _, covariance_number = visualize_pts_covariance(
                pts[0:2, :],
                std=std,
                conf=conf,
                ax=ax,
                debug=debug,
                color=color_tmp,
                hatch=hatch_tmp,
                linewidth=linewidth)

        if mse:
            if mse_value is None:
                num_pts = pts.shape[1]
                mse_tmp, _ = pts_euclidean(pts[0:2, :],
                                           np.zeros((2, num_pts),
                                                    dtype='float32'),
                                           debug=debug)
                display_string = 'MSE: %.1f (%.1f um), Covariance: %.1f' % (
                    mse_tmp, mse_tmp * scale_distance, covariance_number)
                mse_return = mse_tmp
            else:
                display_string = 'MSE: %.1f (%.1f um), Covariance: %.1f' % (
                    mse_value, mse_value * scale_distance, covariance_number)
                mse_return = mse_value
            handle_dict[display_string] = handle_tmp
            plt.legend(list2tuple(handle_dict.values()),
                       list2tuple(handle_dict.keys()),
                       scatterpoints=1,
                       markerscale=4,
                       loc='lower left',
                       fontsize=legend_fontsize)

    # display only specific range
    if display_range:
        axis_bin = 10 * 2
        interval_x = (xlim[1] - xlim[0]) / axis_bin
        interval_y = (ylim[1] - ylim[0]) / axis_bin
        plt.xlim(xlim[0], xlim[1])
        plt.ylim(ylim[0], ylim[1])
        plt.xticks(np.arange(xlim[0], xlim[1] + interval_x, interval_x))
        plt.yticks(np.arange(ylim[0], ylim[1] + interval_y, interval_y))
    plt.grid()

    save_vis_close_helper(fig=fig,
                          ax=ax,
                          vis=vis,
                          save_path=save_path,
                          warning=warning,
                          debug=debug,
                          closefig=closefig,
                          transparent=False)
    return mse_return
Exemplo n.º 4
0
def load_list_from_folder(folder_path,
                          ext_filter=None,
                          depth=1,
                          recursive=False,
                          sort=True,
                          save_path=None,
                          debug=True):
    '''
    load a list of files or folders from a system path

    parameters:
        folder_path:    root to search 
        ext_filter:     a string to represent the extension of files interested
        depth:          maximum depth of folder to search, when it's None, all levels of folders will be searched
        recursive:      False: only return current level
                        True: return all levels till to the input depth

    outputs:
        fulllist:       a list of elements
        num_elem:       number of the elements
    '''
    folder_path = safepath(folder_path)
    if debug:
        assert isfolder(
            folder_path), 'input folder path is not correct: %s' % folder_path
    if not is_path_exists(folder_path): return [], 0

    if debug:
        assert islogical(
            recursive), 'recursive should be a logical variable: {}'.format(
                recursive)
        assert depth is None or (
            isinteger(depth)
            and depth >= 1), 'input depth is not correct {}'.format(depth)
        assert ext_filter is None or (islist(ext_filter) and all(
            isstring(ext_tmp) for ext_tmp in ext_filter)) or isstring(
                ext_filter), 'extension filter is not correct'
    if isstring(ext_filter):  # convert to a list
        ext_filter = [ext_filter]

    fulllist = list()
    if depth is None:  # find all files recursively
        recursive = True
        wildcard_prefix = '**'
        if ext_filter is not None:
            for ext_tmp in ext_filter:
                wildcard = os.path.join(wildcard_prefix,
                                        '*' + string2ext_filter(ext_tmp))
                curlist = glob2.glob(os.path.join(folder_path, wildcard))
                if sort:
                    curlist = sorted(curlist)
                fulllist += curlist

        else:
            wildcard = wildcard_prefix
            curlist = glob2.glob(os.path.join(folder_path, wildcard))
            if sort:
                curlist = sorted(curlist)
            fulllist += curlist
    else:  # find files based on depth and recursive flag
        wildcard_prefix = '*'
        for index in range(depth - 1):
            wildcard_prefix = os.path.join(wildcard_prefix, '*')
        if ext_filter is not None:
            for ext_tmp in ext_filter:
                wildcard = wildcard_prefix + string2ext_filter(ext_tmp)
                curlist = glob.glob(os.path.join(folder_path, wildcard))
                if sort:
                    curlist = sorted(curlist)
                fulllist += curlist
        else:
            wildcard = wildcard_prefix
            curlist = glob.glob(os.path.join(folder_path, wildcard))
            if sort:
                curlist = sorted(curlist)
            fulllist += curlist
        if recursive and depth > 1:
            newlist, _ = load_list_from_folder(folder_path=folder_path,
                                               ext_filter=ext_filter,
                                               depth=depth - 1,
                                               recursive=True)
            fulllist += newlist

    fulllist = [os.path.normpath(path_tmp) for path_tmp in fulllist]
    num_elem = len(fulllist)

    # save list to a path
    if save_path is not None:
        save_path = safepath(save_path)
        if debug:
            assert is_path_exists_or_creatable(
                save_path), 'the file cannot be created'
        with open(save_path, 'w') as file:
            for item in fulllist:
                file.write('%s\n' % item)
        file.close()

    return fulllist, num_elem