-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
97 lines (72 loc) · 2.65 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
from data import hash_evaluation, multi_evaluation
import cPickle as cp
import time
from sdh import SDH
from dksh import DKSHv2
from ksh import KSH
from lsh import LSH
from loader import Cifar10Loader, Cifar100Loader, NuswideLoader
def dump(filename, obj):
with open(filename, 'wb') as f:
cp.dump(obj, f)
def load(filename):
with open(filename, 'rb') as f:
return cp.load(f)
def eva_checkpoint(algo_name, nbit, li_results):
dump('results_nus/{}_{}_step'.format(algo_name, nbit), li_results)
res = multi_evaluation(li_results)
dump('results_nus/{}_{}_total'.format(algo_name, nbit), res)
if len(li_results) >= 1:
print 'mean: mAP={}, pre2={}'.format(res['map_mean'], res['pre2_mean'])
if len(li_results) >= 2:
print 'std: mAP={}, pre2={}'.format(res['map_std'], res['pre2_std'])
def hash_factory(algo_name, nbits, nlabels, nanchors):
if algo_name == 'SDH':
return SDH(nbits, nanchors, nlabels, RBF)
if algo_name == 'DKSH':
return DKSHv2(nbits, nanchors, nlabels, RBF)
if algo_name == 'KSH':
return KSH(nbits, nanchors, nlabels, RBF)
if algo_name == 'LSH':
return LSH(nbits)
return None
def test(list_algo_name, list_bits, loader):
seeds = [7, 17]#, 37, 47, 67, 97, 107, 127, 137, 157]
for algo_name in list_algo_name:
for nbit in list_bits:
print '======execute {} at bit {}======'.format(algo_name, nbit)
print '====total process round: {}====='.format(len(seeds))
li_results = []
for sd in seeds:
print '\nround #{}...'.format(len(li_results)+1)
traindata, trainlabel, basedata, baselabel, testdata, testlabel = loader.split(sd)
alg = hash_factory(algo_name, nbit, 21, 300)
tic = time.clock()
alg.train(traindata, trainlabel)
toc = time.clock()
print 'time:', toc-tic
H_test = alg.queryhash(testdata)
H_base = alg.queryhash(basedata)
# make labels
#gnd_truth = np.array([y == baselabel for y in testlabel]).astype(np.int8)
gnd_truth = (np.dot(testlabel, baselabel.T) >= 1).astype(np.int8)
print 'testing...'
res = hash_evaluation(H_test, H_base, gnd_truth, len(baselabel), len(baselabel), trn_time=toc-tic)
li_results.append(res)
eva_checkpoint(algo_name, nbit, li_results)
def RBF(X, Y):
lenX = X.shape[0]
lenY = Y.shape[0]
X2 = np.dot(np.sum(X * X, axis=1).reshape((lenX, 1)), np.ones((1, lenY), dtype=np.float32))
Y2 = np.dot(np.ones((lenX, 1), dtype=np.float32), np.sum(Y * Y, axis=1).reshape((1, lenY)))
return np.exp((2*np.dot(X,Y.T) - X2 - Y2)/0.4)
if __name__ == "__main__":
# init random seed
# load data
loader = NuswideLoader()
# load algorithms
list_algo_name = ['LSH']
list_nbits = [32, 64]
# test
test(list_algo_name, list_nbits, loader)