Пример #1
0
def summary_plot(md_list: DetectionMetricDataList,
                 metrics: DetectionMetrics,
                 min_precision: float,
                 min_recall: float,
                 dist_th_tp: float,
                 savepath: str = None) -> None:
    """
    Creates a summary plot with PR and TP curves for each class.
    :param md_list: DetectionMetricDataList instance.
    :param metrics: DetectionMetrics instance.
    :param min_precision: Minimum precision value.
    :param min_recall: Minimum recall value.
    :param dist_th_tp: The distance threshold used to determine matches.
    :param savepath: If given, saves the the rendering here instead of displaying.
    """
    n_classes = len(DETECTION_NAMES)
    _, axes = plt.subplots(nrows=n_classes,
                           ncols=2,
                           figsize=(15, 5 * n_classes))
    for ind, detection_name in enumerate(DETECTION_NAMES):
        title1, title2 = ('Recall vs Precision',
                          'Recall vs Error') if ind == 0 else (None, None)

        ax1 = setup_axis(xlim=1,
                         ylim=1,
                         title=title1,
                         min_precision=min_precision,
                         min_recall=min_recall,
                         ax=axes[ind, 0])
        ax1.set_ylabel('{} \n \n Precision'.format(
            PRETTY_DETECTION_NAMES[detection_name]),
                       size=20)

        ax2 = setup_axis(xlim=1,
                         title=title2,
                         min_recall=min_recall,
                         ax=axes[ind, 1])
        if ind == n_classes - 1:
            ax1.set_xlabel('Recall', size=20)
            ax2.set_xlabel('Recall', size=20)

        class_pr_curve(md_list,
                       metrics,
                       detection_name,
                       min_precision,
                       min_recall,
                       ax=ax1)
        class_tp_curve(md_list,
                       metrics,
                       detection_name,
                       min_recall,
                       dist_th_tp=dist_th_tp,
                       ax=ax2)

    plt.tight_layout()

    if savepath is not None:
        plt.savefig(savepath)
        plt.close()
Пример #2
0
def recall_metric_curve(md_list: TrackingMetricDataList,
                        metric_name: str,
                        min_recall: float,
                        savepath: str = None,
                        ax: Axis = None) -> None:
    """
    Plot the recall versus metric curve for the given metric.
    :param md_list: TrackingMetricDataList instance.
    :param metric_name: The name of the metric to plot.
    :param min_recall: Minimum recall value.
    :param savepath: If given, saves the the rendering here instead of displaying.
    :param ax: Axes onto which to render or None to create a new axis.
    """
    # Setup plot.
    if ax is None:
        _, ax = plt.subplots(1, 1, figsize=(7.5, 5))
    ax = setup_axis(xlabel='Recall', ylabel=metric_name.upper(),
                    xlim=1, ylim=None, min_recall=min_recall, ax=ax, show_spines='bottomleft')

    # Plot the recall vs. precision curve for each detection class.
    for tracking_name, md in md_list.md.items():
        # Get values.
        confidence = md.confidence
        recalls = md.recall_hypo
        values = md.get_metric(metric_name)

        # Filter unachieved recall thresholds.
        valid = np.where(np.logical_not(np.isnan(confidence)))[0]
        if len(valid) == 0:
            continue
        first_valid = valid[0]
        assert not np.isnan(confidence[-1])
        recalls = recalls[first_valid:]
        values = values[first_valid:]

        ax.plot(recalls,
                values,
                label='%s' % PRETTY_TRACKING_NAMES[tracking_name],
                color=TRACKING_COLORS[tracking_name])

    # Scale count statistics and FAF logarithmically.
    if metric_name in ['mt', 'ml', 'faf', 'tp', 'fp', 'fn', 'ids', 'frag']:
        ax.set_yscale('symlog')

    if metric_name in ['amota', 'motar', 'recall', 'mota']:
        # Some metrics have an upper bound of 1.
        ax.set_ylim(0, 1)
    elif metric_name != 'motp':
        # For all other metrics except MOTP we set a lower bound of 0.
        ax.set_ylim(bottom=0)

    ax.legend(loc='upper right', borderaxespad=0)
    plt.tight_layout()
    if savepath is not None:
        plt.savefig(savepath)
        plt.close()
