def run_filter(filter_net, box_filter_net , tree, boxes, w2i, CNN, precision_f1, instance_idx, gold): box_feats,spat_feats = get_box_feats(boxes[instance_idx], CNN) box_rep = torch.cat((box_feats,spat_feats),1) prediction = filter_net([0]*5, box_rep, tree) pred_np = prediction.cpu().data.numpy() if precision_f1 >= 1: precision_f1 = int(precision_f1) f1_pred = np.argsort(-pred_np)[0][:min(precision_f1, pred_np.shape[1])] else: f1_pred = [] for kk,prediction in enumerate(pred_np[0]): if np.exp(prediction) > precision_f1: f1_pred.append(kk) f1_pred = np.array(f1_pred) bf_pred = [] for box_index in f1_pred: box_feats,spat_feats = get_box_feats([boxes[instance_idx][box_index]], CNN) box_rep = torch.cat((box_feats,spat_feats),1) ### assumes box_usage = 0 fir box_filter_net prediction = box_filter_net([w2i.get(node.label,0) for node in tree.leaves()], box_rep, tree).data bf_pred.append(prediction[0][0]) return bf_pred, f1_pred
def run_filter(box_filter_net, tree, boxes, w2i, CNN, gold, instance_idx, precision_f2): bf_pred = [] n_box = len(boxes[instance_idx]) for box_index in xrange(n_box): box_feats, spat_feats = get_box_feats([boxes[instance_idx][box_index]], CNN) box_rep = torch.cat((box_feats, spat_feats), 1) ### assumes box_usage = 0 for box_filter_net prediction = box_filter_net( [w2i.get(node.label, 0) for node in tree.leaves()], box_rep, tree).data bf_pred.append(prediction[0][0]) f1_pred = np.array(range(n_box)) if precision_f2 >= 1: precision_f2 = int(precision_f2) final_prediction = f1_pred[np.argsort(-np.array(bf_pred)) [:min(precision_f2, len(bf_pred))]] else: final_prediction = [] for kk, pred in enumerate(bf_pred): if pred > precision_f2: final_prediction.append(f1_pred[kk]) final_prediction = np.array(final_prediction) return final_prediction, len( set(final_prediction).intersection(set(gold[instance_idx])))
def run_filter(filter_net, tree, boxes, CNN, gold, instance_idx, precision_f1, annid2catid): box_feats, spat_feats = get_box_feats(boxes[instance_idx], CNN) box_rep = torch.cat((box_feats, spat_feats), 1) prediction = filter_net([0] * 5, box_rep, tree) pred_np = prediction.cpu().data.numpy() if precision_f1 >= 1: precision_f1 = int(precision_f1) f1_pred = np.argsort(-pred_np)[0][:min(precision_f1, pred_np.shape[1])] else: f1_pred = [] for kk, prediction in enumerate(pred_np[0]): if np.exp(prediction) > precision_f1: f1_pred.append(kk) f1_pred = np.array(f1_pred) gold_cats = set() for gold_box in gold[instance_idx]: ann_id = str(CNN['meta'][boxes[instance_idx][gold_box]]) gold_cats.add(annid2catid[ann_id]) cat_ids = [ annid2catid[str(ann_id)] for ann_id in CNN['meta'][boxes[instance_idx]][f1_pred] ] filtered_cat_ids = [ idx for idx, cat_id in enumerate(cat_ids) if cat_id in gold_cats ] final_pred = f1_pred[filtered_cat_ids] ### filter by obj category return final_pred, len( set(final_pred).intersection(set(gold[instance_idx])))
def evaluate(net, split, CNN, config, experiment_log, box_usage = 0, verbose = False, tst_json = [], out_file = ''): box_usage = config['box_usage'] model = config['model'] eval_start = time.time() n = correct = 0.0 trees, boxes, iou, gold = split indexes = range(len(trees)) if verbose: pbar = tqdm(indexes) else: pbar = indexes preds = [] net.eval() for j in pbar: tree = trees[j] box_feats,spat_feats = get_box_feats(boxes[j], CNN) if box_usage == 0: box_rep = torch.cat((box_feats,spat_feats),1) elif box_usage == 1: box_rep = box_feats elif box_usage == 2: box_rep = spat_feats else: raise NotImplementedError() prediction = net([w2i.get(node.label,0) for node in tree.leaves()], box_rep, tree).data _,pred = torch.max(prediction,1) hit = (1.0 if pred[0][0] in set(gold[j]) else 0.0) correct += hit n += 1 preds.append(int(pred[0][0])) eval_time = time.time() - eval_start log = [] log.append("Total correct {}/{}".format(correct,n)) log.append("="*20) for line in log: print(line) experiment_log.write("\n".join(line)) if tst_json != []: for ii,inst in enumerate(tst_json): tst_json[ii]['predicted_bounding_boxes'] = [tst_json[ii]['box_names'][preds[ii]]] json.dump(tst_json,open(out_file,'w')) return correct/n , len(trees)/eval_time
def run_filter(filter_net, tree, boxes, CNN, gold, instance_idx, precision_f1): box_feats, spat_feats = get_box_feats(boxes[instance_idx], CNN) box_rep = torch.cat((box_feats, spat_feats), 1) prediction = filter_net([0] * 5, box_rep, tree) pred_np = prediction.cpu().data.numpy() if precision_f1 >= 1: precision_f1 = int(precision_f1) f1_pred = np.argsort(-pred_np)[0][:min(precision_f1, pred_np.shape[1])] else: f1_pred = [] for kk, prediction in enumerate(pred_np[0]): if np.exp(prediction) > precision_f1: f1_pred.append(kk) f1_pred = np.array(f1_pred) return f1_pred, len(set(f1_pred).intersection(set(gold[instance_idx])))
def evaluate(net, split, CNN, config, experiment_log, box_usage = 0, verbose = False, tst_json = [], out_file = '', precision_k = 10, no_lang = False): box_usage = config['box_usage'] model = config['model'] eval_start = time.time() n = correct = 0.0 trees, boxes, iou, gold = split indexes = range(len(trees)) if verbose: pbar = tqdm(indexes) else: pbar = indexes precision_count = [0]*precision_k precision_hit = [0]*precision_k preds = [] net.eval() stats = { 'hit' : defaultdict(int), 'cnt' : defaultdict(int) } all_supporting = [] for j in pbar: tree = trees[j] box_feats,spat_feats = get_box_feats(boxes[j], CNN) if box_usage == 0: box_rep = torch.cat((box_feats,spat_feats),1) elif box_usage == 1: box_rep = box_feats elif box_usage == 2: box_rep = spat_feats else: raise NotImplementedError() prediction = net([0]*5, box_rep, tree) supporting = [] _,pred = torch.max(prediction.data,1) pred_np = prediction.cpu().data.numpy() for idx,k in enumerate(np.argsort(-pred_np)[0][:min(precision_k, pred_np.shape[1])]): p_hit = (1.0 if k in set(gold[j]) else 0.0) if p_hit == 1.0: for jj in xrange(idx,min(precision_k, pred_np.shape[1])): precision_hit[jj] += p_hit precision_count[jj] += 1.0 break precision_hit[idx] += p_hit precision_count[idx] += 1.0 hit = (1.0 if pred[0][0] in set(gold[j]) else 0.0) correct += hit n += 1 if len(tst_json) > 0: preds.append(np.array(tst_json[j]['box_names'])[np.argsort(-pred_np)[0]]) all_supporting.append(supporting) eval_time = time.time() - eval_start print "_"*10 for k in xrange(precision_k): print "precision@{} for {} instances: {:5.3f} and overall {:5.3f}".format(k+1,precision_count[k], precision_hit[k]/precision_count[k], precision_hit[k]/len(trees)) if tst_json != []: for ii,inst in enumerate(tst_json): tst_json[ii]['predicted_bounding_boxes'] = [list(p) for p in preds[ii]] tst_json[ii]['context_box'] = all_supporting[ii] json.dump(tst_json,open(out_file,'w')) return correct/n , len(trees)/eval_time
def evaluate(net, split, CNN, config, experiment_log, box_usage = 0, verbose = False, tst_json = [], out_file = ''): box_usage = config['box_usage'] model = config['model'] eval_start = time.time() n = correct = 0.0 trees, boxes, iou, gold = split indexes = range(len(trees)) if verbose: pbar = tqdm(indexes) else: pbar = indexes preds = [] net.eval() stats = { 'hit' : defaultdict(int), 'cnt' : defaultdict(int) } for j in pbar: tree = trees[j] box_feats,spat_feats = get_box_feats(boxes[j], CNN) if box_usage == 0: box_rep = torch.cat((box_feats,spat_feats),1) elif box_usage == 1: box_rep = box_feats elif box_usage == 2: box_rep = spat_feats else: raise NotImplementedError() if config['model'] in set(["groundnet", "groundnetflexall", "groundnetflexrel","treernn"]): prediction = net(tree, box_rep, tree).data else: prediction = net([w2i.get(node.label,0) for node in tree.leaves()], box_rep, tree).data _,pred = torch.max(prediction,1) hit = (1.0 if pred[0][0] in set(gold[j]) else 0.0) correct += hit complexity = len(tree.nonterms()) stats['hit'][complexity] += hit stats['cnt'][complexity] += 1 n += 1 preds.append(int(pred[0][0])) eval_time = time.time() - eval_start log = [] log.append("") log.append("="*20) log.append("Stats") for complexity in stats['hit'].keys(): log_line = "Complexity {:3d} acc {:5.3f} %{:5.3f}".format(complexity, stats['hit'][complexity]*1.0 / stats['cnt'][complexity], stats['cnt'][complexity]*1.0 / n) log.append(log_line) log_line = "Total correct {}/{}".format(correct,n) log.append(log_line) log.append("="*20) for line in log: print(line) experiment_log.write(line) if tst_json != []: for ii,inst in enumerate(tst_json): tst_json[ii]['predicted_bounding_boxes'] = [tst_json[ii]['box_names'][preds[ii]]] json.dump(tst_json,open(out_file,'w')) return correct/n , len(trees)/eval_time
closs = 0.0 cinst = 0 correct = 0.0 trn_start = time.time() if args.verbose and not args.debug_mode: pbar = tqdm(indexes, desc='trn_loss') else: pbar = indexes done = 1 for ii in pbar: tree = Xtrn_tree[ii] # get cnn feat for boxes box_feats, spat_feats = get_box_feats(Xtrn_box[ii], CNN) if config['box_usage'] == 0: box_rep = torch.cat((box_feats,spat_feats),1) elif config['box_usage'] == 1: box_rep = box_feats elif config['box_usage'] == 2: box_rep = spat_feats else: raise NotImplementedError() if args.debug_mode: raise NotImplementedError() else: if config['model'] in set(["groundnet", "groundnetflexall", "groundnetflexrel","treernn"]): prediction = net(tree, box_rep, tree)
def evaluate(box_filter_net, net, split, CNN, config, precision_f2, experiment_log, box_usage=0, verbose=False, tst_json=[], out_file='', no_lang=False): box_usage = config['box_usage'] model = config['model'] eval_start = time.time() n = correct = 0.0 trees, boxes, iou, gold = split indexes = range(len(trees)) if verbose: pbar = tqdm(indexes) else: pbar = indexes preds = [] net.eval() all_supporting = [] for j in pbar: tree = trees[j] filter_pred, gold_predicted = run_filter(box_filter_net, tree, boxes, w2i, CNN, gold, j, precision_f2) if gold_predicted == 0: n += 1 preds.append(None) all_supporting.append(None) continue gold_instance = [] for g in gold[j]: if g in set(filter_pred): gold_instance.append((filter_pred == g).nonzero()[0][0]) box_feats, spat_feats = get_box_feats( list(np.array(boxes[j])[filter_pred]), CNN) if box_usage == 0: box_rep = torch.cat((box_feats, spat_feats), 1) elif box_usage == 1: box_rep = box_feats elif box_usage == 2: box_rep = spat_feats else: raise NotImplementedError() prediction = net([w2i.get(node.label, 0) for node in tree.leaves()], box_rep, tree) supporting = [] _, pred = torch.max(prediction.data, 1) hit = (1.0 if pred[0][0] in set(gold_instance) else 0.0) correct += hit n += 1 preds.append(filter_pred[int(pred[0][0])]) all_supporting.append(supporting) eval_time = time.time() - eval_start if tst_json != []: for ii, inst in enumerate(tst_json): if preds[ii] == None: tst_json[ii]['predicted_bounding_boxes'] = [[-1, -1, -1, -1]] tst_json[ii]['context_box'] = [-1] continue tst_json[ii]['predicted_bounding_boxes'] = [ tst_json[ii]['box_names'][preds[ii]] ] tst_json[ii]['context_box'] = all_supporting[ii] json.dump(tst_json, open(out_file, 'w')) return correct / n, len(trees) / eval_time
done = 1 for ii in pbar: tree = Xtrn_tree[ii] filter_pred, gold_predicted = run_filter(box_filter_net, tree, Xtrn_box, w2i, CNN, Ytrn, ii, args.precision_f2) if gold_predicted == 0: cinst += 1 continue gold_instance = [] for g in Ytrn[ii]: if g in set(filter_pred): gold_instance.append((filter_pred == g).nonzero()[0][0]) box_feats, spat_feats = get_box_feats( list(np.array(Xtrn_box[ii])[filter_pred]), CNN) if config['box_usage'] == 0: box_rep = torch.cat((box_feats, spat_feats), 1) elif config['box_usage'] == 1: box_rep = box_feats elif config['box_usage'] == 2: box_rep = spat_feats else: raise NotImplementedError() if args.debug_mode: raise NotImplementedError() else: prediction = net([w2i.get(n.label, 0) for n in tree.leaves()], box_rep, tree)
def getCategories(category, trees, boxes, gold_labels, CNN, prediction, cut_off={ 'loc': 5, 'boxes': 20, 'words': 18, 'depth': 3 }, annid2catid={}, annid2imgid={}, imgid2catid2count={}, vocabulary={}, imgid2catid2bsize={}, annid2bsize={}, n_supporting=0): categories = { 'loc': defaultdict(list), 'boxes': defaultdict(list), 'words': defaultdict(list), 'depth': defaultdict(list), 'obj_cat': defaultdict(list), 'obj_box_count': defaultdict(list), 'oov': defaultdict(list), 'box_size': defaultdict(list), 'box_order': defaultdict(list), 'obj_box_order': defaultdict(list), 'box_distance': defaultdict(list), 'n_sup': defaultdict(list), 'sup_acc': defaultdict(list), 'missing': defaultdict(list), 'box2d': defaultdict(list), 'missing_all': defaultdict(list) } if category == 'loc': for i, tree in enumerate(trees): n_loc = Counter([n.label for n in tree.nonterms()])['loc'] categories[category][min(n_loc, cut_off[category])].append(i) elif category == 'box_order': for i, box in enumerate(boxes): box_feats, spat_feats = get_box_feats(boxes[i], CNN, convert_spat=False, convert_box=False) for j in gold_labels[i][:1]: ### bracket = min( list(np.argsort(spat_feats[:, -1])[::-1]).index(j), 9) categories[category][bracket].append(i) elif category == 'box_size': for i, box in enumerate(boxes): box_feats, spat_feats = get_box_feats(boxes[i], CNN, convert_spat=False, convert_box=False) for j in gold_labels[i][:1]: ### bracket = int(min(99, spat_feats[j, -1] * 100)) / 40 categories[category][bracket].append(i) elif category == 'boxes': for i, box in enumerate(boxes): n_box = len(box) categories[category][min(n_box, cut_off[category])].append(i) elif category == 'words': for i, tree in enumerate(trees): n_words = len(tree.original_text.split(" ")) categories[category][min(n_words, cut_off[category])].append(i) elif category == 'depth': for i, tree in enumerate(trees): depth = Counter(tree.getRaw().split(" ")[-1])[')'] categories[category][min(depth, cut_off[category])].append(i) elif category == 'obj_cat': for i, pred in enumerate(prediction): obj_cat = annid2catid[str(pred['annotation_id'])] categories[category][obj_cat].append(i) elif category == 'obj_box_count': for i, pred in enumerate(prediction): imgid = annid2imgid[str(pred['annotation_id'])] catid = annid2catid[str(pred['annotation_id'])] count = imgid2catid2count[imgid][catid] categories[category][count].append(i) elif category == 'obj_box_order': for i, pred in enumerate(prediction): imgid = annid2imgid[str(pred['annotation_id'])] catid = annid2catid[str(pred['annotation_id'])] count = imgid2catid2count[imgid][catid] idx = imgid2catid2bsize[imgid][catid].index(annid2bsize[str( pred['annotation_id'])]) bracket = list(np.argsort( imgid2catid2bsize[imgid][catid])[::-1]).index(idx) categories[category][bracket].append(i) elif category == 'box_distance': for i, box in enumerate(boxes): box_feats, spat_feats = get_box_feats(boxes[i], CNN, convert_spat=False, convert_box=False) distances = ((spat_feats[:, 2] + spat_feats[:, 0]) / 2.0)**2 + ( (spat_feats[:, 3] + spat_feats[:, 1]) / 2.0)**2 for j in gold_labels[i][:1]: ### bracket = int(min(15000, distances[j] * 100)) / 33 categories[category][bracket].append(i) elif category == 'box2d': for i, box in enumerate(boxes): box_feats, spat_feats = get_box_feats(boxes[i], CNN, convert_spat=False, convert_box=False) idx = gold_labels[i][0] density = 25 x1 = int(min(spat_feats[idx, 0] * 100 + 100, 199) / 2) / density x2 = int(min(spat_feats[idx, 2] * 100 + 100, 199) / 2) / density y1 = int(min(spat_feats[idx, 1] * 100 + 100, 199) / 2) / density y2 = int(min(spat_feats[idx, 3] * 100 + 100, 199) / 2) / density for xx in xrange(x1, x2 + 1): for yy in xrange(y1, y2 + 1): categories[category]['x' + str(xx) + 'y' + str(yy)].append(i) elif category == 'oov': for i, tree in enumerate(trees): oov = 0 for n in tree.leaves(): if n.label not in vocabulary: oov += 1 categories[category][oov].append(i) elif category == 'n_sup': n_total = 1.0 n_same = 0.0 n_total_all = 1.0 n_same_all = 1.0 for i in xrange(n_supporting): n_sup = supportingCheck(i, prediction[i]['context_box'], Ytst[i], mode=category) categories[category][n_sup].append(i) pred_box = "_".join( [str(v) for v in prediction[i]['predicted_bounding_boxes'][0]]) boxes = [ "_".join([str(v) for v in box]) for box in prediction[i]['box_names'] ] target_box = boxes.index(pred_box) if n_sup > 0: if len(prediction[i]['context_box'] ) > 0 and prediction[i]['context_box'][0] == target_box: n_same += 1 n_total += 1 n_total_all += 1 for key in categories[category]: print key, len(categories[category][key]) print "=" * 20 print "total {} the same {} percentage {}".format( n_total, n_same, n_same / n_total) print "all {} the same {} percentage {}".format( n_total_all, n_same_all, n_same_all / n_total_all) print "=" * 20 elif category == 'sup_acc': for i in xrange(n_supporting): sup = supportingCheck(i, prediction[i]['context_box'], Ytst[i], mode=category) categories[category][sup].append(i) print "# of instances with no supporting objects:", len( categories[category][0]) total_sup = (len(categories[category][1]) + len( categories[category][2]) + len(categories[category][3])) * 1.0 print "# of instances with supporting objects:", total_sup print "# of instances all supporting objects are correct:", len( categories[category][1]) / total_sup print "# of instances at least one supporting objects is correct:", ( len(categories[category][1]) + len(categories[category][2])) / total_sup elif category == 'missing': for i in xrange(n_supporting): missing = supportingCheck(i, prediction[i]['context_box'], Ytst[i], mode=category) categories[category][missing > 0].append(i) for key in categories[category]: print key, len(categories[category][key]) elif category == 'missing_all': for i in xrange(n_supporting): missing_all = supportingCheck(i, prediction[i]['context_box'], Ytst[i], mode=category) categories[category][missing_all].append(i) for key in categories[category]: print key, len(categories[category][key]) else: raise NotImplementedError() return categories
def evaluate(net, split, CNN, config, experiment_log, box_usage=0, verbose=False): box_usage = config['box_usage'] model = config['model'] eval_start = time.time() n = correct = 0.0 trees, boxes, iou, gold = split indexes = range(len(trees)) if verbose: pbar = tqdm(indexes) else: pbar = indexes preds = [] net.eval() predict_label = {0: [], 1: []} for j in pbar: tree = trees[j] box_feats, spat_feats = get_box_feats(boxes[j], CNN) if box_usage == 0: box_rep = torch.cat((box_feats, spat_feats), 1) elif box_usage == 1: box_rep = box_feats elif box_usage == 2: box_rep = spat_feats else: raise NotImplementedError() prediction = net([w2i.get(node.label, 0) for node in tree.leaves()], box_rep, tree).data pred = int(np.round(prediction[0][0])) hit = (1.0 if pred == gold[j] else 0.0) correct += hit n += 1 predict_label[pred] += [j] preds.append(pred) eval_time = time.time() - eval_start log = [] log.append("") log.append("=" * 20) log_line = "Total correct {}/{}".format(correct, n) log_line = "0 : {}, 1 : {}".format(len(predict_label[0]), len(predict_label[1])) log.append(log_line) log.append("=" * 20) for line in log: print(line) experiment_log.write(line) return correct / n, len(trees) / eval_time
def evaluate(filter_net, net, split, CNN, config, precision_f1, verbose=False, tst_json=[], out_file='', annid2catid={}): box_usage = config['box_usage'] model = config['model'] perturbation = config['perturbation'] eval_start = time.time() n = correct = 0.0 trees, boxes, iou, gold = split indexes = range(len(trees)) if verbose: pbar = tqdm(indexes) else: pbar = indexes preds = [] net.eval() all_supporting = [] for j in pbar: tree = trees[j] filter_pred, gold_predicted = run_filter(filter_net, tree, boxes, CNN, gold, j, precision_f1, annid2catid) if gold_predicted == 0: n += 1 preds.append([[-1, -1, -1, -1]]) all_supporting.append([]) continue gold_instance = [] for g in gold[j]: if g in set(filter_pred): gold_instance.append((filter_pred == g).nonzero()[0][0]) box_feats, spat_feats = get_box_feats( list(np.array(boxes[j])[filter_pred]), CNN) if box_usage == 0: box_rep = torch.cat((box_feats, spat_feats), 1) elif box_usage == 1: box_rep = box_feats elif box_usage == 2: box_rep = spat_feats else: raise NotImplementedError() prediction = net([ w2i.get(w, 0) for w in perturb([node.label for node in tree.leaves()], perturbation) ], box_rep, tree) supporting = [] _, pred = torch.max(prediction.data, 1) hit = (1.0 if pred[0][0] in set(gold_instance) else 0.0) correct += hit n += 1 if len(tst_json) > 0: pred_np = prediction.cpu().data.numpy() preds.append( np.array(tst_json[j]['box_names'])[np.argsort(-pred_np)[0]]) all_supporting.append(supporting) eval_time = time.time() - eval_start if tst_json != []: for ii, inst in enumerate(tst_json): tst_json[ii]['predicted_bounding_boxes'] = [ list(p) for p in preds[ii] ] tst_json[ii]['context_box'] = all_supporting[ii] json.dump(tst_json, open(out_file, 'w')) return correct / n, len(trees) / eval_time
def evaluate(net, split, CNN, config, experiment_log, box_usage=0, verbose=False, tst_json=[], out_file='', precision_k=10): box_usage = config['box_usage'] model = config['model'] eval_start = time.time() n = correct = 0.0 trees, boxes, iou, gold = split indexes = range(len(trees)) if verbose: pbar = tqdm(indexes) else: pbar = indexes precision_count = [0] * precision_k precision_hit = [0] * precision_k preds = [] net.eval() net.evaluate = True stats = {'hit': defaultdict(int), 'cnt': defaultdict(int)} all_supporting = [] for j in pbar: tree = trees[j] box_feats, spat_feats = get_box_feats(boxes[j], CNN) if box_usage == 0: box_rep = torch.cat((box_feats, spat_feats), 1) elif box_usage == 1: box_rep = box_feats elif box_usage == 2: box_rep = spat_feats else: raise NotImplementedError() if "ground" in config['model'].lower(): prediction = net(tree, box_rep, tree, decorate=True) supporting = [] loc_count = 0 for ii, node in enumerate(tree.nonterms()): _, phrase_pred = torch.max(node._expr.data, 1) del node._expr, node._attn if node.label == "loc": loc_count += 1 if node.label == "loc" and loc_count >= 2: supporting += [phrase_pred[0][0]] else: prediction, sub, obj, rel = net( [w2i.get(node.label, 0) for node in tree.leaves()], box_rep, tree) try: _, obj_pred = torch.max(obj.data, 0) supporting = [obj_pred[0][0]] except: supporting = [] _, pred = torch.max(prediction.data, 1) pred_np = prediction.cpu().data.numpy() for idx, k in enumerate( np.argsort(-pred_np)[0][:min(precision_k, pred_np.shape[1])]): p_hit = (1.0 if k in set(gold[j]) else 0.0) if p_hit == 1.0: for jj in xrange(idx, min(precision_k, pred_np.shape[1])): precision_hit[jj] += p_hit precision_count[jj] += 1.0 break precision_hit[idx] += p_hit precision_count[idx] += 1.0 hit = (1.0 if pred[0][0] in set(gold[j]) else 0.0) correct += hit n += 1 preds.append(int(pred[0][0])) all_supporting.append(supporting) eval_time = time.time() - eval_start print "_" * 10 for k in xrange(precision_k): print "precision@{} for {} instances: {}".format( k + 1, precision_count[k], precision_hit[k] / precision_count[k]) if tst_json != []: for ii, inst in enumerate(tst_json): tst_json[ii]['predicted_bounding_boxes'] = [ tst_json[ii]['box_names'][preds[ii]] ] tst_json[ii]['context_box'] = all_supporting[ii] json.dump(tst_json, open(out_file, 'w')) return correct / n, len(trees) / eval_time