예제 #1
0
def _load_predictions_and_select_frames(tasks,
                                        tracker_pred_dir,
                                        log_prefix=''):
    '''Loads all predictions of a tracker and takes the subset of frames with ground truth.

    Args:
        tasks -- VideoObjectDict of Tasks.
        tracker_pred_dir -- Directory that contains files video_object.csv

    Returns:
        VideoObjectDict of SparseTimeSeries of frame assessments.
    '''
    preds = oxuva.VideoObjectDict()
    for track_num, vid_obj in enumerate(tasks.keys()):
        vid, obj = vid_obj
        task = tasks[vid_obj]
        track_name = vid + '_' + obj
        log_context = '{}object {}/{} {}'.format(log_prefix, track_num + 1,
                                                 len(tasks), track_name)
        if args.verbose:
            print(log_context, file=sys.stderr)
        pred_file = os.path.join(tracker_pred_dir, '{}.csv'.format(track_name))
        try:
            with open(pred_file, 'r') as fp:
                pred = oxuva.load_predictions_csv(fp)
        except IOError, exc:
            if args.permissive:
                print('warning: exclude track {}: {}'.format(
                    track_name, str(exc)),
                      file=sys.stderr)
            else:
                raise
        pred = oxuva.subset_using_previous_if_missing(
            pred, task.labels.sorted_keys())
        preds[vid_obj] = pred
예제 #2
0
def _dataset_quality(assessments, bootstrap=False, num_trials=10, base_seed=0):
    '''Computes the overall quality of predictions on a dataset.
    The predictions of all tracks are pooled together.

    Args:
        assessments: VideoObjectDict of SparseTimeSeries of frame assessment dicts.
        bootstrap: Include results that involve bootstrap sampling?

    Returns:
        List that contains statistics for each interval.
    '''
    # Compute the total per sequence.
    seq_totals = oxuva.VideoObjectDict({
        vid_obj: oxuva.assessment_sum(assessments[vid_obj].values())
        for vid_obj in assessments.keys()
    })

    quality = _summarize_simple(seq_totals.values())
    if bootstrap:
        quality.update(
            _summarize_bootstrap(seq_totals, num_trials, base_seed=base_seed))
    return quality
예제 #3
0
def _plot_present_absent(assessments,
                         tasks,
                         trackers,
                         iou_threshold,
                         names=None,
                         colors=None,
                         markers=None):
    names = names or {}
    colors = colors or {}
    markers = markers or {}

    # Find subset of tasks that have absent frames.
    subset_all_present = [
        key for key, task in tasks.items()
        if all([label['present'] for t, label in task.labels.items()])
    ]
    subset_any_absent = [
        key for key, task in tasks.items()
        if not all([label['present'] for t, label in task.labels.items()])
    ]

    stats_whole = {
        tracker: _dataset_quality(assessments[tracker][iou_threshold])
        for tracker in trackers
    }
    stats_all_present = {
        tracker: _dataset_quality(
            oxuva.VideoObjectDict({
                vid_obj: assessments[tracker][iou_threshold][vid_obj]
                for vid_obj in subset_all_present
            }))
        for tracker in trackers
    }
    stats_any_absent = {
        tracker: _dataset_quality(
            oxuva.VideoObjectDict({
                vid_obj: assessments[tracker][iou_threshold][vid_obj]
                for vid_obj in subset_any_absent
            }))
        for tracker in trackers
    }

    order = sorted(trackers,
                   key=lambda t: _stats_sort_key(stats_whole[t]),
                   reverse=True)
    max_tpr = max(
        max([stats_all_present[tracker]['TPR'] for tracker in trackers]),
        max([stats_any_absent[tracker]['TPR'] for tracker in trackers]))

    plt.figure(figsize=(args.width_inches, args.height_inches))
    plt.xlabel('TPR (tracks without absent labels)')
    plt.ylabel('TPR (tracks with some absent labels)')
    for tracker in order:
        plt.plot([stats_all_present[tracker]['TPR']],
                 [stats_any_absent[tracker]['TPR']],
                 label=names.get(tracker, tracker),
                 color=colors.get(tracker, None),
                 marker=markers.get(tracker, None),
                 markerfacecolor='none',
                 markeredgewidth=2,
                 clip_on=False)
    plt.xlim(xmin=0, xmax=_ceil_nearest(CLEARANCE * max_tpr, 0.1))
    plt.ylim(ymin=0, ymax=_ceil_nearest(CLEARANCE * max_tpr, 0.1))
    plt.grid(color=GRID_COLOR)
    # Draw a diagonal line.
    plt.plot([0, 1], [0, 1], color=GRID_COLOR, linewidth=1, linestyle='dotted')
    plot_dir = os.path.join('analysis', args.data, args.challenge)
    _ensure_dir_exists(plot_dir)
    base_name = 'present_absent_iou_{}'.format(_float2str_latex(iou_threshold))
    # _save_fig(os.path.join(plot_dir, base_name + '_no_legend.pdf'))
    plt.legend()
    _save_fig(os.path.join(plot_dir, base_name + '.pdf'))
