コード例 #1
0
def match_two(model, device, config, im_one, im_two, plot_save_path):

    pool_size = int(config['global_params']['num_pcs'])

    model.eval()

    it = input_transform((int(config['feature_extract']['imageresizeH']),
                          int(config['feature_extract']['imageresizeW'])))

    im_one_pil = Image.fromarray(cv2.cvtColor(im_one, cv2.COLOR_BGR2RGB))
    im_two_pil = Image.fromarray(cv2.cvtColor(im_two, cv2.COLOR_BGR2RGB))

    im_one_pil = it(im_one_pil).unsqueeze(0)
    im_two_pil = it(im_two_pil).unsqueeze(0)

    input_data = torch.cat((im_one_pil.to(device), im_two_pil.to(device)), 0)

    tqdm.write('====> Extracting Features')
    with torch.no_grad():
        image_encoding = model.encoder(input_data)

        vlad_local, _ = model.pool(image_encoding)
        # global_feats = get_pca_encoding(model, vlad_global).cpu().numpy()

        local_feats_one = []
        local_feats_two = []
        for this_iter, this_local in enumerate(vlad_local):
            this_local_feats = get_pca_encoding(model, this_local.permute(2, 0, 1).reshape(-1, this_local.size(1))). \
                reshape(this_local.size(2), this_local.size(0), pool_size).permute(1, 2, 0)
            local_feats_one.append(
                torch.transpose(this_local_feats[0, :, :], 0, 1))
            local_feats_two.append(this_local_feats[1, :, :])

    tqdm.write('====> Calculating Keypoint Positions')
    patch_sizes = [
        int(s) for s in config['global_params']['patch_sizes'].split(",")
    ]
    strides = [int(s) for s in config['global_params']['strides'].split(",")]
    patch_weights = np.array(
        config['feature_match']['patchWeights2Use'].split(",")).astype(float)

    all_keypoints = []
    all_indices = []

    tqdm.write('====> Matching Local Features')
    for patch_size, stride in zip(patch_sizes, strides):
        # we currently only provide support for square patches, but this can be easily modified for future works
        keypoints, indices = calc_keypoint_centers_from_patches(
            config['feature_match'], patch_size, patch_size, stride, stride)
        all_keypoints.append(keypoints)
        all_indices.append(indices)

    matcher = PatchMatcher(config['feature_match']['matcher'], patch_sizes,
                           strides, all_keypoints, all_indices)

    scores, inlier_keypoints_one, inlier_keypoints_two = matcher.match(
        local_feats_one, local_feats_two)
    score = -apply_patch_weights(scores, len(patch_sizes), patch_weights)

    print(
        f"Similarity score between the two images is: {score:.5f}. Larger scores indicate better matches."
    )

    if config['feature_match']['matcher'] == 'RANSAC':
        if plot_save_path is not None:
            tqdm.write('====> Plotting Local Features and save them to ' +
                       str(join(plot_save_path, 'patchMatchings.png')))

        # using cv2 for their in-built keypoint correspondence plotting tools
        cv_im_one = cv2.resize(
            im_one, (int(config['feature_extract']['imageresizeW']),
                     int(config['feature_extract']['imageresizeH'])))
        cv_im_two = cv2.resize(
            im_two, (int(config['feature_extract']['imageresizeW']),
                     int(config['feature_extract']['imageresizeH'])))
        # cv2 resize slightly different from torch, but for visualisation only not a big problem

        plot_two(cv_im_one, cv_im_two, inlier_keypoints_one,
                 inlier_keypoints_two, plot_save_path)
