def query(image):  # image or image_path
	# 读取提取好的图片特征集
	h5f = h5py.File('CNN_extracted_image_feature_many.h5', 'r')
	img_feats_set = h5f['dataset_1'][:]
	img_names_set = h5f['dataset_2'][:]
	img_names_decode = [bytes(img_name).decode('utf-8', 'ignore') for img_name in img_names_set]
	h5f.close()

	# TODO: 在加载模型之前清除后台session,避免报错 - TypeError: Cannot interpret feed_dict key as Tensor:
	# Tensor Tensor("Placeholder_8:0", shape=(3, 3, 128, 256), dtype=float32) is not an element of this graph.
	keras.backend.clear_session()

	# init VGGNet16 model
	model = VGGNet()

	# TODO: 随便生成一个向量让 model 执行一次 predict 函数
	# TODO: 避免报错 - ValueError: Tensor Tensor("dense_2/Softmax:0", shape=(?, 8), dtype=float32) is not an element of this graph.
	model.model.predict(np.zeros((1, 224, 224, 3)))

	# extract query image's feature,then compute similarity score and sort
	query_img_feat = model.extract_feature(image)
	# dot()函数计算并返回两个numpy数组的内积
	# 即**检索图片与图片库内各图片的相似度的数组**
	simil_scores = np.dot(query_img_feat, img_feats_set.T)
	# argsort()函数返回将数组元素从小到大排列后所对应的原索引号组成的数组
	# 列表切片操作[::-1]则将该索引数组的内容翻转输出
	rank_index = np.argsort(simil_scores)[::-1]
	rank_scores = simil_scores[rank_index]

	# # output similar images by required number
	# max_ret = 3
	# ret_img_list = [img_names_decode[index] for i, index in enumerate(rank_index[0:max_ret])]

	# output similar images by similarity scores
	# return similar images' path list
	rank_scores_index = [index for index, element in enumerate(rank_scores) if element >= 0.5]
	rank_index_ok = [rank_index[index] for index in rank_scores_index]
	# TODO: 此处待优化,否则函数可重用性将受限
	ret_img_nams = ['../static/database-many/'+img_names_decode[index] for index in rank_index_ok]
	print("retrieve images's index: ", rank_index_ok)
	print('retrieved images in order are: ', ret_img_nams)
	# return similar images' PIL.Image.Image obj list
	# Image.open() method returns An :py:class:`~PIL.Image.Image` object.
	# similar_images = [Image.open("database-less/"+name) for name in ret_img_nams]
	return ret_img_nams
print("               searching starts                   ")
print("--------------------------------------------------")

# read and show query image
queryDir = args["query"]
queryImg = mpimg.imread(queryDir)
plt.title("Query Image")
plt.imshow(queryImg)
plt.show()
plt.close()

# init VGGNet16 model
model = VGGNet()

# extract query image's feature,compute similarity score and sort
query_img_feat = model.extract_feature(queryDir)
# dot()函数计算并返回两个numpy数组的内积
# 即**查询图片与图片库内各图片的相似度数组**
simil_scores = np.dot(query_img_feat, img_feats_set.T)
print("about shape:", query_img_feat.shape, img_feats_set[0].shape,
      img_feats_set.shape)
print("dot(query_img_feat, img_feats_set[1]): ",
      np.dot(query_img_feat, img_feats_set[1]))
# argsort函数返回将数组元素从小到大排列后所对应的原索引号组成的数组
# 列表切片操作[::-1]则将该索引数组的内容翻转输出
rank_index = np.argsort(simil_scores)[::-1]
rank_scores = simil_scores[rank_index]
# TODO
print(type(simil_scores), type(rank_index), type(rank_scores))
print("similarity score array: ", simil_scores)
print("img_index_array: ", rank_index)
Exemple #3
0
if __name__ == "__main__":

    db = args["database"]
    imgs_path_list = get_imlist(db)
    print(imgs_path_list)

    print("--------------------------------------------------")
    print("           feature extraction starts              ")
    print("--------------------------------------------------")

    feats = []
    names = []

    model = VGGNet()
    for i, img_path in enumerate(imgs_path_list):
        img_feat = model.extract_feature(img_path)
        # os.path.split(PATH)函数以PATH参数的最后一个'\'作为分隔符,
        # 返回目录名和文件名组成的元组,索引0为目录名,索引1则为文件名
        img_name = os.path.split(img_path)[1]
        feats.append(img_feat)
        names.append(img_name)
        print(
            "extracting feature from image No. %d , %d images in total -> " %
            ((i + 1), len(imgs_path_list)), img_name)

    print("--------------------------------------------------")
    print("        writing feature extraction results        ")
    print("--------------------------------------------------")

    # file of storing extracted features
    # 创建一个h5py文件(对象)