def main(): sentences = word2vec.Text8Corpus("text8") # 加载语料 model = word2vec.Word2Vec(sentences, size=200, window=5, min_count=5) # 训练模型 word_set = model.wv.index2word # 单词集合 word_vec = model.wv.vectors # word2vec结果向量集合 milvus = Milvus() milvus.connect(host='localhost', port='19530') param = { 'collection_name': 'word2vec', 'dimension': 200, 'index_file_size': 1024, 'metric_type': MetricType.L2 } milvus.create_collection(param) status, ids = milvus.insert(collection_name='word2vec', records=word_vec) # 单词分类 ivf_param = {'nlist': 100} # 分成100类 milvus.create_index('word2vec', IndexType.IVF_FLAT, ivf_param) # 增加索引 status, index = milvus.describe_index( 'word2vec') # 相当于将word分成100个类别 做了聚类算法 # 查找相似度最高的单词 res = milvus.search(collection_name='word2vec', query_records=[list(word_vec[word_set.index('king')])], top_k=10, params={'nprobe': 16}) for i in range(10): id = res[1][0][i].id print(word_set[ids.index(id)]) print(1)
def run_offline_paper(): client = Milvus(host=milvus_ip, port='19530') cur.execute("SELECT ID ,doc_vector FROM paper") papers = cur.fetchall() for i in papers: try: id = i[0] vec = i[1].split(",") vec = [eval(j) for j in vec] res = client.search(collection_name='ideaman', query_records=[vec], top_k=51) status = res[0].code if status == 0: topKqueryResult = [str(j) for j in res[-1]._id_array[0]] paper_vecs = ",".join(topKqueryResult[1:]) sql = 'INSERT INTO offline_paper(paper_id , recs) VALUES({} , "{}")'.format( id, paper_vecs) cur.execute(sql) try: conn.commit() except: conn.rollback() except: pass
def milvus_test(usr_features, IS_INFER, mov_features=None, ids=None): _HOST = '127.0.0.1' _PORT = '19530' # default value table_name = 'recommender_demo' milvus = Milvus() param = {'host': _HOST, 'port': _PORT} status = milvus.connect(**param) if status.OK(): print("Server connected.") else: print("Server connect fail.") sys.exit(1) if IS_INFER: status = milvus.drop_collection(table_name) time.sleep(3) status, ok = milvus.has_collection(table_name) if not ok: if mov_features is None: print("Insert vectors is none!") sys.exit(1) param = { 'collection_name': table_name, 'dimension': 200, 'index_file_size': 1024, # optional 'metric_type': MetricType.IP # optional } print(milvus.create_collection(param)) insert_vectors = normaliz_data(mov_features) status, ids = milvus.insert(collection_name=table_name, records=insert_vectors, ids=ids) time.sleep(1) status, result = milvus.count_collection(table_name) print("rows in table recommender_demo:", result) search_vectors = normaliz_data(usr_features) param = { 'collection_name': table_name, 'query_records': search_vectors, 'top_k': 5, 'params': { 'nprobe': 16 } } time1 = time.time() status, results = milvus.search(**param) time2 = time.time() print("Top\t", "Ids\t", "Title\t", "Score") for i, re in enumerate(results[0]): title = paddle.dataset.movielens.movie_info()[int(re.id)].title print(i, "\t", re.id, "\t", title, "\t", float(re.distance) * 5)
def test_not_connect(self): client = Milvus() with pytest.raises(NotConnectError): client.create_collection({}) with pytest.raises(NotConnectError): client.has_collection("a") with pytest.raises(NotConnectError): client.describe_collection("a") with pytest.raises(NotConnectError): client.drop_collection("a") with pytest.raises(NotConnectError): client.create_index("a") with pytest.raises(NotConnectError): client.insert("a", [], None) with pytest.raises(NotConnectError): client.count_collection("a") with pytest.raises(NotConnectError): client.show_collections() with pytest.raises(NotConnectError): client.search("a", 1, 2, [], None) with pytest.raises(NotConnectError): client.search_in_files("a", [], [], 2, 1, None) with pytest.raises(NotConnectError): client._cmd("") with pytest.raises(NotConnectError): client.preload_collection("a") with pytest.raises(NotConnectError): client.describe_index("a") with pytest.raises(NotConnectError): client.drop_index("")
def search_vectors(name, vector, topk, nprobe): search_param = {'nprobe': nprobe} try: milvus = Milvus(host=MILVUS_ADDR, port=MILVUS_PORT) res, ids = milvus.search(collection_name=name, query_records=vector, top_k=topk, params=search_param) if not res.OK(): raise MilvusError("There was some error when search vectors", res) return ids except Exception as e: err_msg = "There was some error when search vectors" logger.error(f"{err_msg} : {str(e)}", exc_info=True) raise MilvusError(err_msg, e)
def search_vectors(name, vector, topk, nprobe): milvus = Milvus() search_param = {'nprobe': nprobe} try: milvus.connect(MILVUS_ADDR, MILVUS_PORT) res, ids = milvus.search(collection_name=name, query_records=vector, top_k=topk, params=search_param) if not res.OK(): raise MilvusError("There has some error when search vectors", res) return ids except Exception as e: raise MilvusError("There has some error when search vectors", e)
def milvus_test(usr_features, mov_features, ids): _HOST = '127.0.0.1' _PORT = '19530' # default value milvus = Milvus() param = {'host': _HOST, 'port': _PORT} status = milvus.connect(**param) if status.OK(): print("\nServer connected.") else: print("\nServer connect fail.") sys.exit(1) table_name = 'paddle_demo1' status, ok = milvus.has_collection(table_name) if not ok: param = { 'collection_name': table_name, 'dimension': 200, 'index_file_size': 1024, # optional 'metric_type': MetricType.IP # optional } milvus.create_collection(param) insert_vectors = normaliz_data([usr_features.tolist()]) status, ids = milvus.insert(collection_name=table_name, records=insert_vectors, ids = ids) time.sleep(1) status, result = milvus.count_collection(table_name) print("rows in table paddle_demo1:", result) status, table = milvus.count_collection(table_name) search_vectors = normaliz_data([mov_features.tolist()]) param = { 'collection_name': table_name, 'query_records': search_vectors, 'top_k': 1, 'params': {'nprobe': 16} } status, results = milvus.search(**param) print("Searched ids:", results[0][0].id) print("Score:", float(results[0][0].distance)*5) status = milvus.drop_collection(table_name)
class MilvusConnection: def __init__(self, env, name="movies_L2", port="19530", param=None): if param is None: param = dict() param = { "collection_name": name, "dimension": 128, "index_file_size": 1024, "metric_type": MetricType.L2, **param, } self.name = name self.client = Milvus(host="localhost", port=port) self.statuses = {} if not self.client.has_collection(name)[1]: status_created_collection = self.client.create_collection(param) vectors = env.base.embeddings.detach().cpu().numpy().astype( "float32") target_ids = list(range(vectors.shape[0])) status_inserted, inserted_vector_ids = self.client.insert( collection_name=name, records=vectors, ids=target_ids) status_flushed = self.client.flush([name]) status_compacted = self.client.compact(collection_name=name) self.statuses["created_collection"] = status_created_collection self.statuses["inserted"] = status_inserted self.statuses["flushed"] = status_flushed self.statuses["compacted"] = status_compacted def search(self, search_vecs, topk=10, search_param=None): if search_param is None: search_param = dict() search_param = {"nprobe": 16, **search_param} status, results = self.client.search( collection_name=self.name, query_records=search_vecs, top_k=topk, params=search_param, ) self.statuses['last_search'] = status return torch.tensor(results.id_array) def get_log(self): return self.statuses
} milvus.create_collection(param=param) ivf_param = {'nlist': 16384} milvus.create_index(collection_name=col_name, index_type=IndexType.IVF_FLAT, params=ivf_param) vectors = [[random.random() for _ in range(dim)] for _ in range(2000)] vector_ids = list(range(2000)) _, ids = milvus.insert(collection_name=col_name, records=vectors, ids=vector_ids) # print(ids) time.sleep(1) search_param = {'nprobe': 16} q_records = [[random.random() for _ in range(dim)] for _ in range(5)] _, result = milvus.search(collection_name=col_name, query_records=q_records, top_k=2, params=search_param) # for r in result: # print(r) print(result.id_array) print(result) milvus.drop_collection(collection_name=col_name) milvus.close()
id_query = {} with open("embed.txt", "r") as f: for line in f: line_lst = line.strip().split("\t") vectors.append(list(map(float, (line_lst[2].split())))) vector_ids.append(int(line_lst[1])) id_query[int(line_lst[1])] = line_lst[0] milvus.insert(collection_name='test01', records=vectors, ids=vector_ids) ivf_param = {'nlist': 16384} milvus.create_index('test01', IndexType.IVF_FLAT, ivf_param) search_param = {'nprobe': 16} q_records = vectors[:1] t = time.time() result = milvus.search(collection_name='test01', query_records=q_records, top_k=10, params=search_param)[1] print(result) print(time.time() - t) t = time.time() result = milvus.search(collection_name='test01', query_records=q_records, top_k=10, params=search_param)[1] print(result) print(time.time() - t)
def main(): milvus = Milvus(_HOST, _PORT) # num = random.randint(1, 100000) num = 100000 # Create collection demo_collection if it dosen't exist. collection_name = 'example_hybrid_collections_{}'.format(num) if milvus.has_collection(collection_name): milvus.drop_collection(collection_name) collection_param = { "fields": [{ "field": "A", "type": DataType.INT32 }, { "field": "B", "type": DataType.INT32 }, { "field": "C", "type": DataType.INT64 }, { "field": "Vec", "type": DataType.FLOAT_VECTOR, "params": { "dim": 128, "metric_type": "L2" } }], "segment_size": 100 } milvus.create_collection(collection_name, collection_param) milvus.compact(collection_name) # milvus.create_partition(collection_name, "p_01", timeout=1800) # pars = milvus.list_partitions(collection_name) # ok = milvus.has_partition(collection_name, "p_01", timeout=1800) # assert ok # ok = milvus.has_partition(collection_name, "p_02") # assert not ok # for p in pars: # if p == "_default": # continue # milvus.drop_partition(collection_name, p) # milvus.drop_collection(collection_name) # sys.exit(0) A_list = [random.randint(0, 255) for _ in range(num)] vec = [[random.random() for _ in range(128)] for _ in range(num)] hybrid_entities = [{ "field": "A", "values": A_list, "type": DataType.INT32 }, { "field": "B", "values": A_list, "type": DataType.INT32 }, { "field": "C", "values": A_list, "type": DataType.INT64 }, { "field": "Vec", "values": vec, "type": DataType.FLOAT_VECTOR, "params": { "dim": 128 } }] for slice_e in utils.entities_slice(hybrid_entities): ids = milvus.insert(collection_name, slice_e) milvus.flush([collection_name]) print("Flush ... ") # time.sleep(3) count = milvus.count_entities(collection_name) milvus.delete_entity_by_id(collection_name, ids[:1]) milvus.flush([collection_name]) print("Get entity be id start ...... ") entities = milvus.get_entity_by_id(collection_name, ids[:1]) et = entities.dict() milvus.delete_entity_by_id(collection_name, ids[1:2]) milvus.flush([collection_name]) print("Create index ......") milvus.create_index(collection_name, "Vec", { "index_type": "IVF_FLAT", "metric_type": "L2", "params": { "nlist": 100 } }) print("Create index done.") info = milvus.get_collection_info(collection_name) print(info) stats = milvus.get_collection_stats(collection_name) print("\nstats\n") print(stats) query_hybrid = \ { "bool": { "must": [ { "term": { "A": [1, 2, 5] } }, { "range": { "B": {"GT": 1, "LT": 100} } }, { "vector": { "Vec": { "topk": 10, "query": vec[: 10000], "params": {"nprobe": 10} } } } ], }, } # print("Start searach ..", flush=True) # results = milvus.search(collection_name, query_hybrid) # print(results) # # for r in list(results): # print("ids", r.ids) # print("distances", r.distances) t0 = time.time() count = 0 results = milvus.search(collection_name, query_hybrid, fields=["B"]) for r in list(results): # print("ids", r.ids) # print("distances", r.distances) for rr in r: count += 1 # print(rr.entity.get("B")) print("Search cost {} s".format(time.time() - t0)) # for result in results: # for r in result: # print(f"{r}") # itertor entity id # for result in results: # for r in result: # # get distance # dis = r.distance # id_ = r.id # # obtain all field name # fields = r.entity.fields # for f in fields: # # get field value by field name # # fv = r.entity. # fv = r.entity.value_of_field(f) # print(fv) milvus.drop_collection(collection_name)
# ------ # Basic hybrid search entities: # And we want to get all the fields back in results, so fields = ["duration", "release_year", "embedding"]. # If searching successfully, results will be returned. # `results` have `nq`(number of queries) separate results, since we only query for 1 film, The length of # `results` is 1. # We ask for top 3 in-return, but our condition is too strict while the database is too small, so we can # only get 1 film, which means length of `entities` in below is also 1. # # Now we've gotten the results, and known it's a 1 x 1 structure, how can we get ids, distances and fields? # It's very simple, for every `topk_film`, it has three properties: `id, distance and entity`. # All fields are stored in `entity`, so you can finally obtain these data as below: # And the result should be film with id = 3. # ------ results = client.search(collection_name, dsl, fields=["duration", "release_year", "embedding"]) print("\n----------search----------") for entities in results: for topk_film in entities: current_entity = topk_film.entity print("- id: {}".format(topk_film.id)) print("- distance: {}".format(topk_film.distance)) print("- release_year: {}".format(current_entity.release_year)) print("- duration: {}".format(current_entity.duration)) print("- embedding: {}".format(current_entity.embedding)) # ------ # Basic delete: # Now let's see how to delete things in Milvus. # You can simply delete entities by their ids.
def search(self, img_id): x_real, x_fake, cost, optimizer, z_noise, z_real, guessed_z = self.build_model() saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) model_file = tf.train.latest_checkpoint('D:/py_project/untitled1/polls/homework/vae/ckpt/') # 加载参数 saver.restore(sess, model_file) # img_id = 'test_img_0.png' search_img_path = "D:/py_project/untitled1/polls/static/polls/homework/"+img_id search_img = cv2.imread(search_img_path, cv2.IMREAD_GRAYSCALE) search_img = np.reshape(search_img, [-1, 784])/255 # 获取 要搜索图片的特征向量 test_z = sess.run(guessed_z, feed_dict={x_real: search_img}) print(img_id, test_z) # 连接milvus数据库 并进行查询 milvus = Milvus() milvus.connect(host='localhost', port='19530') collection_name = 'mnist' search_param = {'nprobe': 16} status, result = milvus.search(collection_name=collection_name, query_records=test_z, top_k=100, params=search_param) milvus.disconnect() # index = [] # for row in result: # for item in row: # index.append(item.id) # batch_real = self.train_images[index] # r, c = 10, 10 # fig, axs = plt.subplots(r, c) # cnt = 0 # for p in range(r): # for q in range(c): # axs[p, q].imshow(np.reshape(batch_real[cnt], (28, 28)), cmap='gray') # axs[p, q].axis('off') # cnt += 1 # print(os.path) # fig.savefig("D:/py_project/untitled1/polls/static/polls/homework/d_real.png") # plt.show() # plt.close() # 保存查询结果 for row in result: for i, item in enumerate(row): # a = 1 # print(item.id) result_img_path = "D:/py_project/untitled1/polls/static/polls/homework/search_result/result_" batch_test_pic = self.train_images[item.id] pic_name = result_img_path + str(i) + ".png" cv2.imwrite(pic_name, batch_test_pic) print('search success!') return True
# And we want to get all the fields back in reasults, so fields = ["duration", "release_year", "embedding"]. # If searching successfully, results will be returned. # `results` have `nq`(number of queries) seperate results, since we only query for 1 film, The length of # `results` is 1. # We ask for top 3 in-return, but our condition is too strict while the database is too small, so we can # only get 1 film, which means length of `entities` in below is also 1. # # Now we've gotten the results, and known it's a 1 x 1 structure, how can we get ids, distances and fields? # It's very simple, for every `topk_film`, it has three properties: `id, distance and entity`. # All fields are stored in `entity`, so you can finally obtain these data as below: # And the result should be film with id = 3. # # Here, we pass parameter '_async=True' to insert data asynchronously, and return a `Future` object. # ------ print("\n----------search----------") search_future = client.search(collection_name, dsl, _async=True) search_results = search_future.result() # ------ # Basic delete: # Now let's see how to delete things in Milvus. # You can simply delete entities by their ids. # # After deleted, we invoke compact collection in a asynchronous way. # ------ print("\n----------delete id = 1, id = 2----------") client.delete_entity_by_id(collection_name, ids=[1, 4]) client.flush() # flush is important compact_future = client.compact(collection_name, _async=True) compact_future.result()
class RecallServerServicer(object): def __init__(self): self.uv_client = LocalPredictor() self.uv_client.load_model_config( "user_vector_model/serving_server_dir") milvus_host = '127.0.0.1' milvus_port = '19530' self.milvus_client = Milvus(milvus_host, milvus_port) self.collection_name = 'demo_films' def get_user_vector(self, user_info): dic = {"userid": [], "gender": [], "age": [], "occupation": []} lod = [0] dic["userid"].append(hash2(user_info.user_id)) dic["gender"].append(hash2(user_info.gender)) dic["age"].append(hash2(user_info.age)) dic["occupation"].append(hash2(user_info.job)) lod.append(1) dic["userid.lod"] = lod dic["gender.lod"] = lod dic["age.lod"] = lod dic["occupation.lod"] = lod for key in dic: dic[key] = np.array(dic[key]).astype(np.int64).reshape( len(dic[key]), 1) fetch_map = self.uv_client.predict( feed=dic, fetch=["save_infer_model/scale_0.tmp_0"], batch=True) return fetch_map["save_infer_model/scale_0.tmp_0"].tolist()[0] def recall(self, request, context): ''' message RecallRequest{ string log_id = 1; user_info.UserInfo user_info = 2; string recall_type= 3; uint32 request_num= 4; } message RecallResponse{ message Error { uint32 code = 1; string text = 2; } message ScorePair { string nid = 1; float score = 2; }; Error error = 1; repeated ScorePair score_pairs = 2; } ''' recall_res = recall_pb2.RecallResponse() user_vector = self.get_user_vector(request.user_info) query_hybrid = { "bool": { "must": [{ "vector": { "embedding": { "topk": 100, "query": [user_vector], "metric_type": "L2" } } }] } } results = self.milvus_client.search(self.collection_name, query_hybrid, fields=["embedding"]) for entities in results: if len(entities) == 0: recall_res.error.code = 500 recall_res.error.text = "Recall server get milvus fail. ({})".format( str(request)) return recall_res for topk_film in entities: current_entity = topk_film.entity score_pair = recall_res.score_pairs.add() score_pair.nid = str(topk_film.id) score_pair.score = float(topk_film.distance) recall_res.error.code = 200 return recall_res
def main(): milvus = Milvus() # Connect to Milvus server # You may need to change _HOST and _PORT accordingly param = {'host': _HOST, 'port': _PORT} status = milvus.connect(**param) if status.OK(): print("Server connected.") else: print("Server connect fail.") sys.exit(1) # Create collection demo_collection if it dosen't exist. collection_name = 'example_collection' status, ok = milvus.has_collection(collection_name) if not ok: param = { 'collection_name': collection_name, 'dimension': _DIM, 'index_file_size': _INDEX_FILE_SIZE, # optional 'metric_type': MetricType.L2 # optional } milvus.create_collection(param) # Show collections in Milvus server _, collections = milvus.show_collections() # present collection info _, info = milvus.collection_info(collection_name) print(info) # Describe demo_collection _, collection = milvus.describe_collection(collection_name) print(collection) # 10000 vectors with 16 dimension # element per dimension is float32 type # vectors should be a 2-D array vectors = [[random.random() for _ in range(_DIM)] for _ in range(10000)] # You can also use numpy to generate random vectors: # `vectors = np.random.rand(10000, 16).astype(np.float32)` # Insert vectors into demo_collection, return status and vectors id list status, ids = milvus.insert(collection_name=collection_name, records=vectors) # Flush collection inserted data to disk. milvus.flush([collection_name]) # Get demo_collection row count status, result = milvus.count_collection(collection_name) # create index of vectors, search more rapidly index_param = { 'nlist': 2048 } # Create ivflat index in demo_collection # You can search vectors without creating index. however, Creating index help to # search faster print("Creating index: {}".format(index_param)) status = milvus.create_index(collection_name, IndexType.IVF_FLAT, index_param) # describe index, get information of index status, index = milvus.describe_index(collection_name) print(index) # Use the top 10 vectors for similarity search query_vectors = vectors[0:10] # execute vector similarity search search_param = { "nprobe": 16 } param = { 'collection_name': collection_name, 'query_records': query_vectors, 'top_k': 1, 'params': search_param } print("Searching ... ") status, results = milvus.search(**param) if status.OK(): # indicate search result # also use by: # `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]` if results[0][0].distance == 0.0 or results[0][0].id == ids[0]: print('Query result is correct') else: print('Query result isn\'t correct') # print results print(results) # Delete demo_collection status = milvus.drop_collection(collection_name) # Disconnect from Milvus status = milvus.disconnect()
print("Start create index ......") status = client.create_index(collection_name, IndexType.ANNOY, index_param) if status.OK(): print("Create index ANNOY successfully\n") else: print("Create index ANNOY fail") # select top 10 vectors from inserted as query vectors query_vectors = vectors[:10] # specify search param search_param = {"search_k": 10} # specify topk is 1, search approximate nearest 1 neighbor status, result = client.search(collection_name, 1, query_vectors, params=search_param) if status.OK(): # show search result print("Search successfully. Result:\n", result) else: print("Search fail: ", status) # drop collection client.drop_collection(collection_name) # disconnect from server # client.disconnect()
def main(): milvus = Milvus(handler="HTTP") # Connect to Milvus server # You may need to change _HOST and _PORT accordingly param = {'host': _HOST, 'port': _PORT} status = milvus.connect(**param) if status.OK(): print("Server connected.") else: print("Server connect fail.") sys.exit(1) # Create table demo_table if it dosen't exist. table_name = 'demo_tables' status, ok = milvus.has_table(table_name) if not ok: param = { 'table_name': table_name, 'dimension': _DIM, 'index_file_size': _INDEX_FILE_SIZE, # optional 'metric_type': MetricType.L2 # optional } milvus.create_table(param) # Show tables in Milvus server _, tables = milvus.show_tables() # Describe demo_table _, table = milvus.describe_table(table_name) print(table) # 10000 vectors with 16 dimension # element per dimension is float32 type # vectors should be a 2-D array vectors = [[random.random() for _ in range(_DIM)] for _ in range(100000)] # You can also use numpy to generate random vectors: # `vectors = np.random.rand(10000, 16).astype(np.float32).tolist()` # Insert vectors into demo_table, return status and vectors id list status, ids = milvus.insert(table_name=table_name, records=vectors) # Wait for 6 seconds, until Milvus server persist vector data. time.sleep(6) # Get demo_table row count status, result = milvus.count_table(table_name) # create index of vectors, search more rapidly index_param = { 'index_type': IndexType.IVFLAT, # choice ivflat index 'nlist': 2048 } # Create ivflat index in demo_table # You can search vectors without creating index. however, Creating index help to # search faster status = milvus.create_index(table_name, index_param) # describe index, get information of index status, index = milvus.describe_index(table_name) print(index) # Use the top 10 vectors for similarity search query_vectors = vectors[0:10] # execute vector similarity search param = { 'table_name': table_name, 'query_records': query_vectors, 'top_k': 1, 'nprobe': 16 } status, results = milvus.search(**param) if status.OK(): # indicate search result # also use by: # `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]` if results[0][0].distance == 0.0 or results[0][0].id == ids[0]: print('Query result is correct') else: print('Query result isn\'t correct') # print results print(results) # Delete demo_table status = milvus.drop_table(table_name) # Disconnect from Milvus status = milvus.disconnect()
def main(): # Specify server addr when create milvus client instance # milvus client instance maintain a connection pool, param # `pool_size` specify the max connection num. milvus = Milvus(_HOST, _PORT) # Create collection demo_collection if it dosen't exist. collection_name = 'example_collection_' status, ok = milvus.has_collection(collection_name) if not ok: param = { 'collection_name': collection_name, 'dimension': _DIM, 'index_file_size': _INDEX_FILE_SIZE, # optional 'metric_type': MetricType.L2 # optional } milvus.create_collection(param) # Show collections in Milvus server _, collections = milvus.list_collections() # Describe demo_collection _, collection = milvus.get_collection_info(collection_name) print(collection) # element per dimension is float32 type # vectors should be a 2-D array vectors = text2vec(index_sentences) print(vectors) # Insert vectors into demo_collection, return status and vectors id list status, ids = milvus.insert(collection_name=collection_name, records=vectors) if not status.OK(): print("Insert failed: {}".format(status)) else: print(ids) #create a quick lookup table to easily access the indexed text/sentences given the ids look_up = {} for ID, sentences in zip(ids, index_sentences): look_up[ID] = sentences for k in look_up: print(k, look_up[k]) # Flush collection inserted data to disk. milvus.flush([collection_name]) # Get demo_collection row count status, result = milvus.count_entities(collection_name) # present collection statistics info _, info = milvus.get_collection_stats(collection_name) print(info) # Obtain raw vectors by providing vector ids status, result_vectors = milvus.get_entity_by_id(collection_name, ids) # create index of vectors, search more rapidly index_param = {'nlist': 2048} # Create ivflat index in demo_collection # You can search vectors without creating index. however, Creating index help to # search faster print("Creating index: {}".format(index_param)) status = milvus.create_index(collection_name, IndexType.IVF_FLAT, index_param) # describe index, get information of index status, index = milvus.get_index_info(collection_name) print(index) # Use the query sentences for similarity search query_vectors = text2vec(query_sentences) # execute vector similarity search search_param = {"nprobe": 16} print("Searching ... ") param = { 'collection_name': collection_name, 'query_records': query_vectors, 'top_k': 1, 'params': search_param, } status, results = milvus.search(**param) if status.OK(): # indicate search result # also use by: # `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]` if results[0][0].distance == 0.0 or results[0][0].id == ids[0]: print('Query result is correct') else: print('Query result isn\'t correct') # print results for res in results: for ele in res: print('id:{}, text:{}, distance: {}'.format( ele.id, look_up[ele.id], ele.distance)) else: print("Search failed. ", status) # Delete demo_collection status = milvus.drop_collection(collection_name)
class Indexer: ''' 索引器。 ''' def __init__(self, name, host='127.0.0.1', port='19531'): ''' 初始化。 ''' self.client = Milvus(host=host, port=port) self.collection = name def init(self, lenient=False): ''' 创建集合。 ''' if lenient: status, result = self.client.has_collection( collection_name=self.collection) if status.code != 0: raise ExertMilvusException(status) if result: return status = self.client.create_collection({ 'collection_name': self.collection, 'dimension': 512, 'index_file_size': 1024, 'metric_type': MetricType.L2 }) if status.code != 0 and not (lenient and status.code == 9): raise ExertMilvusException(status) # 创建索引。 status = self.client.create_index(collection_name=self.collection, index_type=IndexType.IVF_FLAT, params={'nlist': 16384}) if status.code != 0: raise ExertMilvusException(status) return status def drop(self): ''' 删除集合。 ''' status = self.client.drop_collection(collection_name=self.collection) if status.code != 0: raise ExertMilvusException(status) def flush(self): ''' 写入到硬盘。 ''' status = self.client.flush([self.collection]) if status.code != 0: raise ExertMilvusException(status) def compact(self): ''' 压缩集合。 ''' status = self.client.compact(collection_name=self.collection) if status.code != 0: raise ExertMilvusException(status) def close(self): ''' 关闭链接。 ''' self.client.close() def new_tag(self, tag): ''' 建分块标签。 ''' status = self.client.create_partition(collection_name=self.collection, partition_tag=tag) if status.code != 0: raise ExertMilvusException(status) def list_tag(self): ''' 列举分块标签。 ''' status, result = self.client.list_partitions( collection_name=self.collection) if status.code != 0: raise ExertMilvusException(status) return result def drop_tag(self, tag): ''' 删除分块标签。 ''' status = self.client.drop_partition(collection_name=self.collection, partition_tag=tag) if status.code != 0: raise ExertMilvusException(status) def index(self, vectors, tag=None, ids=None): ''' 添加索引 ''' params = {} if tag != None: params['tag'] = tag if ids != None: params['ids'] = ids status, result = self.client.insert(collection_name=self.collection, records=vectors, **params) if status.code != 0: raise ExertMilvusException(status) return result def listing(self, ids): ''' 列举信息。 ''' status, result = self.client.get_entity_by_id( collection_name=self.collection, ids=ids) if status.code != 0: raise ExertMilvusException(status) return result def counting(self): ''' 计算索引数。 ''' status, result = self.client.count_entities( collection_name=self.collection) if status.code != 0: raise ExertMilvusException(status) return result def unindex(self, ids): ''' 去掉索引。 ''' status = self.client.delete_entity_by_id( collection_name=self.collection, id_array=ids) if status.code != 0: raise ExertMilvusException(status) def search(self, vectors, top_count=100, tags=None): ''' 搜索。 ''' params = {'params': {'nprobe': 16}} if tags != None: params['partition_tags'] = tags status, results = self.client.search(collection_name=self.collection, query_records=vectors, top_k=top_count, **params) if status.code != 0: raise ExertMilvusException(status) return results
class MilvusClient(object): def __init__(self, collection_name=None, host=None, port=None, timeout=60): """ Milvus client wrapper for python-sdk. Default timeout set 60s """ self._collection_name = collection_name try: start_time = time.time() if not host: host = SERVER_HOST_DEFAULT if not port: port = SERVER_PORT_DEFAULT logger.debug(host) logger.debug(port) # retry connect for remote server i = 0 while time.time() < start_time + timeout: try: self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False) if self._milvus.server_status(): logger.debug("Try connect times: %d, %s" % (i, round(time.time() - start_time, 2))) break except Exception as e: logger.debug("Milvus connect failed: %d times" % i) i = i + 1 if time.time() > start_time + timeout: raise Exception("Server connect timeout") except Exception as e: raise e self._metric_type = None if self._collection_name and self.exists_collection(): self._metric_type = metric_type_to_str(self.describe()[1].metric_type) self._dimension = self.describe()[1].dimension def __str__(self): return 'Milvus collection %s' % self._collection_name def set_collection(self, name): self._collection_name = name def check_status(self, status): if not status.OK(): logger.error(self._collection_name) logger.error(status.message) logger.error(self._milvus.server_status()) logger.error(self.count()) raise Exception("Status not ok") def check_result_ids(self, result): for index, item in enumerate(result): if item[0].distance >= epsilon: logger.error(index) logger.error(item[0].distance) raise Exception("Distance wrong") def create_collection(self, collection_name, dimension, index_file_size, metric_type): if not self._collection_name: self._collection_name = collection_name if metric_type not in METRIC_MAP.keys(): raise Exception("Not supported metric_type: %s" % metric_type) metric_type = METRIC_MAP[metric_type] create_param = {'collection_name': collection_name, 'dimension': dimension, 'index_file_size': index_file_size, "metric_type": metric_type} status = self._milvus.create_collection(create_param) self.check_status(status) def create_partition(self, tag_name): status = self._milvus.create_partition(self._collection_name, tag_name) self.check_status(status) def drop_partition(self, tag_name): status = self._milvus.drop_partition(self._collection_name, tag_name) self.check_status(status) def list_partitions(self): status, tags = self._milvus.list_partitions(self._collection_name) self.check_status(status) return tags @time_wrapper def insert(self, X, ids=None, collection_name=None): if collection_name is None: collection_name = self._collection_name status, result = self._milvus.insert(collection_name, X, ids) self.check_status(status) return status, result def insert_rand(self): insert_xb = random.randint(1, 100) X = [[random.random() for _ in range(self._dimension)] for _ in range(insert_xb)] X = utils.normalize(self._metric_type, X) count_before = self.count() status, _ = self.insert(X) self.check_status(status) self.flush() if count_before + insert_xb != self.count(): raise Exception("Assert failed after inserting") def get_rand_ids(self, length): while True: status, stats = self._milvus.get_collection_stats(self._collection_name) self.check_status(status) segments = stats["partitions"][0]["segments"] # random choice one segment segment = random.choice(segments) status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) if not status.OK(): logger.error(status.message) continue if len(segment_ids): break if length >= len(segment_ids): logger.debug("Reset length: %d" % len(segment_ids)) return segment_ids return random.sample(segment_ids, length) def get_rand_ids_each_segment(self, length): res = [] status, stats = self._milvus.get_collection_stats(self._collection_name) self.check_status(status) segments = stats["partitions"][0]["segments"] segments_num = len(segments) # random choice from each segment for segment in segments: status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) self.check_status(status) res.extend(segment_ids[:length]) return segments_num, res def get_rand_entities(self, length): ids = self.get_rand_ids(length) status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids) self.check_status(status) return ids, get_res @time_wrapper def get_entities(self, get_ids): status, get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids) self.check_status(status) return get_res @time_wrapper def delete(self, ids, collection_name=None): if collection_name is None: collection_name = self._collection_name status = self._milvus.delete_entity_by_id(collection_name, ids) self.check_status(status) def delete_rand(self): delete_id_length = random.randint(1, 100) count_before = self.count() logger.info("%s: length to delete: %d" % (self._collection_name, delete_id_length)) delete_ids = self.get_rand_ids(delete_id_length) self.delete(delete_ids) self.flush() logger.info("%s: count after delete: %d" % (self._collection_name, self.count())) status, get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids) self.check_status(status) for item in get_res: if item: raise Exception("Assert failed after delete") if count_before - len(delete_ids) != self.count(): raise Exception("Assert failed after delete") @time_wrapper def flush(self, collection_name=None): if collection_name is None: collection_name = self._collection_name status = self._milvus.flush([collection_name]) self.check_status(status) @time_wrapper def compact(self, collection_name=None): if collection_name is None: collection_name = self._collection_name status = self._milvus.compact(collection_name) self.check_status(status) @time_wrapper def create_index(self, index_type, index_param=None): index_type = INDEX_MAP[index_type] logger.info("Building index start, collection_name: %s, index_type: %s" % (self._collection_name, index_type)) if index_param: logger.info(index_param) status = self._milvus.create_index(self._collection_name, index_type, index_param) self.check_status(status) def describe_index(self): status, result = self._milvus.get_index_info(self._collection_name) self.check_status(status) index_type = None for k, v in INDEX_MAP.items(): if result._index_type == v: index_type = k break return {"index_type": index_type, "index_param": result._params} def drop_index(self): logger.info("Drop index: %s" % self._collection_name) return self._milvus.drop_index(self._collection_name) def query(self, X, top_k, search_param=None, collection_name=None): if collection_name is None: collection_name = self._collection_name status, result = self._milvus.search(collection_name, top_k, query_records=X, params=search_param) self.check_status(status) return result def query_rand(self): top_k = random.randint(1, 100) nq = random.randint(1, 100) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} _, X = self.get_rand_entities(nq) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) status, _ = self._milvus.search(self._collection_name, top_k, query_records=X, params=search_param) self.check_status(status) # for i, item in enumerate(search_res): # if item[0].id != ids[i]: # logger.warning("The index of search result: %d" % i) # raise Exception("Query failed") # @time_wrapper # def query_ids(self, top_k, ids, search_param=None): # status, result = self._milvus.search_by_id(self._collection_name, ids, top_k, params=search_param) # self.check_result_ids(result) # return result def count(self, name=None): if name is None: name = self._collection_name logger.debug(self._milvus.count_entities(name)) row_count = self._milvus.count_entities(name)[1] if not row_count: row_count = 0 logger.debug("Row count: %d in collection: <%s>" % (row_count, name)) return row_count def drop(self, timeout=120, name=None): timeout = int(timeout) if name is None: name = self._collection_name logger.info("Start delete collection: %s" % name) status = self._milvus.drop_collection(name) self.check_status(status) i = 0 while i < timeout: if self.count(name=name): time.sleep(1) i = i + 1 continue else: break if i >= timeout: logger.error("Delete collection timeout") def describe(self): # logger.info(self._milvus.get_collection_info(self._collection_name)) return self._milvus.get_collection_info(self._collection_name) def show_collections(self): return self._milvus.list_collections() def exists_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name _, res = self._milvus.has_collection(collection_name) # self.check_status(status) return res def clean_db(self): collection_names = self.show_collections()[1] for name in collection_names: logger.debug(name) self.drop(name=name) @time_wrapper def preload_collection(self): status = self._milvus.load_collection(self._collection_name, timeout=3000) self.check_status(status) return status def get_server_version(self): _, res = self._milvus.server_version() return res def get_server_mode(self): return self.cmd("mode") def get_server_commit(self): return self.cmd("build_commit_id") def get_server_config(self): return json.loads(self.cmd("get_config *")) def get_mem_info(self): result = json.loads(self.cmd("get_system_info")) result_human = { # unit: Gb "memory_used": round(int(result["memory_used"]) / (1024*1024*1024), 2) } return result_human def cmd(self, command): status, res = self._milvus._cmd(command) logger.info("Server command: %s, result: %s" % (command, res)) self.check_status(status) return res
def main(): # Specify server addr when create milvus client instance # milvus client instance maintain a connection pool, param # `pool_size` specify the max connection num. # 获取服务端的连接 milvus = Milvus(_HOST, _PORT) # Create collection demo_collection if it dosen't exist. # 创建collection collection_name = 'example_collection_' # 看是否有这个collection status, ok = milvus.has_collection(collection_name) # 如果没有则创建 if not ok: param = { 'collection_name': collection_name, 'dimension': _DIM, 'index_file_size': _INDEX_FILE_SIZE, # optional 'metric_type': MetricType.L2 # optional } # 创建collection milvus.create_collection(param) # Show collections in Milvus server # 查看所有的collection _, collections = milvus.list_collections() print(collections) # Describe demo_collection # 得到当前的collection _, collection = milvus.get_collection_info(collection_name) print(collection) # 10000 vectors with 128 dimension # element per dimension is float32 type # vectors should be a 2-D array # 创建10个长度为8的向量 vectors = [[random.random() for _ in range(_DIM)] for _ in range(10)] print(vectors) # You can also use numpy to generate random vectors: # vectors = np.random.rand(10000, _DIM).astype(np.float32) # Insert vectors into demo_collection, return status and vectors id list # 把这10个向量都插入milvus status, ids = milvus.insert(collection_name=collection_name, records=vectors) if not status.OK(): print("Insert failed: {}".format(status)) print(ids) # Flush collection inserted data to disk. # 数据落盘 milvus.flush([collection_name]) # Get demo_collection row count # 得到当前row的数量 status, result = milvus.count_entities(collection_name) print(status) print(result) # present collection statistics info # 查看collection的统计数据 _, info = milvus.get_collection_stats(collection_name) print(info) # Obtain raw vectors by providing vector ids # 得到前十个数据 status, result_vectors = milvus.get_entity_by_id(collection_name, ids[:10]) print(result_vectors) # create index of vectors, search more rapidly # 创建索引 index_param = {'nlist': 2048} # Create ivflat index in demo_collection # You can search vectors without creating index. however, Creating index help to # search faster # 创建ivf_flat print("Creating index: {}".format(index_param)) status = milvus.create_index(collection_name, IndexType.IVF_FLAT, index_param) # describe index, get information of index # 得到索引的信息 status, index = milvus.get_index_info(collection_name) print(index) # Use the top 10 vectors for similarity search # 对前10个数据进行query query_vectors = vectors[0:10] # execute vector similarity search # 索引的搜索的中心点数量 search_param = {"nprobe": 16} print("Searching ... ") param = { 'collection_name': collection_name, 'query_records': query_vectors, 'top_k': 1, 'params': search_param, } # 进行搜索 status, results = milvus.search(**param) if status.OK(): print(results) # indicate search result # also use by: # `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]` if results[0][0].distance == 0.0 or results[0][0].id == ids[0]: print('Query result is correct') else: print('Query result isn\'t correct') # print results print(results) else: print("Search failed. ", status) # Delete demo_collection # 删除掉collection status = milvus.drop_collection(collection_name)
def main(): # Specify server addr when create milvus client instance milvus = Milvus(_HOST, _PORT) # Create collection demo_collection if it dosen't exist. collection_name = 'example_async_collection_' status, ok = milvus.has_collection(collection_name) if not ok: param = { 'collection_name': collection_name, 'dimension': _DIM, 'index_file_size': 128, # optional 'metric_type': MetricType.L2 # optional } status = milvus.create_collection(param) if not status.OK(): print("Create collection failed: {}".format(status.message), file=sys.stderr) print("exiting ...", file=sys.stderr) sys.exit(1) # Show collections in Milvus server _, collections = milvus.list_collections() # Describe demo_collection _, collection = milvus.get_collection_info(collection_name) print(collection) # 10000 vectors with 16 dimension # element per dimension is float32 type # vectors should be a 2-D array vectors = [[random.random() for _ in range(_DIM)] for _ in range(100000)] # You can also use numpy to generate random vectors: # `vectors = np.random.rand(10000, 16).astype(np.float32)` def _insert_callback(status, ids): if status.OK(): print("Insert successfully") else: print("Insert failed.", status.message) # Insert vectors into demo_collection, adding callback function insert_future = milvus.insert(collection_name=collection_name, records=vectors, _async=True, _callback=_insert_callback) # Or invoke result() to get results: # insert_future = milvus.insert(collection_name=collection_name, records=vectors, _async=True) # status, ids = insert_future.result() insert_future.done() # Flush collection inserted data to disk. def _flush_callback(status): if status.OK(): print("Flush successfully") else: print("Flush failed.", status.message) flush_future = milvus.flush([collection_name], _async=True, _callback=_flush_callback) # Or invoke result() to get results: # flush_future = milvus.flush([collection_name], _async=True) # status = flush_future.result() flush_future.done() def _compact_callback(status): if status.OK(): print("Compact successfully") else: print("Compact failed.", status.message) compact_furure = milvus.compact(collection_name, _async=True, _cakkback=_compact_callback) # Or invoke result() to get results: # compact_future = milvus.compact(collection_name, _async=True) # status = compact_future.result() compact_furure.done() # Get demo_collection row count status, result = milvus.count_entities(collection_name) # present collection info _, info = milvus.get_collection_stats(collection_name) print(info) # create index of vectors, search more rapidly index_param = {'nlist': 2048} def _index_callback(status): if status.OK(): print("Create index successfully") else: print("Create index failed.", status.message) # Create ivflat index in demo_collection # You can search vectors without creating index. however, Creating index help to # search faster print("Creating index: {}".format(index_param)) index_future = milvus.create_index(collection_name, IndexType.IVF_FLAT, index_param, _async=True, _callback=_index_callback) # Or invoke result() to get results: # index_future = milvus.create_index(collection_name, IndexType.IVF_FLAT, index_param, _async=True) # status = index_future.result() index_future.done() # describe index, get information of index status, index = milvus.get_index_info(collection_name) print(index) # Use the top 10 vectors for similarity search query_vectors = vectors[0:10] # execute vector similarity search search_param = {"nprobe": 16} print("Searching ... ") def _search_callback(status, results): # if status.OK(): # print("Search successfully") # else: # print("Search failed.", status.message) if status.OK(): # indicate search result # also use by: # `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]` if results[0][0].distance == 0.0: # or results[0][0].id == ids[0]: print('Query result is correct') else: print('Query result isn\'t correct') # print results print(results) else: print("Search failed. ", status) param = { 'collection_name': collection_name, 'query_records': query_vectors, 'top_k': 1, 'params': search_param, "_async": True, "_callback": _search_callback } search_future = milvus.search(**param) # Or invoke result() to get results: # # param = { # 'collection_name': collection_name, # 'query_records': query_vectors, # 'top_k': 1, # 'params': search_param, # "_async": True, # } # search_future = milvus.search(param) # status, results = index_future.result() search_future.done() # Delete demo_collection status = milvus.drop_collection(collection_name)
class SearchEngine: def __init__(self, host, port): self.host = os.environ.get('MILVUS_HOST', host) self.port = os.environ.get('MILVUS_PORT', str(port)) self.engine = Milvus(host=self.host, port=self.port) self.collection_name = None ################################################# # HANDLE COLLECTION ################################################# def create_collection(self, collection_name, dimension): # collection 생성 param = { 'collection_name': collection_name, 'dimension': dimension, 'index_file_size': 1000, 'metric_type': MetricType.IP } self.engine.create_collection(param) print('[INFO] collection {}을 생성했습니다.'.format(collection_name)) def drop_collection(self, collection_name): # collection 삭제 self.engine.drop_collection(collection_name=collection_name) print('[INFO] collection {}을 삭제했습니다.'.format(collection_name)) def get_collection_stats(self, collection_name): # collection 정보 출력 print(self.engine.get_collection_info(collection_name)) print(self.engine.get_collection_stats(collection_name)) def set_collection(self, collection_name): # 쿼리 조작을 하기 위한 collection 지정 self.collection_name = collection_name print('[INFO] setting collection {}'.format(self.collection_name)) ################################################# # UTILS ################################################# def check_set_collection(self): # 쿼리 조작을 위한 collection 지정이 되어있는 지 체크 assert self.collection_name is not None, '[ERROR] collection을 setting해 주십시오!!' def check_exist_data_by_key(self, key): # collection에 정해진 key가 존재하는 지 확인 self.check_set_collection() _, vector = self.engine.get_entity_by_id( collection_name=self.collection_name, ids=key) vector = vector if vector else [vector] return True if vector[0] else False def convert_key_format(self, key): return [key] if isinstance(key, int) else key def convert_value_format(self, value): rank = len(value.shape) assert rank < 2, '[ERROR] value의 dim을 2 미만으로 입력해 주세요!!' return value.reshape(1, -1) if rank == 1 else value ################################################# # INSERT ################################################# def insert_data(self, key, value): # 데이터를 collection에 입력 key = self.convert_key_format(key) value = self.convert_value_format(value) if self.check_exist_data_by_key(key): print("[ERROR] 이미 collection에 데이터가 존재합니다.") return self.engine.insert(collection_name=self.collection_name, records=value, ids=key) self.engine.flush([self.collection_name]) print('[INFO] insert key {}'.format(key)) ################################################# # DELELTE ################################################# def delete_data(self, key): # 데이터를 collection에서 제거 key = self.convert_key_format(key) if not self.check_exist_data_by_key(key): print("[ERROR] collection에 데이터가 존재하지 않습니다.") return self.engine.delete_entity_by_id(self.collection_name, key) self.engine.flush([self.collection_name]) print('[INFO] delete key {}'.format(key)) ################################################# # UPDATE ################################################# def update_data(self, key, value): # 데이터를 업데이트 key = self.convert_key_format(key) value = self.convert_value_format(value) if not self.check_exist_data_by_key(key): print("[ERROR] collection에 데이터가 존재하지 않습니다.") return self.engine.delete_entity_by_id(self.collection_name, key) self.engine.flush([self.collection_name]) self.engine.insert(collection_name=self.collection_name, records=value, ids=key) self.engine.flush([self.collection_name]) print('[INFO] update key {}'.format(key)) ################################################# # SEARCH ################################################# def search_by_feature(self, feature, top_k): # feature를 이용해서 데이터를 검색 self.check_set_collection() feature = self.convert_value_format(feature) _, result = self.engine.search(collection_name=self.collection_name, query_records=feature, top_k=top_k) li_id = [ list(map(lambda x: x.id, result[0])) for i in range(len(result)) ] li_dist = [ list(map(lambda x: x.distance, result[0])) for i in range(len(result)) ] return li_id, li_dist def search_by_key(self, key, top_k): # key를 이용해서 데이터를 검색 self.check_set_collection() key = self.convert_key_format(key) if not self.check_exist_data_by_key(key): print("[ERROR] collection에 데이터가 존재하지 않습니다.") return _, vector = self.engine.get_entity_by_id( collection_name=self.collection_name, ids=key) _, result = self.engine.search(collection_name=self.collection_name, query_records=vector, top_k=top_k + 1) li_id = [ list(map(lambda x: x.id, result[0][1:])) for i in range(len(result)) ] li_dist = [ list(map(lambda x: x.distance, result[0][1:])) for i in range(len(result)) ] return li_id, li_dist
class MilvusHelper: def __init__(self): try: self.client = Milvus(host=MILVUS_HOST, port=MILVUS_PORT) LOGGER.debug( "Successfully connect to Milvus with IP:{} and PORT:{}".format( MILVUS_HOST, MILVUS_PORT)) except Exception as e: LOGGER.error("Failed to connect Milvus: {}".format(e)) sys.exit(1) # Return if Milvus has the collection def has_collection(self, collection_name): try: status = self.client.has_collection(collection_name)[1] return status except Exception as e: LOGGER.error("Failed to load data to Milvus: {}".format(e)) sys.exit(1) # Create milvus collection if not exists def create_colllection(self, collection_name): try: if not self.has_collection(collection_name): collection_param = { 'collection_name': collection_name, 'dimension': VECTOR_DIMENSION, 'index_file_size': INDEX_FILE_SIZE, 'metric_type': METRIC_TYPE } status = self.client.create_collection(collection_param) if status.code != 0: raise Exception(status.message) LOGGER.debug( "Create Milvus collection: {}".format(collection_name)) except Exception as e: LOGGER.error("Failed to load data to Milvus: {}".format(e)) sys.exit(1) # Batch insert vectors to milvus collection def insert(self, collection_name, vectors): try: self.create_colllection(collection_name) status, ids = self.client.insert(collection_name=collection_name, records=vectors) if not status.code: LOGGER.debug( "Insert vectors to Milvus in collection: {} with {} rows". format(collection_name, len(vectors))) return ids else: raise Exception(status.message) except Exception as e: LOGGER.error("Failed to load data to Milvus: {}".format(e)) sys.exit(1) # Create IVF_FLAT index on milvus collection def create_index(self, collection_name): try: index_param = {'nlist': 16384} status = self.client.create_index(collection_name, IndexType.IVF_FLAT, index_param) if not status.code: LOGGER.debug( "Successfully create index in collection:{} with param:{}". format(collection_name, index_param)) return status else: raise Exception(status.message) except Exception as e: LOGGER.error("Failed to create index: {}".format(e)) sys.exit(1) # Delete Milvus collection def delete_collection(self, collection_name): try: status = self.client.drop_collection( collection_name=collection_name) if not status.code: LOGGER.debug( "Successfully drop collection: {}".format(collection_name)) return status else: raise Exception(status.message) except Exception as e: LOGGER.error("Failed to drop collection: {}".format(e)) sys.exit(1) # Search vector in milvus collection def search_vectors(self, collection_name, vectors, top_k): try: search_param = {'nprobe': 16} status, result = self.client.search( collection_name=collection_name, query_records=vectors, top_k=top_k, params=search_param) if not status.code: LOGGER.debug("Successfully search in collection: {}".format( collection_name)) return result else: raise Exception(status.message) except Exception as e: LOGGER.error("Failed to search vectors in Milvus: {}".format(e)) sys.exit(1) # Get the number of milvus collection def count(self, collection_name): try: status, num = self.client.count_entities( collection_name=collection_name) if not status.code: LOGGER.debug( "Successfully get the num:{} of the collection:{}".format( num, collection_name)) return num else: raise Exception(status.message) except Exception as e: LOGGER.error("Failed to count vectors in Milvus: {}".format(e)) sys.exit(1)
class MilvusHelper(BaseVectorSimilarityHelper): def __init__(self, _server_url, _server_port, _timeout=10): super().__init__() self.server_url = _server_url self.server_port = _server_port self.timeout = _timeout self.client = None self.metric_type_mapper = { VectorMetricType.L2: MetricType.L2, VectorMetricType.IP: MetricType.IP, VectorMetricType.JACCARD: MetricType.JACCARD, VectorMetricType.HAMMING: MetricType.HAMMING, } self.index_type_mapper = { VectorIndexType.FLAT: IndexType.FLAT, VectorIndexType.IVFLAT: IndexType.IVFLAT, VectorIndexType.IVF_SQ8: IndexType.IVF_SQ8, VectorIndexType.RNSG: IndexType.RNSG, VectorIndexType.IVF_SQ8H: IndexType.IVF_SQ8H, VectorIndexType.IVF_PQ: IndexType.IVF_PQ, VectorIndexType.HNSW: IndexType.HNSW, VectorIndexType.ANNOY: IndexType.ANNOY, } def init(self): if self.client is None: if not (self.server_url is None or self.server_url is None): try: self.client = Milvus(host=self.server_url, port=self.server_port) except: raise MilvusRuntimeException(f'cannot connect to {self.server_url}:{self.server_port}') else: raise MilvusRuntimeException('Milvus config is not correct') def insert(self, _database_name, _to_insert_vector, _partition_tag=None, _params=None): """ 向数据库中插入一系列的特征向量 notes:如果用户有自己的id,建议使用insert_with_id函数 ATTENTION!!! 一个库中不能既调用insert_with_id还调用insert,只能调用一种,否则会报错 Args: _database_name: 数据库名称 _to_insert_vector: 待插入的特征向量的列表 _partition_tag: 分区标签 _params: 插入参数 Returns: 插入后的id """ self.init() status, ids = self.client.insert(_database_name, _to_insert_vector, partition_tag=_partition_tag, params=_params, timeout=self.timeout) self.flush(_database_name) if status.OK(): return ids else: raise MilvusRuntimeException(status.message) def insert_with_id(self, _database_name, _to_insert_vector, _to_insert_ids, _partition_tag=None, _params=None): """ 向数据库中插入一系列的有固定id的特征向量 ATTENTION!!! 一个库中不能既调用insert_with_id还调用insert,只能调用一种,否则会报错 Args: _database_name: 数据库名称 _to_insert_vector: 待插入的特征向量的列表 _to_insert_ids: 待插入的特征向量的id的列表,每个元素必须为正整数,且不越界 _partition_tag: 分区标签 _params: 插入参数 Returns: 插入后的id """ self.init() status, ids = self.client.insert(_database_name, _to_insert_vector, ids=_to_insert_ids, partition_tag=_partition_tag, params=_params, timeout=self.timeout) self.flush(_database_name) if status.OK(): return ids else: raise MilvusRuntimeException(status.message) def delete(self, _database_name, _to_delete_ids): """ 删除特定id Args: _database_name: 数据库名称 _to_delete_ids: 待删除的id Returns: 是否删除成功 """ self.init() status = self.client.delete_entity_by_id(_database_name, _to_delete_ids, self.timeout) self.flush(_database_name) if status.OK(): return True else: raise MilvusRuntimeException(status.message) def database_exist(self, _database_name): """ 数据库是否存在 Args: _database_name: 数据库名称 Returns: 是否存在 """ self.init() status, is_exist = self.client.has_collection(_database_name, self.timeout) if status.OK(): return is_exist else: raise MilvusRuntimeException(status.message) def create_database(self, _database_name, _dimension, _index_file_size, _metric_type): """ 创建数据库 Args: _database_name: 数据库名称 _dimension: 特征向量维度 _index_file_size: index的文件大小 _metric_type: 度量类型 Returns: 是否创建成功 """ self.init() if not self.database_exist(_database_name): assert _metric_type in self.metric_type_mapper, f'{_metric_type} not support in milvus' status = self.client.create_collection({ 'collection_name': _database_name, 'dimension': _dimension, 'index_file_size': _index_file_size, 'metric_type': self.metric_type_mapper[_metric_type] }) if status.OK(): return True else: raise MilvusRuntimeException(status.message) else: return True def create_index(self, _database_name, _index_type): """ 创建index(索引) Args: _database_name: 数据库名称 _index_type: index类型 Returns: 是否创建成功 """ self.init() if self.database_exist(_database_name): assert _index_type in self.index_type_mapper, f'{_index_type} not support in milvus' status = self.client.create_index(_database_name, self.index_type_mapper[_index_type], timeout=self.timeout) if status.OK(): return True else: raise MilvusRuntimeException(status.message) def search(self, _database_name, _query_vector_list, _top_k, _partition_tag=None, _params=None): """ 检索用参数 Args: _database_name: 数据库名称 _query_vector_list: 检索用的特征向量列表 _top_k: top k _partition_tag: 分区标签 _params: 检索参数 Returns: 检索的结果,包含id和distance """ self.init() if self.database_exist(_database_name): status, search_result = self.client.search(_database_name, _top_k, _query_vector_list, partition_tags=_partition_tag, params=_params, timeout=self.timeout) if status.OK(): return search_result else: raise MilvusRuntimeException(status.message) else: raise DatabaseNotExist(f'{_database_name} not exist') def flush(self, _database_name): """ sink数据库 Args: _database_name: 数据库名称 Returns: 是否flush成功 """ self.init() status = self.client.flush([_database_name, ], self.timeout) if status.OK(): return True else: raise MilvusRuntimeException(status.message) def __del__(self): if self.client is not None: self.client.close()
"query": [query_embedding], "metric_type": "L2", "params": { "nprobe": 8 } } } }] } } # ------ # Basic hybrid search entities # ------ results = client.search(collection_name, query_hybrid, fields=["release_year", "embedding"]) for entities in results: for topk_film in entities: current_entity = topk_film.entity print("==") print("- id: {}".format(topk_film.id)) print("- title: {}".format(titles[topk_film.id])) print("- distance: {}".format(topk_film.distance)) print("- release_year: {}".format(current_entity.release_year)) print("- embedding: {}".format(current_entity.embedding)) # ------ # Basic delete index: # You can drop index for a field.
def main(): # Connect to Milvus server # You may need to change _HOST and _PORT accordingly param = {'host': _HOST, 'port': _PORT} # You can create a instance specified server addr and # invoke rpc method directly client = Milvus(**param) # Create collection demo_collection if it dosen't exist. collection_name = 'demo_partition_collection' partition_tag = "random" status, ok = client.has_collection(collection_name) # if collection exists, then drop it if status.OK() and ok: client.drop_collection(collection_name) param = { 'collection_name': collection_name, 'dimension': _DIM, 'index_file_size': _INDEX_FILE_SIZE, # optional 'metric_type': MetricType.L2 # optional } client.create_collection(param) # Show collections in Milvus server _, collections = client.show_collections() # Describe collection _, collection = client.describe_collection(collection_name) print(collection) # create partition client.create_partition(collection_name, partition_tag=partition_tag) # display partitions _, partitions = client.show_partitions(collection_name) # 10000 vectors with 16 dimension # element per dimension is float32 type # vectors should be a 2-D array vectors = [[random.random() for _ in range(_DIM)] for _ in range(10000)] # You can also use numpy to generate random vectors: # `vectors = np.random.rand(10000, 16).astype(np.float32).tolist()` # Insert vectors into partition of collection, return status and vectors id list status, ids = client.insert(collection_name=collection_name, records=vectors, partition_tag=partition_tag) # Wait for 6 seconds, until Milvus server persist vector data. time.sleep(6) # Get demo_collection row count status, num = client.count_collection(collection_name) # create index of vectors, search more rapidly index_param = { 'nlist': 2048 } # Create ivflat index in demo_collection # You can search vectors without creating index. however, Creating index help to # search faster status = client.create_index(collection_name, IndexType.IVF_FLAT, index_param) # describe index, get information of index status, index = client.describe_index(collection_name) print(index) # Use the top 10 vectors for similarity search query_vectors = vectors[0:10] # execute vector similarity search, search range in partition `partition1` search_param = { "nprobe": 10 } param = { 'collection_name': collection_name, 'query_records': query_vectors, 'top_k': 1, 'partition_tags': ["random"], 'params': search_param } status, results = client.search(**param) if status.OK(): # indicate search result # also use by: # `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]` if results[0][0].distance == 0.0 or results[0][0].id == ids[0]: print('Query result is correct') else: print('Query result isn\'t correct') # print results print(results) # Delete partition. You can also invoke `drop_collection()`, so that all of partitions belongs to # designated collections will be deleted. # status = client.drop_partition(collection_name, partition_tag) # Delete collection. All of partitions of this collection will be dropped. status = client.drop_collection(collection_name)
class MilvusClient(object): def __init__(self, collection_name=None, host=None, port=None, timeout=180): self._collection_name = collection_name start_time = time.time() if not host: host = SERVER_HOST_DEFAULT if not port: port = SERVER_PORT_DEFAULT logger.debug(host) logger.debug(port) # retry connect remote server i = 0 while time.time() < start_time + timeout: try: self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False) break except Exception as e: logger.error(str(e)) logger.error("Milvus connect failed: %d times" % i) i = i + 1 time.sleep(i) if time.time() > start_time + timeout: raise Exception("Server connect timeout") # self._metric_type = None def __str__(self): return 'Milvus collection %s' % self._collection_name def check_status(self, status): if not status.OK(): logger.error(status.message) logger.error(self._milvus.server_status()) logger.error(self.count()) raise Exception("Status not ok") def check_result_ids(self, result): for index, item in enumerate(result): if item[0].distance >= epsilon: logger.error(index) logger.error(item[0].distance) raise Exception("Distance wrong") # only support the given field name def create_collection(self, dimension, data_type=DataType.FLOAT_VECTOR, auto_id=False, collection_name=None, other_fields=None): self._dimension = dimension if not collection_name: collection_name = self._collection_name vec_field_name = utils.get_default_field_name(data_type) fields = [{ "name": vec_field_name, "type": data_type, "params": { "dim": dimension } }] if other_fields: other_fields = other_fields.split(",") if "int" in other_fields: fields.append({ "name": utils.DEFAULT_INT_FIELD_NAME, "type": DataType.INT64 }) if "float" in other_fields: fields.append({ "name": utils.DEFAULT_FLOAT_FIELD_NAME, "type": DataType.FLOAT }) create_param = {"fields": fields, "auto_id": auto_id} try: self._milvus.create_collection(collection_name, create_param) logger.info("Create collection: <%s> successfully" % collection_name) except Exception as e: logger.error(str(e)) raise def create_partition(self, tag, collection_name=None): if not collection_name: collection_name = self._collection_name self._milvus.create_partition(collection_name, tag) def generate_values(self, data_type, vectors, ids): values = None if data_type in [DataType.INT32, DataType.INT64]: values = ids elif data_type in [DataType.FLOAT, DataType.DOUBLE]: values = [(i + 0.0) for i in ids] elif data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: values = vectors return values def generate_entities(self, vectors, ids=None, collection_name=None): entities = [] if collection_name is None: collection_name = self._collection_name info = self.get_info(collection_name) for field in info["fields"]: field_type = field["type"] entities.append({ "name": field["name"], "type": field_type, "values": self.generate_values(field_type, vectors, ids) }) return entities @time_wrapper def insert(self, entities, ids=None, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name try: insert_ids = self._milvus.insert(tmp_collection_name, entities, ids=ids) return insert_ids except Exception as e: logger.error(str(e)) def get_dimension(self): info = self.get_info() for field in info["fields"]: if field["type"] in [ DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR ]: return field["params"]["dim"] def get_rand_ids(self, length): segment_ids = [] while True: stats = self.get_stats() segments = stats["partitions"][0]["segments"] # random choice one segment segment = random.choice(segments) try: segment_ids = self._milvus.list_id_in_segment( self._collection_name, segment["id"]) except Exception as e: logger.error(str(e)) if not len(segment_ids): continue elif len(segment_ids) > length: return random.sample(segment_ids, length) else: logger.debug("Reset length: %d" % len(segment_ids)) return segment_ids # def get_rand_ids_each_segment(self, length): # res = [] # status, stats = self._milvus.get_collection_stats(self._collection_name) # self.check_status(status) # segments = stats["partitions"][0]["segments"] # segments_num = len(segments) # # random choice from each segment # for segment in segments: # status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) # self.check_status(status) # res.extend(segment_ids[:length]) # return segments_num, res # def get_rand_entities(self, length): # ids = self.get_rand_ids(length) # status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids) # self.check_status(status) # return ids, get_res def get(self): get_ids = random.randint(1, 1000000) self._milvus.get_entity_by_id(self._collection_name, [get_ids]) @time_wrapper def get_entities(self, get_ids): get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids) return get_res @time_wrapper def delete(self, ids, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name self._milvus.delete_entity_by_id(tmp_collection_name, ids) def delete_rand(self): delete_id_length = random.randint(1, 100) count_before = self.count() logger.debug("%s: length to delete: %d" % (self._collection_name, delete_id_length)) delete_ids = self.get_rand_ids(delete_id_length) self.delete(delete_ids) self.flush() logger.info("%s: count after delete: %d" % (self._collection_name, self.count())) get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids) for item in get_res: assert not item # if count_before - len(delete_ids) < self.count(): # logger.error(delete_ids) # raise Exception("Error occured") @time_wrapper def flush(self, _async=False, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name self._milvus.flush([tmp_collection_name], _async=_async) @time_wrapper def compact(self, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name status = self._milvus.compact(tmp_collection_name) self.check_status(status) @time_wrapper def create_index(self, field_name, index_type, metric_type, _async=False, index_param=None): index_type = INDEX_MAP[index_type] metric_type = utils.metric_type_trans(metric_type) logger.info( "Building index start, collection_name: %s, index_type: %s, metric_type: %s" % (self._collection_name, index_type, metric_type)) if index_param: logger.info(index_param) index_params = { "index_type": index_type, "metric_type": metric_type, "params": index_param } self._milvus.create_index(self._collection_name, field_name, index_params, _async=_async) # TODO: need to check def describe_index(self, field_name): # stats = self.get_stats() info = self._milvus.describe_index(self._collection_name, field_name) index_info = {"index_type": "flat", "index_param": None} for field in info["fields"]: for index in field['indexes']: if not index or "index_type" not in index: continue else: for k, v in INDEX_MAP.items(): if index['index_type'] == v: index_info['index_type'] = k index_info['index_param'] = index['params'] return index_info return index_info def drop_index(self, field_name): logger.info("Drop index: %s" % self._collection_name) return self._milvus.drop_index(self._collection_name, field_name) @time_wrapper def query(self, vector_query, filter_query=None, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name must_params = [vector_query] if filter_query: must_params.extend(filter_query) query = {"bool": {"must": must_params}} result = self._milvus.search(tmp_collection_name, query) return result @time_wrapper def load_and_query(self, vector_query, filter_query=None, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name must_params = [vector_query] if filter_query: must_params.extend(filter_query) query = {"bool": {"must": must_params}} self.load_collection(tmp_collection_name) result = self._milvus.search(tmp_collection_name, query) return result def get_ids(self, result): idss = result._entities.ids ids = [] len_idss = len(idss) len_r = len(result) top_k = len_idss // len_r for offset in range(0, len_idss, top_k): ids.append(idss[offset:min(offset + top_k, len_idss)]) return ids def query_rand(self, nq_max=100): # for ivf search dimension = 128 top_k = random.randint(1, 100) nq = random.randint(1, nq_max) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] metric_type = random.choice(["l2", "ip"]) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) vec_field_name = utils.get_default_field_name() vector_query = { "vector": { vec_field_name: { "topk": top_k, "query": query_vectors, "metric_type": utils.metric_type_trans(metric_type), "params": search_param } } } self.query(vector_query) def load_query_rand(self, nq_max=100): # for ivf search dimension = 128 top_k = random.randint(1, 100) nq = random.randint(1, nq_max) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] metric_type = random.choice(["l2", "ip"]) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) vec_field_name = utils.get_default_field_name() vector_query = { "vector": { vec_field_name: { "topk": top_k, "query": query_vectors, "metric_type": utils.metric_type_trans(metric_type), "params": search_param } } } self.load_and_query(vector_query) # TODO: need to check def count(self, collection_name=None): if collection_name is None: collection_name = self._collection_name row_count = self._milvus.get_collection_stats( collection_name)["row_count"] logger.debug("Row count: %d in collection: <%s>" % (row_count, collection_name)) return row_count def drop(self, timeout=120, collection_name=None): timeout = int(timeout) if collection_name is None: collection_name = self._collection_name logger.info("Start delete collection: %s" % collection_name) self._milvus.drop_collection(collection_name) i = 0 while i < timeout: try: row_count = self.count(collection_name=collection_name) if row_count: time.sleep(1) i = i + 1 continue else: break except Exception as e: logger.debug(str(e)) break if i >= timeout: logger.error("Delete collection timeout") def get_stats(self): return self._milvus.get_collection_stats(self._collection_name) def get_info(self, collection_name=None): # pdb.set_trace() if collection_name is None: collection_name = self._collection_name return self._milvus.get_collection_info(collection_name) def show_collections(self): return self._milvus.list_collections() def exists_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name res = self._milvus.has_collection(collection_name) return res def clean_db(self): collection_names = self.show_collections() for name in collection_names: self.drop(collection_name=name) @time_wrapper def load_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.load_collection(collection_name, timeout=3000) @time_wrapper def release_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.release_collection(collection_name, timeout=3000) @time_wrapper def load_partitions(self, tag_names, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.load_partitions(collection_name, tag_names, timeout=3000) @time_wrapper def release_partitions(self, tag_names, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.release_partitions(collection_name, tag_names, timeout=3000)
class MilvusDBHandler: """Milvus DB handler This class is intended to abstract the access and communication with external MilvusDB from Executors For more information about Milvus: - https://github.com/milvus-io/milvus/ """ @staticmethod def get_index_type(index_type): from milvus import IndexType return { 'Flat': IndexType.FLAT, 'IVF,Flat': IndexType.IVFLAT, 'IVF,SQ8': IndexType.IVF_SQ8, 'RNSG': IndexType.RNSG, 'IVF,SQ8H': IndexType.IVF_SQ8H, 'IVF,PQ': IndexType.IVF_PQ, 'HNSW': IndexType.IVF_PQ, 'Annoy': IndexType.ANNOY }.get(index_type, IndexType.FLAT) class MilvusDBInserter: """Milvus DB Inserter This class is an inner class and provides a context manager to insert vectors into Milvus while ensuring data is flushed. For more information about Milvus: - https://github.com/milvus-io/milvus/ """ def __init__(self, client, collection_name: str): self.logger = get_logger(self.__class__.__name__) self.client = client self.collection_name = collection_name def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.logger.info(f'Sending flush command to Milvus Server for collection: {self.collection_name}') self.client.flush([self.collection_name]) def insert(self, keys: list, vectors: 'np.ndarray'): status, _ = self.client.insert(collection_name=self.collection_name, records=vectors, ids=keys) if not status.OK(): self.logger.error('Insert failed: {}'.format(status)) raise MilvusDBException(status.message) def __init__(self, host: str, port: int, collection_name: str): """ Initialize an MilvusDBHandler :param host: Host of the Milvus Server :param port: Port to connect to the Milvus Server :param collection_name: Name of the collection where the Handler will insert and query vectors. """ self.logger = get_logger(self.__class__.__name__) self.host = host self.port = str(port) self.collection_name = collection_name self.milvus_client = None def __enter__(self): return self.connect() def __exit__(self, exc_type, exc_val, exc_tb): self.close() def connect(self): from milvus import Milvus if self.milvus_client is None or not self.milvus_client.server_status()[0].OK(): self.logger.info(f'Setting connection to Milvus Server at {self.host}:{self.port}') self.milvus_client = Milvus(self.host, self.port) return self def close(self): self.logger.info(f'Closing connection to Milvus Server at {self.host}:{self.port}') self.milvus_client.close() def insert(self, keys: 'np.ndarray', vectors: 'np.ndarray'): with MilvusDBHandler.MilvusDBInserter(self.milvus_client, self.collection_name) as db: db.insert(reduce(operator.concat, keys.tolist()), vectors) def build_index(self, index_type: str, index_params: dict): type = self.get_index_type(index_type) self.logger.info(f'Creating index of type: {index_type} at' f' Milvus Server. collection: {self.collection_name} with index params: {index_params}') status = self.milvus_client.create_index(self.collection_name, type, index_params) if not status.OK(): self.logger.error('Creating index failed: {}'.format(status)) raise MilvusDBException(status.message) def search(self, query_vectors: 'np.ndarray', top_k: int, search_params: dict = None): self.logger.info(f'Querying collection: {self.collection_name} with search params: {search_params}') status, results = self.milvus_client.search(collection_name=self.collection_name, query_records=query_vectors, top_k=top_k, params=search_params) if not status.OK(): self.logger.error('Querying index failed: {}'.format(status)) raise MilvusDBException(status.message) else: return results.distance_array, results.id_array