Example #1
0
    def __init__(self, model_path, **config):
        self.model_path = model_path
        # Update config
        self.config = dict_update(getattr(self, 'default_config', {}), config)
        self._init_model()
        if model_path is None:
            print("No pretrained model specified!")
            self.sess = None
        else:
            ext = os.path.splitext(model_path)[1]

            sess_config = tf.compat.v1.ConfigProto()
            sess_config.gpu_options.allow_growth = True

            if ext.find('.pb') == 0:
                graph = load_frozen_model(self.model_path, print_nodes=False)
                self.sess = tf.compat.v1.Session(graph=graph, config=sess_config)
            elif ext.find('.ckpt') == 0:
                self.sess = tf.compat.v1.Session(config=sess_config)
                meta_graph_path = os.path.join(model_path + '.meta')
                if not os.path.exists(meta_graph_path):
                    self._construct_network()
                    recoverer(self.sess, model_path)
                else:
                    recoverer(self.sess, model_path, meta_graph_path)
Example #2
0
def main(argv=None):  # pylint: disable=unused-argument
    """Program entrance."""
    # create sift detector.
    sift_wrapper = SiftWrapper(n_sample=FLAGS.max_kpt_num)
    sift_wrapper.half_sigma = FLAGS.half_sigma
    sift_wrapper.pyr_off = FLAGS.pyr_off
    sift_wrapper.ori_off = FLAGS.ori_off
    sift_wrapper.create()
    # create deep feature extractor.
    graph = load_frozen_model(FLAGS.model_path, print_nodes=False)
    sess = tf.Session(graph=graph)
    # extract deep feature from images.
    deep_feat1, cv_kpts1, img1 = extract_deep_features(sift_wrapper,
                                                       sess,
                                                       FLAGS.img1_path,
                                                       qtz=False)
    deep_feat2, cv_kpts2, img2 = extract_deep_features(sift_wrapper,
                                                       sess,
                                                       FLAGS.img2_path,
                                                       qtz=False)
    # match features by OpenCV brute-force matcher (CPU).
    matcher_wrapper = MatcherWrapper()
    # the ratio criterion is set to 0.89 for GeoDesc as described in the paper.
    deep_good_matches, deep_mask = matcher_wrapper.get_matches(
        deep_feat1,
        deep_feat2,
        cv_kpts1,
        cv_kpts2,
        ratio=0.89,
        cross_check=True,
        info='deep')

    deep_display = matcher_wrapper.draw_matches(img1, cv_kpts1, img2, cv_kpts2,
                                                deep_good_matches, deep_mask)
    # compare with SIFT.
    if FLAGS.cf_sift:
        sift_feat1 = sift_wrapper.compute(img1, cv_kpts1)
        sift_feat2 = sift_wrapper.compute(img2, cv_kpts2)
        sift_good_matches, sift_mask = matcher_wrapper.get_matches(
            sift_feat1,
            sift_feat2,
            cv_kpts1,
            cv_kpts2,
            ratio=0.80,
            cross_check=True,
            info='sift')
        sift_display = matcher_wrapper.draw_matches(img1, cv_kpts1, img2,
                                                    cv_kpts2,
                                                    sift_good_matches,
                                                    sift_mask)
        display = np.concatenate((sift_display, deep_display), axis=0)
    else:
        display = deep_display

    cv2.imshow('display', display)
    cv2.waitKey()

    sess.close()
Example #3
0
    def __init__(self, model_path, **config):
        self.model_path = model_path
        # Update config
        self.config = dict_update(getattr(self, 'default_config', {}), config)
        self._init_model()
        ext = os.path.splitext(model_path)[1]

        sess_config = tf.compat.v1.ConfigProto()
        sess_config.gpu_options.allow_growth = True

        if ext.find('.pb') == 0:
            graph = load_frozen_model(self.model_path, print_nodes=False)
            self.sess = tf.compat.v1.Session(graph=graph, config=sess_config)
        elif ext.find('.ckpt') == 0:
            self._construct_network()
            self.sess = tf.compat.v1.Session(config=sess_config)
            recoverer(self.sess, model_path)
