""" Load dataset """ data_path = '{}/{}'.format(opt.data_path, 'hico') image_path = '{}/{}/{}'.format(opt.data_path, 'hico', 'images') cand_dir = '{}/{}/{}'.format(opt.data_path, 'hico', 'detections') dset = Hico(data_path, \ image_path, \ opt.test_split, \ cand_dir=cand_dir,\ thresh_file=opt.thresh_file, \ add_gt=False, \ train_mode=False, \ jittering=False, \ nms_thresh=opt.nms_thresh) """ Load the test triplets """ target_triplets = dset.get_zeroshottriplets( ) # uncomment to eval zeroshot triplets #target_triplets = dset.visualphrases.words() # uncomment to eval all triplets """ Keys to analyze """ keys = ['s-sro-o', 's-r-o-sro'] """ Aggregate csv result files (from official HICO eval code) """ # Logger path logger_path = osp.join(opt.logger_dir, opt.exp_name) detection_path = parser.get_res_dir(opt, 'detections_' + opt.embedding_type) res_path = parser.get_res_dir(opt, 'res_' + opt.embedding_type) for key in keys: """ File out : 1 file for AP results : group all zeroshot triplets AP """ filename_out = osp.join(res_path, 'results_{}_{}_{}_{}.csv'.format(\ opt.cand_test,\ opt.test_split,\
# print("*****************************************") model.load_state_dict(checkpoint['model'], False) if isinstance(model, torch.nn.DataParallel): model = model.module model.eval() ###################### """ Query triplets """ ###################### if opt.train_split == 'trainval' or opt.train_split == 'train': triplet_queries = dset.visualphrases.words() else: triplet_queries = dset.get_zeroshottriplets() ############### """ Analogy """ ############### if opt.use_analogy: model.precomp_language_features() # pre-compute unigram emb model.precomp_sim_tables() # pre-compute similarity tables for speed-up # Target queries indices queries_sro, triplet_queries_idx = model.precomp_target_queries( triplet_queries) # Pre-compute language features in joint sro space
image_path = '{}/{}/{}'.format(opt.data_path, 'hico', 'images') cand_dir = '{}/{}/{}'.format(opt.data_path, 'hico', 'detections') dset = Hico(data_path, \ image_path, \ opt.test_split, \ cand_dir=cand_dir,\ thresh_file=opt.thresh_file, \ add_gt=False, \ train_mode=False, \ jittering=False, \ nms_thresh=opt.nms_thresh) """ Key types """ keys = ['s-r-o', 's-sro-o', 's-r-o-sro'] """ Load the test triplets """ target_triplets = dset.get_zeroshottriplets() subset = 'zeroshottriplet' for key in keys: """ Load ap results for all triplets """ filename_in = osp.join(opt.logger_dir, opt.exp_name, 'results_{}_{}_{}_{}_def.csv'.format(\ opt.cand_test,\ opt.test_split,\ opt.epoch_model,\ key)) with open(filename_in) as f: reader = csv.DictReader(f) ap_results = [r for r in reader] """ Write csv subset of triplets """ filename_out = osp.join(opt.logger_dir, opt.exp_name, 'results_{}_{}_{}_{}_{}_def.csv'.format(\
import os.path as osp from datasets.hico_api import Hico import csv import pickle # Load vocabulary of triplets root_path = './data' data_path = '{}/{}'.format(root_path, 'hico') image_path = '{}/{}/{}'.format(root_path, 'hico', 'images') cand_dir = '{}/{}/{}'.format(root_path, 'hico', 'detections') split = 'trainval' #'train','trainval' dset = Hico(data_path, image_path, split, cand_dir) # Get set of triplets triplets_remove = dset.get_zeroshottriplets() triplet_cat_remove = [] for l in range(len(triplets_remove)): triplet_cat_remove.append(dset.visualphrases.word2idx[triplets_remove[l]]) # Build a new set of candidates excluding the triplet categories cand_positives = pickle.load(open(osp.join(data_path, 'cand_positives_' + split + '.pkl'),'rb')) idx_keep = [] for j in range(cand_positives.shape[0]): if j%100000==0: print('Done {}/{}'.format(j, cand_positives.shape[0])) im_id = cand_positives[j,0] cand_id = cand_positives[j,1] # Load the gt label of visualphrase