Ejemplo n.º 1
0
def eval_datasets(model,
                  datasets=('oxford5k', 'paris6k', 'roxford5k', 'rparis6k'),
                  ms=False,
                  tta_gem_p=1.0,
                  logger=None):
    model = model.eval()

    data_root = os.path.join(get_root(), 'cirtorch')
    scales = [1 / 2**(1 / 2), 1.0, 2**(1 / 2)] if ms else [1.0]
    results = dict()

    for dataset in datasets:

        # prepare config structure for the test dataset
        cfg = configdataset(dataset, os.path.join(data_root, 'test'))
        images = [cfg['im_fname'](cfg, i) for i in range(cfg['n'])]
        qimages = [cfg['qim_fname'](cfg, i) for i in range(cfg['nq'])]
        bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])]
        tqdm_desc = cfg['dataset']

        db_feats = extract_vectors(model,
                                   images=images,
                                   bbxs=None,
                                   scales=scales,
                                   tta_gem_p=tta_gem_p,
                                   tqdm_desc=tqdm_desc)
        query_feats = extract_vectors(model,
                                      images=qimages,
                                      bbxs=bbxs,
                                      scales=scales,
                                      tta_gem_p=tta_gem_p,
                                      tqdm_desc=tqdm_desc)

        scores = np.dot(db_feats, query_feats.T)
        ranks = np.argsort(-scores, axis=0)
        results[dataset] = compute_map_and_print(dataset,
                                                 ranks,
                                                 cfg['gnd'],
                                                 kappas=[1, 5, 10],
                                                 logger=logger)

    return results
Ejemplo n.º 2
0
import numpy as np

import torch
from torch.utils.model_zoo import load_url
from torch.autograd import Variable
from torchvision import transforms

from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
from cirtorch.datasets.datahelpers import cid2filename
from cirtorch.datasets.testdataset import configdataset
from cirtorch.utils.download import download_train, download_test
from cirtorch.utils.whiten import whitenlearn, whitenapply
from cirtorch.utils.evaluate import compute_map_and_print
from cirtorch.utils.general import get_root, get_data_root, htime

data_root = os.path.join(get_root(), 'cirtorch')

PRETRAINED = {
    'retrievalSfM120k-vgg16-gem':
    'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-vgg16-gem-b4dcdc6.pth',
    'retrievalSfM120k-resnet101-gem':
    'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-resnet101-gem-b80fb85.pth',
}

datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k']
whitening_names = ['retrieval-SfM-30k', 'retrieval-SfM-120k']

parser = argparse.ArgumentParser(
    description='PyTorch CNN Image Retrieval Testing')

# network