コード例 #2
0
def get_clusters(cluster_set, model, encoder_dim, device, opt, config):
    if device.type == 'cuda':
        cuda = True
    else:
        cuda = False
    nDescriptors = 50000
    nPerImage = 100
    nIm = ceil(nDescriptors / nPerImage)

    cluster_sampler = SubsetRandomSampler(
        np.random.choice(len(cluster_set.dbImages), nIm, replace=False))

    cluster_data_loader = DataLoader(
        dataset=ImagesFromList(cluster_set.dbImages,
                               transform=input_transform()),
        num_workers=opt.threads,
        batch_size=int(config['train']['cachebatchsize']),
        shuffle=False,
        pin_memory=cuda,
        sampler=cluster_sampler)

    if not exists(join(opt.cache_path, 'centroids')):
        makedirs(join(opt.cache_path, 'centroids'))

    initcache_clusters = join(
        opt.cache_path, 'centroids', 'vgg16_' + 'mapillary_' +
        config['train']['num_clusters'] + '_desc_cen.hdf5')
    with h5py.File(initcache_clusters, mode='w') as h5_file:
        with torch.no_grad():
            model.eval()
            tqdm.write('====> Extracting Descriptors')
            dbFeat = h5_file.create_dataset("descriptors",
                                            [nDescriptors, encoder_dim],
                                            dtype=np.float32)

            for iteration, (input_data, indices) in enumerate(
                    tqdm(cluster_data_loader, desc='Iter'.rjust(15)), 1):
                input_data = input_data.to(device)
                image_descriptors = model.encoder(input_data).view(
                    input_data.size(0), encoder_dim, -1).permute(0, 2, 1)
                image_descriptors = F.normalize(
                    image_descriptors, p=2,
                    dim=2)  # we L2-norm descriptors before vlad so
                # need to L2-norm here as well

                batchix = (iteration - 1) * int(
                    config['train']['cachebatchsize']) * nPerImage
                for ix in range(image_descriptors.size(0)):
                    # sample different location for each image in batch
                    sample = np.random.choice(image_descriptors.size(1),
                                              nPerImage,
                                              replace=False)
                    startix = batchix + ix * nPerImage
                    dbFeat[startix:startix + nPerImage, :] = image_descriptors[
                        ix, sample, :].detach().cpu().numpy()

                del input_data, image_descriptors

        tqdm.write('====> Clustering..')
        niter = 100
        kmeans = faiss.Kmeans(encoder_dim,
                              int(config['train']['num_clusters']),
                              niter=niter,
                              verbose=False)
        kmeans.train(dbFeat[...])

        tqdm.write('====> Storing centroids ' + str(kmeans.centroids.shape))
        h5_file.create_dataset('centroids', data=kmeans.centroids)
        tqdm.write('====> Done!')
