예제 #1
0
def feature_extract(eval_set, model, device, opt, config):
    if not exists(opt.output_features_dir):
        makedirs(opt.output_features_dir)

    output_local_features_prefix = join(opt.output_features_dir, 'patchfeats')
    output_global_features_filename = join(opt.output_features_dir,
                                           'globalfeats.npy')

    pool_size = int(config['global_params']['num_pcs'])

    test_data_loader = DataLoader(
        dataset=eval_set,
        num_workers=int(config['global_params']['threads']),
        batch_size=int(config['feature_extract']['cacheBatchSize']),
        shuffle=False,
        pin_memory=(not opt.nocuda))

    model.eval()
    with torch.no_grad():
        tqdm.write('====> Extracting Features')
        db_feat = np.empty((len(eval_set), pool_size), dtype=np.float32)

        for iteration, (input_data, indices) in \
                enumerate(tqdm(test_data_loader, position=1, leave=False, desc='Test Iter'.rjust(15)), 1):
            indices_np = indices.detach().numpy()
            input_data = input_data.to(device)
            image_encoding = model.encoder(input_data)
            if config['global_params']['pooling'].lower() == 'patchnetvlad':
                vlad_local, vlad_global = model.pool(image_encoding)

                vlad_global_pca = get_pca_encoding(model, vlad_global)
                db_feat[indices_np, :] = vlad_global_pca.detach().cpu().numpy()

                for this_iter, this_local in enumerate(vlad_local):
                    this_patch_size = model.pool.patch_sizes[this_iter]

                    db_feat_patches = np.empty(
                        (this_local.size(0), pool_size, this_local.size(2)),
                        dtype=np.float32)
                    grid = np.indices((1, this_local.size(0)))
                    this_local_pca = get_pca_encoding(model, this_local.permute(2, 0, 1).reshape(-1, this_local.size(1))).\
                        reshape(this_local.size(2), this_local.size(0), pool_size).permute(1, 2, 0)
                    db_feat_patches[
                        grid, :, :] = this_local_pca.detach().cpu().numpy()

                    for i, val in enumerate(indices_np):
                        image_name = os.path.splitext(
                            os.path.basename(eval_set.images[val]))[0]
                        filename = output_local_features_prefix + '_' + 'psize{}_'.format(
                            this_patch_size) + image_name + '.npy'
                        np.save(filename, db_feat_patches[i, :, :])
            else:
                vlad_global = model.pool(image_encoding)
                vlad_global_pca = get_pca_encoding(model, vlad_global)
                db_feat[indices_np, :] = vlad_global_pca.detach().cpu().numpy()

    np.save(output_global_features_filename, db_feat)
예제 #2
0
def match_two(model, device, opt, config):

    pool_size = int(config['global_params']['num_pcs'])

    model.eval()

    im_one = Image.open(opt.first_im_path)
    im_two = Image.open(opt.second_im_path)

    it = input_transform((int(config['feature_extract']['imageresizeH']),
                          int(config['feature_extract']['imageresizeW'])))

    im_one = it(im_one).unsqueeze(0)
    im_two = it(im_two).unsqueeze(0)

    input_data = torch.cat((im_one.to(device), im_two.to(device)), 0)

    tqdm.write('====> Extracting Features')
    with torch.no_grad():
        image_encoding = model.encoder(input_data)

        vlad_local, _ = model.pool(image_encoding)
        # global_feats = get_pca_encoding(model, vlad_global).cpu().numpy()

        local_feats_one = []
        local_feats_two = []
        for this_iter, this_local in enumerate(vlad_local):
            this_local_feats = get_pca_encoding(model, this_local.permute(2, 0, 1).reshape(-1, this_local.size(1))). \
                reshape(this_local.size(2), this_local.size(0), pool_size).permute(1, 2, 0)
            local_feats_one.append(
                torch.transpose(this_local_feats[0, :, :], 0, 1))
            local_feats_two.append(this_local_feats[1, :, :])

    tqdm.write('====> Calculating Keypoint Positions')
    patch_sizes = [
        int(s) for s in config['global_params']['patch_sizes'].split(",")
    ]
    strides = [int(s) for s in config['global_params']['strides'].split(",")]
    patch_weights = np.array(
        config['feature_match']['patchWeights2Use'].split(",")).astype(float)

    all_keypoints = []
    all_indices = []

    tqdm.write('====> Matching Local Features')
    for patch_size, stride in zip(patch_sizes, strides):
        # we currently only provide support for square patches, but this can be easily modified for future works
        keypoints, indices = calc_keypoint_centers_from_patches(
            config['feature_match'], patch_size, patch_size, stride, stride)
        all_keypoints.append(keypoints)
        all_indices.append(indices)

    matcher = PatchMatcher(config['feature_match']['matcher'], patch_sizes,
                           strides, all_keypoints, all_indices)

    scores, inlier_keypoints_one, inlier_keypoints_two = matcher.match(
        local_feats_one, local_feats_two)
    score = -normalise_func(scores, len(patch_sizes), patch_weights)

    print(
        "Similarity score between the two images is: '{:.5f}'. In this example, a larger score indicates a better match."
        .format(score))

    if config['feature_match']['matcher'] == 'RANSAC':
        tqdm.write('====> Plotting Local Features')

        # using cv2 for their in-built keypoint correspondence plotting tools
        cv_im_one = cv2.imread(opt.first_im_path, -1)
        cv_im_two = cv2.imread(opt.second_im_path, -1)
        cv_im_one = cv2.resize(
            cv_im_one, (int(config['feature_extract']['imageresizeW']),
                        int(config['feature_extract']['imageresizeH'])))
        cv_im_two = cv2.resize(
            cv_im_two, (int(config['feature_extract']['imageresizeW']),
                        int(config['feature_extract']['imageresizeH'])))
        # cv2 resize slightly different from torch, but for visualisation only not a big problem

        plot_two(cv_im_one, cv_im_two, inlier_keypoints_one,
                 inlier_keypoints_two, opt.plot_save_path)