print("--------------------------------------------------") print(" searching starts") print("--------------------------------------------------") # read and show query image queryImg = mpimg.imread(query) plt.title("Query Image") plt.imshow(queryImg) plt.show() # init VGGNet16 model model = VGGNet() # extract query image's feature, compute simlarity score and sort # queryVec = model.vgg_extract_feat(query) # 修改此处改变提取特征的网络 queryVec = model.resnet_extract_feat(query) # 修改此处改变提取特征的网络 # print(queryVec.shape) # print(feats.shape) print('--------------------------') # print(queryVec) # print(feats.T) print('--------------------------') scores = np.dot(queryVec, feats.T) # scores = np.dot(queryVec, feats.T)/(np.linalg.norm(queryVec)*np.linalg.norm(feats.T)) rank_ID = np.argsort(scores)[::-1] rank_score = scores[rank_ID] # print (rank_ID) print(rank_score) # number of top retrieved images to show maxres = 3 # 检索出三张相似度最高的图片
if __name__ == "__main__": database = 'database' index = 'models/vgg_featureCNN.resnet50.h5' img_list = get_imlist(database) print("--------------------------------------------------") print(" feature extraction starts") print("--------------------------------------------------") feats = [] names = [] model = VGGNet() for i, img_path in enumerate(img_list): # norm_feat = model.vgg_extract_feat(img_path) # 修改此处改变提取特征的网络 norm_feat = model.resnet_extract_feat(img_path) # 修改此处改变提取特征的网络 img_name = os.path.split(img_path)[1] feats.append(norm_feat) names.append(img_name) print("extracting feature from image No. %d , %d images in total" % ((i + 1), len(img_list))) feats = np.array(feats) # print(feats) # directory for storing extracted features # output = args["index"] output = index print("--------------------------------------------------") print(" writing feature extraction results ...") print("--------------------------------------------------")