コード例 #3
0
        isParallel = True

    model = model.to(device)

    pool_size = encoder_dim
    if config['global_params']['pooling'].lower() == 'netvlad':
        pool_size *= int(config['global_params']['num_clusters'])

    print('===> Loading PCA dataset(s)')

    exlude_panos_training = not config['train'].getboolean('includepanos')

    pca_train_set = MSLS(opt.dataset_root_dir,
                         mode='test',
                         cities='train',
                         transform=input_transform(),
                         bs=int(config['train']['cachebatchsize']),
                         threads=opt.threads,
                         margin=float(config['train']['margin']),
                         exclude_panos=exlude_panos_training)

    nFeatures = 10000
    if nFeatures > len(pca_train_set.dbImages):
        nFeatures = len(pca_train_set.dbImages)

    sampler = SubsetRandomSampler(
        np.random.choice(len(pca_train_set.dbImages), nFeatures,
                         replace=False))

    data_loader = DataLoader(dataset=ImagesFromList(
        pca_train_set.dbImages, transform=input_transform()),
コード例 #4
0
    model = model.to(device)

    pool_size = encoder_dim
    if config['global_params']['pooling'].lower() == 'netvlad':
        pool_size *= int(config['global_params']['num_clusters'])

    print('===> Loading PCA dataset(s)')

    nFeatures = 10000
    if opt.dataset_choice == 'mapillary':
        exlude_panos_training = not config['train'].getboolean('includepanos')

        pca_train_set = MSLS(opt.dataset_root_dir,
                             mode='test',
                             cities='train',
                             transform=input_transform(),
                             bs=int(config['train']['cachebatchsize']),
                             threads=opt.threads,
                             margin=float(config['train']['margin']),
                             exclude_panos=exlude_panos_training)

        pca_train_images = pca_train_set.dbImages
    elif opt.dataset_choice == 'pitts':
        dataset_file_path = join(PATCHNETVLAD_ROOT_DIR, 'dataset_imagenames',
                                 'pitts30k_imageNames_index.txt')
        pca_train_set = PlaceDataset(None, dataset_file_path,
                                     opt.dataset_root_dir, None,
                                     config['train'])
        pca_train_images = pca_train_set.images
    else:
        raise ValueError('Unknown dataset choice: ' + opt.dataset_choice)
コード例 #5
0
        if opt.cluster_path:
            if isfile(opt.cluster_path):
                if opt.cluster_path != initcache:
                    shutil.copyfile(opt.cluster_path, initcache)
            else:
                raise FileNotFoundError(
                    "=> no cluster data found at '{}'".format(
                        opt.cluster_path))
        else:
            print('===> Finding cluster centroids')

            print('===> Loading dataset(s) for clustering')
            train_dataset = MSLS(opt.dataset_root_dir,
                                 mode='test',
                                 cities='train',
                                 transform=input_transform(),
                                 bs=int(config['train']['cachebatchsize']),
                                 threads=opt.threads,
                                 margin=float(config['train']['margin']))

            model = model.to(device)

            print('===> Calculating descriptors and clusters')
            get_clusters(train_dataset, model, encoder_dim, device, opt,
                         config)

            # a little hacky, but needed to easily run init_params
            model = model.to(device="cpu")

        with h5py.File(initcache, mode='r') as h5:
            clsts = h5.get("centroids")[...]
コード例 #6
0
ファイル: val.py プロジェクト: QVPR/Patch-NetVLAD
def val(eval_set,
        model,
        encoder_dim,
        device,
        opt,
        config,
        writer,
        epoch_num=0,
        write_tboard=False,
        pbar_position=0):
    if device.type == 'cuda':
        cuda = True
    else:
        cuda = False
    eval_set_queries = ImagesFromList(eval_set.qImages,
                                      transform=input_transform())
    eval_set_dbs = ImagesFromList(eval_set.dbImages,
                                  transform=input_transform())
    test_data_loader_queries = DataLoader(
        dataset=eval_set_queries,
        num_workers=opt.threads,
        batch_size=int(config['train']['cachebatchsize']),
        shuffle=False,
        pin_memory=cuda)
    test_data_loader_dbs = DataLoader(dataset=eval_set_dbs,
                                      num_workers=opt.threads,
                                      batch_size=int(
                                          config['train']['cachebatchsize']),
                                      shuffle=False,
                                      pin_memory=cuda)

    model.eval()
    with torch.no_grad():
        tqdm.write('====> Extracting Features')
        pool_size = encoder_dim
        if config['global_params']['pooling'].lower() == 'netvlad':
            pool_size *= int(config['global_params']['num_clusters'])
        qFeat = np.empty((len(eval_set_queries), pool_size), dtype=np.float32)
        dbFeat = np.empty((len(eval_set_dbs), pool_size), dtype=np.float32)

        for feat, test_data_loader in zip(
            [qFeat, dbFeat], [test_data_loader_queries, test_data_loader_dbs]):
            for iteration, (input_data, indices) in \
                    enumerate(tqdm(test_data_loader, position=pbar_position, leave=False, desc='Test Iter'.rjust(15)), 1):
                input_data = input_data.to(device)
                image_encoding = model.encoder(input_data)

                vlad_encoding = model.pool(image_encoding)
                feat[indices.detach().numpy(), :] = vlad_encoding.detach().cpu(
                ).numpy()

                del input_data, image_encoding, vlad_encoding

    del test_data_loader_queries, test_data_loader_dbs

    tqdm.write('====> Building faiss index')
    faiss_index = faiss.IndexFlatL2(pool_size)
    # noinspection PyArgumentList
    faiss_index.add(dbFeat)

    tqdm.write('====> Calculating recall @ N')
    n_values = [1, 5, 10, 20, 50, 100]

    # for each query get those within threshold distance
    gt = eval_set.all_pos_indices

    # any combination of mapillary cities will work as a val set
    qEndPosTot = 0
    dbEndPosTot = 0
    for cityNum, (qEndPos, dbEndPos) in enumerate(
            zip(eval_set.qEndPosList, eval_set.dbEndPosList)):
        faiss_index = faiss.IndexFlatL2(pool_size)
        faiss_index.add(dbFeat[dbEndPosTot:dbEndPosTot + dbEndPos, :])
        _, preds = faiss_index.search(
            qFeat[qEndPosTot:qEndPosTot + qEndPos, :], max(n_values))
        if cityNum == 0:
            predictions = preds
        else:
            predictions = np.vstack((predictions, preds))
        qEndPosTot += qEndPos
        dbEndPosTot += dbEndPos

    correct_at_n = np.zeros(len(n_values))
    # TODO can we do this on the matrix in one go?
    for qIx, pred in enumerate(predictions):
        for i, n in enumerate(n_values):
            # if in top N then also in top NN, where NN > N
            if np.any(np.in1d(pred[:n], gt[qIx])):
                correct_at_n[i:] += 1
                break
    recall_at_n = correct_at_n / len(eval_set.qIdx)

    all_recalls = {}  # make dict for output
    for i, n in enumerate(n_values):
        all_recalls[n] = recall_at_n[i]
        tqdm.write("====> Recall@{}: {:.4f}".format(n, recall_at_n[i]))
        if write_tboard:
            writer.add_scalar('Val/Recall@' + str(n), recall_at_n[i],
                              epoch_num)

    return all_recalls