def main(argv=None):  # pylint: disable=unused-argument
    """Program entrance."""
    hpatches_seq_list = open(FLAGS.hpatches_seq_list).read().splitlines()
    n_seq = len(hpatches_seq_list)

    graph = load_frozen_model(FLAGS.model_path, print_nodes=False)
    with tf.Session(graph=graph) as sess:
        for i in range(n_seq):
            if i % 16 == 0:
                print(i, '/', n_seq)
            strs = hpatches_seq_list[i].split('/')
            save_folder = os.path.join(FLAGS.feat_out_path, 'geodesc',
                                       strs[-2])
            if not os.path.exists(save_folder):
                os.makedirs(save_folder)
            seq_data = load_seq(
                os.path.join(FLAGS.hpatches_root, hpatches_seq_list[i]))
            feat = sess.run("squeeze_1:0", feed_dict={"input:0": seq_data})
            csv_path = os.path.join(save_folder,
                                    os.path.splitext(strs[-1])[0] + '.csv')
            np.savetxt(csv_path, feat, delimiter=",", fmt='%.8f')
Example #5
0
def main(argv=None):  # pylint: disable=unused-argument
    """Program entrance."""
    local_model_path = FLAGS.model_prefix + '.pb'
    regional_model_path = FLAGS.regional_model
    aug_model_path = FLAGS.model_prefix + '-aug.pb'

    input_imgs = [FLAGS.img1_path, FLAGS.img2_path]
    num_img = len(input_imgs)
    # detect SIFT keypoints and crop image patches.
    patches, npy_kpts, cv_kpts, sift_desc, imgs = get_kpt_and_patch(input_imgs)
    # extract local features and keypoint matchability.
    local_graph = load_frozen_model(local_model_path, print_nodes=False)
    local_feat = []
    l2_local_feat = []
    kpt_m = []
    with tf.Session(graph=local_graph) as sess:
        for i in patches:
            local_returns = sess.run(
                ["conv6_feat:0", "kpt_mb:0", "l2norm_feat:0"],
                feed_dict={"input:0": np.expand_dims(i, -1)})
            local_feat.append(local_returns[0])
            kpt_m.append(local_returns[1])
            l2_local_feat.append(local_returns[2])
    tf.reset_default_graph()
    # extract regional features.
    regional_graph = load_frozen_model(regional_model_path, print_nodes=False)
    regional_feat = []
    with tf.Session(graph=regional_graph) as sess:
        for i in imgs:
            regional_returns = sess.run(
                "res5c:0", feed_dict={"input:0": np.expand_dims(i, 0)})
            regional_feat.append(regional_returns)
    tf.reset_default_graph()
    # local feature augmentation.
    aug_graph = load_frozen_model(aug_model_path, print_nodes=False)
    aug_feat = []
    with tf.Session(graph=aug_graph) as sess:
        for i in range(num_img):
            aug_returns = sess.run("l2norm:0",
                                   feed_dict={
                                       "input/local_feat:0":
                                       np.expand_dims(local_feat[i], 0),
                                       "input/regional_feat:0":
                                       regional_feat[i],
                                       "input/kpt_m:0":
                                       np.expand_dims(kpt_m[i], 0),
                                       "input/kpt_xy:0":
                                       np.expand_dims(npy_kpts[i], 0),
                                   })
            aug_returns = np.squeeze(aug_returns, axis=0)
            aug_feat.append(aug_returns)
    tf.reset_default_graph()
    # feature matching and draw matches.
    matcher = MatcherWrapper()
    sift_match, sift_mask = matcher.get_matches(
        sift_desc[0],
        sift_desc[1],
        cv_kpts[0],
        cv_kpts[1],
        ratio=0.8 if FLAGS.ratio_test else None,
        cross_check=False,
        err_thld=4,
        info='sift')

    base_match, base_mask = matcher.get_matches(
        l2_local_feat[0],
        l2_local_feat[1],
        cv_kpts[0],
        cv_kpts[1],
        ratio=0.89 if FLAGS.ratio_test else None,
        cross_check=False,
        err_thld=4,
        info='base')

    aug_match, aug_mask = matcher.get_matches(
        aug_feat[0],
        aug_feat[1],
        cv_kpts[0],
        cv_kpts[1],
        ratio=0.89 if FLAGS.ratio_test else None,
        cross_check=False,
        err_thld=4,
        info='aug')

    sift_disp = matcher.draw_matches(imgs[0], cv_kpts[0], imgs[1], cv_kpts[1],
                                     sift_match, sift_mask)
    base_disp = matcher.draw_matches(imgs[0], cv_kpts[0], imgs[1], cv_kpts[1],
                                     base_match, base_mask)
    aug_disp = matcher.draw_matches(imgs[0], cv_kpts[0], imgs[1], cv_kpts[1],
                                    aug_match, aug_mask)

    rows, cols = sift_disp.shape[0:2]
    white = (np.ones((rows / 50, cols, 3)) * 255).astype(np.uint8)
    disp = np.concatenate([sift_disp, white, base_disp, white, aug_disp],
                          axis=0)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(disp)
    plt.show()