class ImageSearcher(): '''Image searcher API for clothes retrieval web demo''' def __init__(self): root_path = os.path.dirname(__file__) inds_path = os.path.abspath(os.path.join(root_path, 'db/index')) feature_path = os.path.abspath(os.path.join(root_path, 'db/feature.npy')) self.searcher = Searcher(inds_path, feature_path) self.local_features = np.load('db/local_features.npy') def search(self, image_path, do_detection=1, k=10): #queryImage = cv2.imread(image_path) t1 = Timer() t1.tic() #queryFeatures = descriptor.get_descriptor(image_path, multi_box=False) queryFeatures = descriptor.get_descriptor(image_path) t1.toc('Feature Extraction time: ') t2 = Timer() t2.tic() #p = Profile() #results = p.runcall(self.searcher.search, queryFeatures) #p.print_stats() results, dists, ind = self.searcher.search(queryFeatures,k=5*k) #self.reranking(queryFeatures, results, dists, ind, 0.6) #self.queryExpansion2(results, dists, ind) #self.queryExpansion(queryFeatures, results, dists, ind, top=3) t2.toc('Knn search time: ') result = [] # origine image #result.append(image_path) dist = [] for j,imageName in enumerate(results): if imageName not in result: result.append(imageName) dist.append(dists[j]) #print result[:k] return result[:k],dist[:k] def reranking(self, queryFeatures, results, dists, ind, rerank_thresh=0.7): features = self.local_features feature = [] flag = 0 dist = 0 res = [] for i,index in enumerate(ind): if dists[i] < rerank_thresh: flag += 1 else: if dist == 0: dist = dists[i-1] feature.append(features[index]) res.append(results[i]) if len(feature) < 3: return feature = np.array(feature).copy() result,new_ind = self.searcher.research(res, queryFeatures, feature, 3) for j,r in enumerate(result): results.insert(flag+j, r) dists.insert(flag+j, dist) def queryExpansion2(self, results, dists, ind, threshold=0.3, k=10, top=3): features = self.searcher.features feature = [] for i in xrange(top): query = features[ind[i]] if dists[i] > threshold: break new_result, new_dist, new_ind = self.searcher.search(query,k=k) for j,dist in enumerate(new_dist): if dist > threshold: break for k,d in enumerate(dists[i:]): if dist < d: results.insert(i+k, new_result[j]) dists.insert(i+k, dist) break def queryExpansion(self, queryFeatures, results, dists, ind, threshold=0.8, k=10, top=5): """ Do Query Expansion with at most top """ features = self.searcher.features feature = [] #feature.append(queryFeatures) for i,dist in enumerate(dists): #if dist < threshold and i < top: if i < top: feature.append(features[ind[i]]) if len(feature) == 0: return 0 query = np.mean(np.array(feature), axis=0) new_results, new_dists, new_ind = self.searcher.search(query,k=k) for i,dist in enumerate(new_dists): if dist > dists[-1]: break for j,d in enumerate(dists): if dist < d: results.insert(j, new_results[i]) dists.insert(j, dist) break