Пример #1
0
""" Load dataset """
data_path = '{}/{}'.format(opt.data_path, 'hico')
image_path = '{}/{}/{}'.format(opt.data_path, 'hico', 'images')
cand_dir = '{}/{}/{}'.format(opt.data_path, 'hico', 'detections')

dset = Hico(data_path, \
            image_path, \
            opt.test_split, \
            cand_dir=cand_dir,\
            thresh_file=opt.thresh_file, \
            add_gt=False, \
            train_mode=False, \
            jittering=False, \
            nms_thresh=opt.nms_thresh)
""" Load the test triplets """
target_triplets = dset.get_zeroshottriplets(
)  # uncomment to eval zeroshot triplets
#target_triplets = dset.visualphrases.words() # uncomment to eval all triplets
""" Keys to analyze """
keys = ['s-sro-o', 's-r-o-sro']
""" Aggregate csv result files (from official HICO eval code) """
# Logger path
logger_path = osp.join(opt.logger_dir, opt.exp_name)

detection_path = parser.get_res_dir(opt, 'detections_' + opt.embedding_type)
res_path = parser.get_res_dir(opt, 'res_' + opt.embedding_type)

for key in keys:
    """ File out : 1 file for AP results : group all zeroshot triplets AP """
    filename_out = osp.join(res_path, 'results_{}_{}_{}_{}.csv'.format(\
                                        opt.cand_test,\
                                        opt.test_split,\
Пример #2
0
# print("*****************************************")

model.load_state_dict(checkpoint['model'], False)
if isinstance(model, torch.nn.DataParallel):
    model = model.module
model.eval()

######################
""" Query triplets """
######################

if opt.train_split == 'trainval' or opt.train_split == 'train':
    triplet_queries = dset.visualphrases.words()
else:
    triplet_queries = dset.get_zeroshottriplets()

###############
""" Analogy """
###############

if opt.use_analogy:

    model.precomp_language_features()  # pre-compute unigram emb
    model.precomp_sim_tables()  # pre-compute similarity tables for speed-up

# Target queries indices
queries_sro, triplet_queries_idx = model.precomp_target_queries(
    triplet_queries)

# Pre-compute language features in joint sro space
Пример #3
0
image_path = '{}/{}/{}'.format(opt.data_path, 'hico', 'images')
cand_dir = '{}/{}/{}'.format(opt.data_path, 'hico', 'detections')

dset = Hico(data_path, \
            image_path, \
            opt.test_split, \
            cand_dir=cand_dir,\
            thresh_file=opt.thresh_file, \
            add_gt=False, \
            train_mode=False, \
            jittering=False, \
            nms_thresh=opt.nms_thresh)
""" Key types """
keys = ['s-r-o', 's-sro-o', 's-r-o-sro']
""" Load the test triplets """
target_triplets = dset.get_zeroshottriplets()
subset = 'zeroshottriplet'

for key in keys:
    """ Load ap results for all triplets """
    filename_in = osp.join(opt.logger_dir, opt.exp_name, 'results_{}_{}_{}_{}_def.csv'.format(\
                                        opt.cand_test,\
                                        opt.test_split,\
                                        opt.epoch_model,\
                                        key))

    with open(filename_in) as f:
        reader = csv.DictReader(f)
        ap_results = [r for r in reader]
    """  Write csv subset of triplets """
    filename_out = osp.join(opt.logger_dir, opt.exp_name, 'results_{}_{}_{}_{}_{}_def.csv'.format(\
Пример #4
0
import os.path as osp
from datasets.hico_api import Hico
import csv
import pickle

# Load vocabulary of triplets
root_path = './data'
data_path  = '{}/{}'.format(root_path, 'hico')
image_path = '{}/{}/{}'.format(root_path, 'hico', 'images')
cand_dir   = '{}/{}/{}'.format(root_path, 'hico', 'detections')

split = 'trainval' #'train','trainval'
dset = Hico(data_path, image_path, split, cand_dir)

# Get set of triplets
triplets_remove = dset.get_zeroshottriplets()

triplet_cat_remove = []
for l in range(len(triplets_remove)):
    triplet_cat_remove.append(dset.visualphrases.word2idx[triplets_remove[l]])

# Build a new set of candidates excluding the triplet categories
cand_positives = pickle.load(open(osp.join(data_path, 'cand_positives_' + split + '.pkl'),'rb'))

idx_keep = []
for j in range(cand_positives.shape[0]):
    if j%100000==0:
        print('Done {}/{}'.format(j, cand_positives.shape[0]))
    im_id = cand_positives[j,0]
    cand_id = cand_positives[j,1]
    # Load the gt label of visualphrase