Example #1
0
def run():
    hyperparams.announceTraining()

    g = graph.Graph()
    train_pairs = hyperparams.getTrainDataGen()
    if FLAGS.tds == 'hp':
        tval_pairs = train_pairs.subSampled(5)
    eval_pairs = hyperparams.getEvalDataGen()

    tf_sess = tf.Session()

    forward_passer = hyperparams.getForwardPasser(g, tf_sess)

    def step(sess):
        dp = train_pairs.getRandomDataPoint()
        if FLAGS.augment and FLAGS.tds == 'hp':
            dp = augment.augmentHpatchPair(dp)
        # Feedforward
        fps = [forward_passer(im) for im in dp.im]
        if type(fps[0]) == multiscale.ForwardPass:
            flat = [i.toFlatForwardPass() for i in fps]
        else:
            flat = fps
        matched_indices = system.match(flat)
        inliers = system.getInliers(dp, flat, matched_indices)
        print('%d inliers' % np.count_nonzero(inliers))
        batch, inlier_labels = makeBatchAndInlierLabelsDouble(
            dp, fps, matched_indices, inliers)
        sess.run(g.train_step,
                 feed_dict={
                     g.train_input: batch,
                     g.inlier_mask: inlier_labels
                 })

    def validate(_):
        if FLAGS.tds == 'hp':
            return [
                evaluate.succinctness(eval_pairs, forward_passer, 10),
                evaluate.succinctness(tval_pairs, forward_passer, 10)
            ]
        else:
            return [evaluate.succinctness(eval_pairs, forward_passer, 10)]

    trainer = Trainer(step,
                      validate,
                      hyperparams.checkpointRoot(),
                      hyperparams.shortString(),
                      check_every=FLAGS.val_every,
                      best_index=0,
                      best_is_max=True)
    trainer.trainUpToNumIts(tf_sess, FLAGS.its)
def run():
    hyperparams.announceEval()
    eval_pairs = hyperparams.getEvalDataGen()
    graph, sess = hyperparams.modelFromCheckpoint()
    forward_passer = hyperparams.getForwardPasser(graph, sess)
    fp_cache = system.ForwardPassCache(forward_passer)
    for pair_i in range(len(eval_pairs)):
        print('%d/%d' % (pair_i, len(eval_pairs)))
        pair = eval_pairs[pair_i]
        print(pair.name())
        fps = [fp_cache[im] for im in pair.im]

        assert len(fps[0]) == FLAGS.num_test_pts
        assert len(fps[1]) == FLAGS.num_test_pts
        matched_indices = system.match(fps)
        inliers = system.getInliers(pair, fps, matched_indices)
        matched_fps = [fps[i][matched_indices[i]] for i in [0, 1]]
        evaluate.renderMatching(pair, fps, matched_fps, inliers)
Example #3
0
        print('Writing to %s' % self._out)

        exts = ['.jpg', '.png']
        if FLAGS.ext != '':
            exts.append(FLAGS.ext)
        images = sorted(os.listdir(self._in))
        self._images = [i for i in images if os.path.splitext(i)[1] in exts]

    def __getitem__(self, item):
        return os.path.join(self._in, self._images[item]), \
               os.path.join(self._out, os.path.splitext(self._images[item])[0])

    def __len__(self):
        return len(self._images)


if __name__ == '__main__':
    FLAGS.val_best = True
    print(hyperparams.methodEvalString())
    graph, sess = hyperparams.modelFromCheckpoint()
    forward_passer = hyperparams.getForwardPasser(graph, sess)
    folder_inference = FolderInference()

    for i, in_path_out_root in enumerate(folder_inference):
        print('%d/%d' % (i, len(folder_inference)))
        in_path, out_root = in_path_out_root
        fp = forward_passer(cv2.imread(in_path, cv2.IMREAD_GRAYSCALE))
        np.savetxt(out_root + '.txt', np.vstack((fp.ips_rc, fp.ip_scores)).T,
                   header='row column score', fmt='%5.0f %5.0f %.3f')
        cv2.imwrite(out_root + '_render.png', fp.render())
Example #4
0
def process(spath, irange, seq, fpasser):
    fps = fpasser.parallelForward(
        [cv2.imread(seq.images[i], cv2.IMREAD_GRAYSCALE) for i in irange])
    for i in irange:
        print('%d/%d' % (i, len(seq.images)))
        im = seq.images[i]
        path = os.path.join(spath, '%05d' % i)
        fp = fps[i - irange[0]]
        rendering = fp.render()
        if FLAGS.debug_plot:
            cv2.imshow('render', rendering)
            cv2.waitKey(1)
        cv2.imwrite(path + '.jpg', rendering)
        hkl.dump(fp.hicklable(), open(path + '.hkl', 'w'))


if __name__ == '__main__':
    seqs = hyperparams.getEvalSequences()
    g, sess = hyperparams.modelFromCheckpoint()
    fpasser = hyperparams.getForwardPasser(g, sess)
    for seq in seqs:
        print(seq.name())
        spath = os.path.join(hyperparams.seqFpsPath(), seq.name())
        if not os.path.exists(spath):
            os.makedirs(spath)
        n = len(seq.images)
        for i in range(0, n, FLAGS.fpbs):
            process(spath, range(i, i + FLAGS.fpbs), seq, fpasser)
        process(spath, range(n - (n % FLAGS.fpbs), n), seq, fpasser)