예제 #4
0
def main():
    parser = argparse.ArgumentParser(formatter_class=ARGS_FORMATTER)
    _add_arguments(parser)
    global args
    args = parser.parse_args()
    logging.basicConfig(level=getattr(logging, args.loglevel.upper()))

    dataset_names = _get_datasets(args.data)
    dataset_tasks = {
        dataset: _load_tasks(
            os.path.join(REPO_DIR, 'dataset', 'annotations', dataset + '.csv'))
        for dataset in dataset_names
    }
    # Take union of all datasets.
    tasks = {
        key: task
        for dataset in dataset_names
        for key, task in dataset_tasks[dataset].items()
    }

    tracker_names = _load_tracker_names()
    # Assign colors and markers alphabetically to achieve invariance across plots.
    trackers = sorted(tracker_names.keys(), key=lambda s: s.lower())
    color_list = _generate_colors(len(trackers))
    tracker_colors = dict(zip(trackers, color_list))
    tracker_markers = dict(zip(trackers, itertools.cycle(MARKERS)))

    # Each element preds[tracker] is a VideoObjectDict of SparseTimeSeries of prediction dicts.
    # Only predictions for frames with ground-truth labels are kept.
    # This is much smaller than the predictions for all frames, and is therefore cached.
    predictions = {}
    for dataset in dataset_names:
        for tracker_ind, tracker in enumerate(trackers):
            log_context = 'tracker {}/{} {}'.format(tracker_ind + 1,
                                                    len(trackers), tracker)
            cache_file = os.path.join(dataset, 'predictions',
                                      '{}.pickle'.format(tracker))
            predictions.setdefault(tracker, {}).update(
                oxuva.cache_pickle(os.path.join(args.cache_dir, 'analyze',
                                                cache_file),
                                   lambda: _load_predictions_and_select_frames(
                                       dataset_tasks[dataset],
                                       os.path.join(REPO_DIR, 'predictions',
                                                    dataset, tracker),
                                       log_prefix=log_context + ': '),
                                   ignore_existing=args.ignore_cache))

    assessments = {}
    # Obtain results at different IOU thresholds in order to make axes the same in all graphs.
    # TODO: Is it unsafe to use float (iou) as dictionary key?
    for tracker in trackers:
        assessments[tracker] = {}
        for iou in args.iou_thresholds:
            logger.info(
                'assess predictions of tracker "%s" with IOU threshold %g',
                tracker, iou)
            assessments[tracker][iou] = oxuva.VideoObjectDict({
                track: oxuva.assess_sequence(tasks[track].labels,
                                             predictions[tracker][track], iou)
                for track in tasks
            })

    if args.subcommand == 'table':
        _print_statistics(assessments, trackers, tracker_names)
    elif args.subcommand == 'plot_tpr_tnr':
        _plot_tpr_tnr_overall(assessments, tasks, trackers, tracker_names,
                              tracker_colors, tracker_markers)
    elif args.subcommand == 'plot_tpr_tnr_intervals':
        _plot_tpr_tnr_intervals(assessments, tasks, trackers, tracker_names,
                                tracker_colors, tracker_markers)
    elif args.subcommand == 'plot_tpr_time':
        for iou in args.iou_thresholds:
            _plot_intervals(assessments, tasks, trackers, iou, tracker_names,
                            tracker_colors, tracker_markers)
    elif args.subcommand == 'plot_present_absent':
        for iou in args.iou_thresholds:
            _plot_present_absent(assessments, tasks, trackers, iou,
                                 tracker_names, tracker_colors,
                                 tracker_markers)