Пример #3
0
def dist_pr_curve(md_list: DetectionMetricDataList,
                  metrics: DetectionMetrics,
                  dist_th: float,
                  min_precision: float,
                  min_recall: float,
                  savepath: str = None) -> None:
    """
    Plot the PR curves for different distance thresholds.
    :param md_list: DetectionMetricDataList instance.
    :param metrics: DetectionMetrics instance.
    :param dist_th: Distance threshold for matching.
    :param min_precision: Minimum precision value.
    :param min_recall: Minimum recall value.
    :param savepath: If given, saves the the rendering here instead of displaying.
    """
    # Prepare axis.
    fig, (ax, lax) = plt.subplots(ncols=2,
                                  gridspec_kw={"width_ratios": [4, 1]},
                                  figsize=(7.5, 5))
    ax = setup_axis(xlabel='Recall',
                    ylabel='Precision',
                    xlim=1,
                    ylim=1,
                    min_precision=min_precision,
                    min_recall=min_recall,
                    ax=ax)

    # Plot the recall vs. precision curve for each detection class.
    data = md_list.get_dist_data(dist_th)
    for md, detection_name in data:
        md = md_list[(detection_name, dist_th)]
        ap = metrics.get_label_ap(detection_name, dist_th)
        ax.plot(md.recall,
                md.precision,
                label='{}: {:.1f}%'.format(
                    PRETTY_DETECTION_NAMES[detection_name], ap * 100),
                color=DETECTION_COLORS[detection_name])
    hx, lx = ax.get_legend_handles_labels()
    lax.legend(hx, lx, borderaxespad=0)
    lax.axis("off")
    plt.tight_layout()
    if savepath is not None:
        plt.savefig(savepath)
        plt.close()
Пример #4
0
def class_pr_curve(md_list: DetectionMetricDataList,
                   metrics: DetectionMetrics,
                   detection_name: str,
                   min_precision: float,
                   min_recall: float,
                   savepath: str = None,
                   ax: Axis = None) -> None:
    """
    Plot a precision recall curve for the specified class.
    :param md_list: DetectionMetricDataList instance.
    :param metrics: DetectionMetrics instance.
    :param detection_name: The detection class.
    :param min_precision:
    :param min_recall: Minimum recall value.
    :param savepath: If given, saves the the rendering here instead of displaying.
    :param ax: Axes onto which to render.
    """
    # Prepare axis.
    if ax is None:
        ax = setup_axis(title=PRETTY_DETECTION_NAMES[detection_name],
                        xlabel='Recall',
                        ylabel='Precision',
                        xlim=1,
                        ylim=1,
                        min_precision=min_precision,
                        min_recall=min_recall)

    # Get recall vs precision values of given class for each distance threshold.
    data = md_list.get_class_data(detection_name)

    # Plot the recall vs. precision curve for each distance threshold.
    for md, dist_th in data:
        md: DetectionMetricData
        ap = metrics.get_label_ap(detection_name, dist_th)
        ax.plot(md.recall,
                md.precision,
                label='Dist. : {}, AP: {:.1f}'.format(dist_th, ap * 100))

    ax.legend(loc='best')
    if savepath is not None:
        plt.savefig(savepath)
        plt.close()
Пример #5
0
def class_tp_curve(md_list: DetectionMetricDataList,
                   metrics: DetectionMetrics,
                   detection_name: str,
                   min_recall: float,
                   dist_th_tp: float,
                   savepath: str = None,
                   ax: Axis = None) -> None:
    """
    Plot the true positive curve for the specified class.
    :param md_list: DetectionMetricDataList instance.
    :param metrics: DetectionMetrics instance.
    :param detection_name:
    :param min_recall: Minimum recall value.
    :param dist_th_tp: The distance threshold used to determine matches.
    :param savepath: If given, saves the the rendering here instead of displaying.
    :param ax: Axes onto which to render.
    """
    # Get metric data for given detection class with tp distance threshold.
    md = md_list[(detection_name, dist_th_tp)]
    min_recall_ind = round(100 * min_recall)
    if min_recall_ind <= md.max_recall_ind:
        # For traffic_cone and barrier only a subset of the metrics are plotted.
        rel_metrics = [
            m for m in TP_METRICS
            if not np.isnan(metrics.get_label_tp(detection_name, m))
        ]
        ylimit = max([
            max(getattr(md, metric)[min_recall_ind:md.max_recall_ind + 1])
            for metric in rel_metrics
        ]) * 1.1
    else:
        ylimit = 1.0

    # Prepare axis.
    if ax is None:
        ax = setup_axis(title=PRETTY_DETECTION_NAMES[detection_name],
                        xlabel='Recall',
                        ylabel='Error',
                        xlim=1,
                        min_recall=min_recall)
    ax.set_ylim(0, ylimit)

    # Plot the recall vs. error curve for each tp metric.
    for metric in TP_METRICS:
        tp = metrics.get_label_tp(detection_name, metric)

        # Plot only if we have valid data.
        if tp is not np.nan and min_recall_ind <= md.max_recall_ind:
            recall, error = md.recall[:md.max_recall_ind + 1], getattr(
                md, metric)[:md.max_recall_ind + 1]
        else:
            recall, error = [], []

        # Change legend based on tp value
        if tp is np.nan:
            label = '{}: n/a'.format(PRETTY_TP_METRICS[metric])
        elif min_recall_ind > md.max_recall_ind:
            label = '{}: nan'.format(PRETTY_TP_METRICS[metric])
        else:
            label = '{}: {:.2f} ({})'.format(PRETTY_TP_METRICS[metric], tp,
                                             TP_METRICS_UNITS[metric])
        ax.plot(recall, error, label=label)
    ax.axvline(x=md.max_recall, linestyle='-.', color=(0, 0, 0, 0.3))
    ax.legend(loc='best')

    if savepath is not None:
        plt.savefig(savepath)
        plt.close()