コード例 #1
0
def run():
    hyperparams.announceEval()
    eval_pairs = hyperparams.getEvalDataGen()
    graph, sess = hyperparams.modelFromCheckpoint()
    forward_passer = hyperparams.getForwardPasser(graph, sess)
    fp_cache = system.ForwardPassCache(forward_passer)
    for pair_i in range(len(eval_pairs)):
        print('%d/%d' % (pair_i, len(eval_pairs)))
        pair = eval_pairs[pair_i]
        print(pair.name())
        fps = [fp_cache[im] for im in pair.im]

        assert len(fps[0]) == FLAGS.num_test_pts
        assert len(fps[1]) == FLAGS.num_test_pts
        matched_indices = system.match(fps)
        inliers = system.getInliers(pair, fps, matched_indices)
        matched_fps = [fps[i][matched_indices[i]] for i in [0, 1]]
        evaluate.renderMatching(pair, fps, matched_fps, inliers)
コード例 #2
0
def plot(label=None, k=10):
    hyperparams.announceEval()
    eval_pairs = hyperparams.getEvalDataGen()
    # Very special case lfnet:
    if FLAGS.baseline == 'lfnet':
        x = [50, 100, 150]
        y = []
        for num_pts in x:
            forward_pass_dict = baselines.parseLFNetOuts(eval_pairs, num_pts)
            success = np.zeros(len(eval_pairs), dtype=bool)
            for pair_i in range(len(eval_pairs)):
                pair = eval_pairs[pair_i]
                folder, a, b = pair.name().split(' ')
                forward_passes = [
                    forward_pass_dict['%s%s' % (folder, i)] for i in [a, b]
                ]

                matched_indices = system.match(forward_passes)
                inliers = system.getInliers(pair, forward_passes,
                                            matched_indices)
                if np.count_nonzero(inliers) >= k:
                    success[pair_i] = True
            y.append(np.mean(success.astype(float)))
        plt.plot(x, y, 'x', label='%s: N/A' % (hyperparams.methodString()))
    else:
        pair_outs = cache_forward_pass.loadOrCalculateOuts()
        if FLAGS.num_scales > 1 and FLAGS.baseline == '':
            fps = [[multiscale.forwardPassFromHicklable(im) for im in pair]
                   for pair in pair_outs]
        else:
            fps = [[system.forwardPassFromHicklable(im) for im in pair]
                   for pair in pair_outs]
        pairs_fps = zip(eval_pairs, fps)
        stats = [
            evaluate.leastNumForKInliers(pair_fps[0], pair_fps[1], k)
            for pair_fps in pairs_fps
        ]
        x, y = evaluate.lessThanCurve(stats)
        auc = evaluate.auc(x, y, 200)

        plt.step(x, y, label='%s: %.2f' % (hyperparams.label(), auc))
        return auc
コード例 #3
0
def plot(do_plot=True):
    hyperparams.announceEval()
    succ, Rerr, terr = cache_forward_pass.loadOrEvaluate()
    assert Rerr is not None
    sx, sy = evaluate.lessThanCurve(succ)
    sauc = evaluate.auc(sx, sy, 200)
    rx, ry = evaluate.lessThanCurve(Rerr)
    if FLAGS.ds == 'eu':
        print('5 degrees')
        rmax = 5
    else:
        rmax = 1
    rauc = evaluate.auc(rx, ry, rmax)
    tx, ty = evaluate.lessThanCurve(terr)
    tauc = evaluate.auc(tx, ty, 1)

    if do_plot:
        plt.step(rx,
                 ry,
                 label='%s R: %.2f' % (hyperparams.methodString(), rauc))
        plt.step(tx,
                 ty,
                 label='%s t: %.2f' % (hyperparams.methodString(), tauc))
    return sauc, rauc, tauc
コード例 #4
0
ファイル: plot_ransac.py プロジェクト: uzh-rpg/sips2_open
def plot(color=plt.get_cmap('tab10').colors[0]):
    n = FLAGS.num_test_pts
    hyperparams.announceEval()
    eval_pairs = hyperparams.getEvalDataGen()
    # Very special case lfnet:
    if FLAGS.baseline == 'lfnet':

        forward_pass_dict = baselines.parseLFNetOuts(eval_pairs, n)
        fps = []
        for pair_i in range(len(eval_pairs)):
            pair = eval_pairs[pair_i]
            folder, a, b = pair.name().split(' ')
            forward_passes = [
                forward_pass_dict['%s%s' % (folder, i)] for i in [a, b]
            ]
            matched_indices = system.match(forward_passes)
            fps.append([forward_passes[i][matched_indices[i]] for i in [0, 1]])
    else:
        pair_outs = cache_forward_pass.loadOrCalculateOuts()
        if FLAGS.num_scales > 1 and FLAGS.baseline == '':
            raise NotImplementedError
        else:
            fps = []
            full_fps = []
            for pair in pair_outs:
                reduced = [
                    system.forwardPassFromHicklable(i).reducedTo(n)
                    for i in pair
                ]
                full_fps.append(reduced)
                matched_indices = system.match(reduced)
                fps.append([reduced[i][matched_indices[i]] for i in [0, 1]])

    pairs_fps = zip(eval_pairs, fps)
    masks_errs = [
        evaluate.p3pMaskAndError(pair_fps[0], pair_fps[1])
        for pair_fps in pairs_fps
    ]

    if FLAGS.baseline == '':
        for mask_e, pair_fps, full in zip(masks_errs, pairs_fps, full_fps):
            mask, rerr, terr = mask_e
            pair, fps = pair_fps
            evaluate.renderMatching(pair, full, fps, mask)

    ninl = np.array([np.count_nonzero(i[0]) for i in masks_errs])
    rerrs = np.array([i[1] for i in masks_errs])
    rerrs[ninl < 10] = np.inf
    terrs = np.array([i[2] for i in masks_errs])
    terrs[ninl < 10] = np.inf

    if FLAGS.baseline != 'sift':
        rlabel, tlabel = hyperparams.label(), None
    else:
        rlabel, tlabel = '%s, rot.' % hyperparams.label(), \
                         '%s, transl.' % hyperparams.label()

    x, y = evaluate.lessThanCurve(rerrs)
    plt.semilogx(x, y, '--', color=color, label=rlabel)
    x, y = evaluate.lessThanCurve(terrs)
    plt.semilogx(x, y, ':', color=color, label=tlabel)
コード例 #5
0
ファイル: render_matching.py プロジェクト: uzh-rpg/sips2_open
import os

import rpg_common_py.geometry

import sips2.cache_forward_pass as cache_forward_pass
import sips2.evaluate as evaluate
import sips2.flags as flags
import sips2.hyperparams as hyperparams
import sips2.multiscale as multiscale
import sips2.system as system

FLAGS = flags.FLAGS


if __name__ == '__main__':
    hyperparams.announceEval()

    eval_pairs = hyperparams.getEvalDataGen()

    if FLAGS.baseline == '':
        graph, sess = hyperparams.modelFromCheckpoint()
        forward_passer = hyperparams.getForwardPasser(graph, sess)

        nmins, Rerrs, terrs = [], [], []
        for pair in eval_pairs:
            fps = [forward_passer(im) for im in pair.im]
            nmin, Rerr, terr = evaluate.leastNumForKInliers(
                pair, fps, 10, save_final=True, get_rt=True)
            nmins.append(nmin)
            Rerrs.append(Rerr)
            terrs.append(terr)