/
face_align_client_casia_model.py
210 lines (186 loc) · 6.66 KB
/
face_align_client_casia_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
#!/usr/bin/env python
# encoding: utf-8
import socket
from retrieval.load_data import get_files
import numpy as np
import Image
import os
import cPickle as pickle
from sklearn.metrics.pairwise import cosine_similarity
from face_verify_client import send
from scipy.io import loadmat
from config import CNN_FEA_HOST, CNN_FEA_PORT, ALIGN_HOST, ALIGN_PORT
def send2align(src_path):
"""
@params: src_path, the face image path to be detected and aligned
@return: return the aligned face image path, if failed, return fail info, eg
error info and 'no face found'
"""
HOST, PORT = ALIGN_HOST, ALIGN_PORT
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((HOST, PORT))
sock.sendall(src_path)
recv_data = sock.recv(1024)
# print "recv: {0}".format(recv_data)
sock.close()
return recv_data
except Exception, e:
# print "got an error {0}".format(e)
sock.close()
return "error {0}".format(e)
def path2arr(path):
"""
read image from path, return numpy array,
return empty array if failed
"""
try:
img = Image.open(path)
img = img.resize((47, 55))
# return np.array(img, dtype="int")
return img
except:
return None
def valid_img_path(path):
"""
check if the path is valid
return bool
"""
try:
Image.open(path)
return True
except:
return False
def get_cnn_fea(aligned_paths):
"""
transform aligned face images to CNN features
"""
print "Getting cnn fea."
print aligned_paths
fea_host, fea_port = CNN_FEA_HOST, CNN_FEA_PORT
fea_paths = [send(path + ';1', fea_host, fea_port) for path in aligned_paths]
features = []
for path in fea_paths:
mat = loadmat(path)
features.append(np.asarray(mat['data'], dtype='float'))
return features
class DataBase:
def __init__(self, data_path):
self.data_path = data_path
self.batch_size = 1000
self.data = []
self.loadData()
def loadData(self):
file_names = get_files(self.data_path)
data = []
for fn in file_names:
with open(fn, 'r') as f:
data += pickle.load(f)
self.data = data
def __save2database(self, imgIDs, cnn_feas):
"""
save data to database(pickle files)
save style: [(id_1, fea_1), .... , (id_batch_size, fea_batch_size)]
"""
data_folder = self.data_path
batch_size = self.batch_size
def save(count, imgIDs, cnn_feas):
fn = data_folder + count.__str__() + '.pkl'
if os.path.isfile(fn):
with open(fn, 'r') as f:
orgnl_data = pickle.load(f)
left_space = batch_size - len(orgnl_data)
with open(fn, 'w') as f:
if left_space >= len(imgIDs):
pickle.dump(orgnl_data + zip(imgIDs, cnn_feas), f)
else:
pickle.dump(orgnl_data + zip(imgIDs[0:left_space], cnn_feas[0:left_space]), f)
save(count + 1, imgIDs[left_space::], cnn_feas[left_space::])
else:
with open(fn, 'w') as f:
if batch_size >= len(imgIDs):
pickle.dump(zip(imgIDs, cnn_feas), f)
else:
pickle.dump(zip(imgIDs[0:batch_size], cnn_feas[0:batch_size]), f)
save(count + 1, imgIDs[batch_size::], cnn_feas[batch_size::])
file_names = get_files(data_folder)
n_files = len(file_names)
count = n_files if n_files != 0 else 1
save(count, imgIDs, cnn_feas)
def add_faces(self, imgIDs, img_paths):
"""
model: retrieval_face.Model
add faces to database,
return invalid image ids
"""
print "Adding faces."
inval_ids = []
aligned_paths = []
for path, id in zip(img_paths, imgIDs):
aligned_path = send2align(path)
if not valid_img_path(aligned_path):
inval_ids.append(id)
imgIDs.remove(id)
else:
aligned_paths.append(aligned_path)
if len(aligned_paths) != 0:
cnn_feas = get_cnn_fea(aligned_paths)
self.__save2database(imgIDs, cnn_feas)
return inval_ids
def search(self, fn, top_n=10, sim_thresh=None):
"""
retrieval face from database,
return top_n similar faces' imgIDs, return None if failed
"""
print "\n\nsearch...\n\n"
if top_n > len(self.data):
top_n = len(self.data)
aligned_fn = send2align(fn)
if not valid_img_path(aligned_fn):
print "align none."
return None
cnn_fea = get_cnn_fea([aligned_fn])[0]
# print "cnn_fea: {0}".format(cnn_fea[0])
sims = [cosine_similarity(cnn_fea[0], item[1][0])[0][0] for item in self.data]
# print len(self.data), len(sims)
# for i in range(len(sims)):
# print sims[i], self.data[i][0]
sort_index = np.argsort(-np.array(sims))
result = []
print sort_index
if sim_thresh is None:
for index in np.nditer(sort_index):
cur_id = self.data[index][0].split('-')[0]
if cur_id not in result and len(result) < top_n:
result.append(cur_id)
return result
else:
for index in np.nditer(sort_index):
if sims[index] < sim_thresh:
break
cur_id = self.data[index][0].split('-')[0]
if cur_id not in result:
result.append(cur_id)
return result
if __name__ == '__main__':
data_path = './temp/'
dataBase = DataBase(data_path)
test_fns = [
'/home/g206/data/baidu/origin/man/1/58.jpg',
'/home/g206/data/baidu/origin/man/1/53.jpg',
'/home/g206/data/baidu/origin/man/1/51.jpg',
'/home/g206/data/baidu/origin/man/1/57.jpg',
'/home/g206/data/baidu/origin/man/1/14.jpg',
'/home/g206/data/baidu/origin/man/1/16.jpg',
'/home/g206/data/baidu/origin/man/1/34.jpg',
'/home/g206/data/baidu/origin/man/1/25.jpg',
'/home/g206/data/baidu/origin/man/1/22.jpg',
'/home/g206/data/baidu/origin/man/1/21.jpg',
]
test_ids = [i.__str__() + '-1' for i in range(len(test_fns))]
# dataBase.add_faces(test_ids, test_fns)
dataBase.loadData()
search_result = dataBase.search(test_fns[1])
print search_result
search_result = dataBase.search(test_fns[0], sim_thresh=0.0)
print search_result