Ejemplo n.º 1
0
def main():
    sentences = word2vec.Text8Corpus("text8")  # 加载语料
    model = word2vec.Word2Vec(sentences, size=200, window=5,
                              min_count=5)  # 训练模型
    word_set = model.wv.index2word  # 单词集合
    word_vec = model.wv.vectors  # word2vec结果向量集合
    milvus = Milvus()
    milvus.connect(host='localhost', port='19530')
    param = {
        'collection_name': 'word2vec',
        'dimension': 200,
        'index_file_size': 1024,
        'metric_type': MetricType.L2
    }
    milvus.create_collection(param)

    status, ids = milvus.insert(collection_name='word2vec', records=word_vec)

    # 单词分类
    ivf_param = {'nlist': 100}  # 分成100类
    milvus.create_index('word2vec', IndexType.IVF_FLAT, ivf_param)  # 增加索引
    status, index = milvus.describe_index(
        'word2vec')  # 相当于将word分成100个类别 做了聚类算法

    # 查找相似度最高的单词
    res = milvus.search(collection_name='word2vec',
                        query_records=[list(word_vec[word_set.index('king')])],
                        top_k=10,
                        params={'nprobe': 16})
    for i in range(10):
        id = res[1][0][i].id
        print(word_set[ids.index(id)])

    print(1)
Ejemplo n.º 2
0
def run_offline_paper():
    client = Milvus(host=milvus_ip, port='19530')
    cur.execute("SELECT ID ,doc_vector FROM paper")
    papers = cur.fetchall()
    for i in papers:
        try:
            id = i[0]
            vec = i[1].split(",")
            vec = [eval(j) for j in vec]
            res = client.search(collection_name='ideaman',
                                query_records=[vec],
                                top_k=51)

            status = res[0].code
            if status == 0:
                topKqueryResult = [str(j) for j in res[-1]._id_array[0]]
                paper_vecs = ",".join(topKqueryResult[1:])
                sql = 'INSERT INTO offline_paper(paper_id , recs) VALUES({} , "{}")'.format(
                    id, paper_vecs)
                cur.execute(sql)
                try:
                    conn.commit()
                except:
                    conn.rollback()
        except:
            pass
Ejemplo n.º 3
0
def milvus_test(usr_features, IS_INFER, mov_features=None, ids=None):
    _HOST = '127.0.0.1'
    _PORT = '19530'  # default value
    table_name = 'recommender_demo'
    milvus = Milvus()

    param = {'host': _HOST, 'port': _PORT}
    status = milvus.connect(**param)
    if status.OK():
        print("Server connected.")
    else:
        print("Server connect fail.")
        sys.exit(1)

    if IS_INFER:
        status = milvus.drop_collection(table_name)
        time.sleep(3)

    status, ok = milvus.has_collection(table_name)
    if not ok:
        if mov_features is None:
            print("Insert vectors is none!")
            sys.exit(1)
        param = {
            'collection_name': table_name,
            'dimension': 200,
            'index_file_size': 1024,  # optional
            'metric_type': MetricType.IP  # optional
        }

        print(milvus.create_collection(param))

        insert_vectors = normaliz_data(mov_features)
        status, ids = milvus.insert(collection_name=table_name,
                                    records=insert_vectors,
                                    ids=ids)

        time.sleep(1)

    status, result = milvus.count_collection(table_name)
    print("rows in table recommender_demo:", result)

    search_vectors = normaliz_data(usr_features)
    param = {
        'collection_name': table_name,
        'query_records': search_vectors,
        'top_k': 5,
        'params': {
            'nprobe': 16
        }
    }
    time1 = time.time()
    status, results = milvus.search(**param)
    time2 = time.time()

    print("Top\t", "Ids\t", "Title\t", "Score")
    for i, re in enumerate(results[0]):
        title = paddle.dataset.movielens.movie_info()[int(re.id)].title
        print(i, "\t", re.id, "\t", title, "\t", float(re.distance) * 5)
Ejemplo n.º 4
0
    def test_not_connect(self):
        client = Milvus()

        with pytest.raises(NotConnectError):
            client.create_collection({})

        with pytest.raises(NotConnectError):
            client.has_collection("a")

        with pytest.raises(NotConnectError):
            client.describe_collection("a")

        with pytest.raises(NotConnectError):
            client.drop_collection("a")

        with pytest.raises(NotConnectError):
            client.create_index("a")

        with pytest.raises(NotConnectError):
            client.insert("a", [], None)

        with pytest.raises(NotConnectError):
            client.count_collection("a")

        with pytest.raises(NotConnectError):
            client.show_collections()

        with pytest.raises(NotConnectError):
            client.search("a", 1, 2, [], None)

        with pytest.raises(NotConnectError):
            client.search_in_files("a", [], [], 2, 1, None)

        with pytest.raises(NotConnectError):
            client._cmd("")

        with pytest.raises(NotConnectError):
            client.preload_collection("a")

        with pytest.raises(NotConnectError):
            client.describe_index("a")

        with pytest.raises(NotConnectError):
            client.drop_index("")
Ejemplo n.º 5
0
 def search_vectors(name, vector, topk, nprobe):
     search_param = {'nprobe': nprobe}
     try:
         milvus = Milvus(host=MILVUS_ADDR, port=MILVUS_PORT)
         res, ids = milvus.search(collection_name=name, query_records=vector, top_k=topk, params=search_param)
         if not res.OK():
             raise MilvusError("There was some error when search vectors", res)
         return ids
     except Exception as e:
         err_msg = "There was some error when search vectors"
         logger.error(f"{err_msg} : {str(e)}", exc_info=True)
         raise MilvusError(err_msg, e)
Ejemplo n.º 6
0
 def search_vectors(name, vector, topk, nprobe):
     milvus = Milvus()
     search_param = {'nprobe': nprobe}
     try:
         milvus.connect(MILVUS_ADDR, MILVUS_PORT)
         res, ids = milvus.search(collection_name=name,
                                  query_records=vector,
                                  top_k=topk,
                                  params=search_param)
         if not res.OK():
             raise MilvusError("There has some error when search vectors",
                               res)
         return ids
     except Exception as e:
         raise MilvusError("There has some error when search vectors", e)
Ejemplo n.º 7
0
def milvus_test(usr_features, mov_features, ids):
    _HOST = '127.0.0.1'
    _PORT = '19530'  # default value
    milvus = Milvus()

    param = {'host': _HOST, 'port': _PORT}
    status = milvus.connect(**param)
    if status.OK():
        print("\nServer connected.")
    else:
        print("\nServer connect fail.")
        sys.exit(1)

    table_name = 'paddle_demo1'

    status, ok = milvus.has_collection(table_name)
    if not ok:
        param = {
            'collection_name': table_name,
            'dimension': 200,
            'index_file_size': 1024,  # optional
            'metric_type': MetricType.IP  # optional
        }

        milvus.create_collection(param)

    insert_vectors = normaliz_data([usr_features.tolist()])
    status, ids = milvus.insert(collection_name=table_name, records=insert_vectors, ids = ids)

    time.sleep(1)

    status, result = milvus.count_collection(table_name)
    print("rows in table paddle_demo1:", result)

    status, table = milvus.count_collection(table_name)

    search_vectors = normaliz_data([mov_features.tolist()])
    param = {
        'collection_name': table_name,
        'query_records': search_vectors,
        'top_k': 1,
        'params': {'nprobe': 16}
    }
    status, results = milvus.search(**param)
    print("Searched ids:", results[0][0].id)
    print("Score:", float(results[0][0].distance)*5)

    status = milvus.drop_collection(table_name)
Ejemplo n.º 8
0
class MilvusConnection:
    def __init__(self, env, name="movies_L2", port="19530", param=None):

        if param is None:
            param = dict()
        param = {
            "collection_name": name,
            "dimension": 128,
            "index_file_size": 1024,
            "metric_type": MetricType.L2,
            **param,
        }
        self.name = name
        self.client = Milvus(host="localhost", port=port)
        self.statuses = {}
        if not self.client.has_collection(name)[1]:
            status_created_collection = self.client.create_collection(param)
            vectors = env.base.embeddings.detach().cpu().numpy().astype(
                "float32")
            target_ids = list(range(vectors.shape[0]))
            status_inserted, inserted_vector_ids = self.client.insert(
                collection_name=name, records=vectors, ids=target_ids)
            status_flushed = self.client.flush([name])
            status_compacted = self.client.compact(collection_name=name)
            self.statuses["created_collection"] = status_created_collection
            self.statuses["inserted"] = status_inserted
            self.statuses["flushed"] = status_flushed
            self.statuses["compacted"] = status_compacted

    def search(self, search_vecs, topk=10, search_param=None):
        if search_param is None:
            search_param = dict()
        search_param = {"nprobe": 16, **search_param}
        status, results = self.client.search(
            collection_name=self.name,
            query_records=search_vecs,
            top_k=topk,
            params=search_param,
        )
        self.statuses['last_search'] = status
        return torch.tensor(results.id_array)

    def get_log(self):
        return self.statuses
Ejemplo n.º 9
0
}
milvus.create_collection(param=param)

ivf_param = {'nlist': 16384}
milvus.create_index(collection_name=col_name,
                    index_type=IndexType.IVF_FLAT,
                    params=ivf_param)

vectors = [[random.random() for _ in range(dim)] for _ in range(2000)]
vector_ids = list(range(2000))
_, ids = milvus.insert(collection_name=col_name,
                       records=vectors,
                       ids=vector_ids)
# print(ids)

time.sleep(1)
search_param = {'nprobe': 16}
q_records = [[random.random() for _ in range(dim)] for _ in range(5)]
_, result = milvus.search(collection_name=col_name,
                          query_records=q_records,
                          top_k=2,
                          params=search_param)
# for r in result:
#     print(r)
print(result.id_array)
print(result)

milvus.drop_collection(collection_name=col_name)

milvus.close()
Ejemplo n.º 10
0
id_query = {}
with open("embed.txt", "r") as f:
    for line in f:
        line_lst = line.strip().split("\t")
        vectors.append(list(map(float, (line_lst[2].split()))))
        vector_ids.append(int(line_lst[1]))
        id_query[int(line_lst[1])] = line_lst[0]

milvus.insert(collection_name='test01', records=vectors, ids=vector_ids)

ivf_param = {'nlist': 16384}
milvus.create_index('test01', IndexType.IVF_FLAT, ivf_param)

search_param = {'nprobe': 16}
q_records = vectors[:1]
t = time.time()
result = milvus.search(collection_name='test01',
                       query_records=q_records,
                       top_k=10,
                       params=search_param)[1]
print(result)
print(time.time() - t)

t = time.time()
result = milvus.search(collection_name='test01',
                       query_records=q_records,
                       top_k=10,
                       params=search_param)[1]
print(result)
print(time.time() - t)
Ejemplo n.º 11
0
def main():
    milvus = Milvus(_HOST, _PORT)

    # num = random.randint(1, 100000)
    num = 100000
    # Create collection demo_collection if it dosen't exist.
    collection_name = 'example_hybrid_collections_{}'.format(num)
    if milvus.has_collection(collection_name):
        milvus.drop_collection(collection_name)

    collection_param = {
        "fields": [{
            "field": "A",
            "type": DataType.INT32
        }, {
            "field": "B",
            "type": DataType.INT32
        }, {
            "field": "C",
            "type": DataType.INT64
        }, {
            "field": "Vec",
            "type": DataType.FLOAT_VECTOR,
            "params": {
                "dim": 128,
                "metric_type": "L2"
            }
        }],
        "segment_size":
        100
    }
    milvus.create_collection(collection_name, collection_param)

    milvus.compact(collection_name)

    # milvus.create_partition(collection_name, "p_01", timeout=1800)
    # pars = milvus.list_partitions(collection_name)
    # ok = milvus.has_partition(collection_name, "p_01", timeout=1800)
    # assert ok
    # ok = milvus.has_partition(collection_name, "p_02")
    # assert not ok
    # for p in pars:
    #     if p == "_default":
    #         continue
    #     milvus.drop_partition(collection_name, p)

    # milvus.drop_collection(collection_name)
    # sys.exit(0)

    A_list = [random.randint(0, 255) for _ in range(num)]
    vec = [[random.random() for _ in range(128)] for _ in range(num)]
    hybrid_entities = [{
        "field": "A",
        "values": A_list,
        "type": DataType.INT32
    }, {
        "field": "B",
        "values": A_list,
        "type": DataType.INT32
    }, {
        "field": "C",
        "values": A_list,
        "type": DataType.INT64
    }, {
        "field": "Vec",
        "values": vec,
        "type": DataType.FLOAT_VECTOR,
        "params": {
            "dim": 128
        }
    }]

    for slice_e in utils.entities_slice(hybrid_entities):
        ids = milvus.insert(collection_name, slice_e)
    milvus.flush([collection_name])
    print("Flush ... ")
    # time.sleep(3)
    count = milvus.count_entities(collection_name)

    milvus.delete_entity_by_id(collection_name, ids[:1])
    milvus.flush([collection_name])
    print("Get entity be id start ...... ")
    entities = milvus.get_entity_by_id(collection_name, ids[:1])
    et = entities.dict()
    milvus.delete_entity_by_id(collection_name, ids[1:2])
    milvus.flush([collection_name])

    print("Create index ......")
    milvus.create_index(collection_name, "Vec", {
        "index_type": "IVF_FLAT",
        "metric_type": "L2",
        "params": {
            "nlist": 100
        }
    })
    print("Create index done.")

    info = milvus.get_collection_info(collection_name)
    print(info)
    stats = milvus.get_collection_stats(collection_name)
    print("\nstats\n")
    print(stats)
    query_hybrid = \
    {
        "bool": {
            "must": [
                {
                    "term": {
                        "A": [1, 2, 5]
                    }
                },
                {
                    "range": {
                        "B": {"GT": 1, "LT": 100}
                    }
                },
                {
                    "vector": {
                        "Vec": {
                            "topk": 10, "query": vec[: 10000], "params": {"nprobe": 10}
                        }
                    }
                }
            ],
        },
    }

    # print("Start searach ..", flush=True)
    # results = milvus.search(collection_name, query_hybrid)
    # print(results)
    #
    # for r in list(results):
    #     print("ids", r.ids)
    #     print("distances", r.distances)

    t0 = time.time()
    count = 0
    results = milvus.search(collection_name, query_hybrid, fields=["B"])
    for r in list(results):
        # print("ids", r.ids)
        # print("distances", r.distances)
        for rr in r:
            count += 1
            # print(rr.entity.get("B"))

    print("Search cost {} s".format(time.time() - t0))

    # for result in results:
    #     for r in result:
    #         print(f"{r}")

    # itertor entity id
    # for result in results:
    #     for r in result:
    #         # get distance
    #         dis = r.distance
    #         id_ = r.id
    #         # obtain all field name
    #         fields = r.entity.fields
    #         for f in fields:
    #             # get field value by field name
    #             # fv = r.entity.
    #             fv = r.entity.value_of_field(f)
    #             print(fv)

    milvus.drop_collection(collection_name)
Ejemplo n.º 12
0
# ------
# Basic hybrid search entities:
#     And we want to get all the fields back in results, so fields = ["duration", "release_year", "embedding"].
#     If searching successfully, results will be returned.
#     `results` have `nq`(number of queries) separate results, since we only query for 1 film, The length of
#     `results` is 1.
#     We ask for top 3 in-return, but our condition is too strict while the database is too small, so we can
#     only get 1 film, which means length of `entities` in below is also 1.
#
#     Now we've gotten the results, and known it's a 1 x 1 structure, how can we get ids, distances and fields?
#     It's very simple, for every `topk_film`, it has three properties: `id, distance and entity`.
#     All fields are stored in `entity`, so you can finally obtain these data as below:
#     And the result should be film with id = 3.
# ------
results = client.search(collection_name, dsl, fields=["duration", "release_year", "embedding"])
print("\n----------search----------")
for entities in results:
    for topk_film in entities:
        current_entity = topk_film.entity
        print("- id: {}".format(topk_film.id))
        print("- distance: {}".format(topk_film.distance))

        print("- release_year: {}".format(current_entity.release_year))
        print("- duration: {}".format(current_entity.duration))
        print("- embedding: {}".format(current_entity.embedding))

# ------
# Basic delete:
#     Now let's see how to delete things in Milvus.
#     You can simply delete entities by their ids.
Ejemplo n.º 13
0
    def search(self, img_id):

        x_real, x_fake, cost, optimizer, z_noise, z_real, guessed_z = self.build_model()
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            model_file = tf.train.latest_checkpoint('D:/py_project/untitled1/polls/homework/vae/ckpt/')
            # 加载参数
            saver.restore(sess, model_file)

            # img_id = 'test_img_0.png'
            search_img_path = "D:/py_project/untitled1/polls/static/polls/homework/"+img_id
            search_img = cv2.imread(search_img_path, cv2.IMREAD_GRAYSCALE)
            search_img = np.reshape(search_img, [-1, 784])/255

            # 获取 要搜索图片的特征向量

            test_z = sess.run(guessed_z, feed_dict={x_real: search_img})

            print(img_id, test_z)

            # 连接milvus数据库 并进行查询
            milvus = Milvus()
            milvus.connect(host='localhost', port='19530')
            collection_name = 'mnist'
            search_param = {'nprobe': 16}
            status, result = milvus.search(collection_name=collection_name, query_records=test_z, top_k=100,
                                           params=search_param)
            milvus.disconnect()

            # index = []
            # for row in result:
            #     for item in row:
            #         index.append(item.id)
            # batch_real = self.train_images[index]
            # r, c = 10, 10
            # fig, axs = plt.subplots(r, c)
            # cnt = 0
            # for p in range(r):
            #     for q in range(c):
            #         axs[p, q].imshow(np.reshape(batch_real[cnt], (28, 28)), cmap='gray')
            #         axs[p, q].axis('off')
            #         cnt += 1
            # print(os.path)
            # fig.savefig("D:/py_project/untitled1/polls/static/polls/homework/d_real.png")
            # plt.show()
            # plt.close()

            # 保存查询结果
            for row in result:
                for i, item in enumerate(row):
                    # a = 1
                    # print(item.id)
                    result_img_path = "D:/py_project/untitled1/polls/static/polls/homework/search_result/result_"
                    batch_test_pic = self.train_images[item.id]
                    pic_name = result_img_path + str(i) + ".png"
                    cv2.imwrite(pic_name, batch_test_pic)

            print('search success!')

            return True
Ejemplo n.º 14
0
#     And we want to get all the fields back in reasults, so fields = ["duration", "release_year", "embedding"].
#     If searching successfully, results will be returned.
#     `results` have `nq`(number of queries) seperate results, since we only query for 1 film, The length of
#     `results` is 1.
#     We ask for top 3 in-return, but our condition is too strict while the database is too small, so we can
#     only get 1 film, which means length of `entities` in below is also 1.
#
#     Now we've gotten the results, and known it's a 1 x 1 structure, how can we get ids, distances and fields?
#     It's very simple, for every `topk_film`, it has three properties: `id, distance and entity`.
#     All fields are stored in `entity`, so you can finally obtain these data as below:
#     And the result should be film with id = 3.
#
#      Here, we pass parameter '_async=True' to insert data asynchronously, and return a `Future` object.
# ------
print("\n----------search----------")
search_future = client.search(collection_name, dsl, _async=True)
search_results = search_future.result()

# ------
# Basic delete:
#     Now let's see how to delete things in Milvus.
#     You can simply delete entities by their ids.
#
#     After deleted, we invoke compact collection in a asynchronous way.
# ------
print("\n----------delete id = 1, id = 2----------")
client.delete_entity_by_id(collection_name, ids=[1, 4])
client.flush()  # flush is important
compact_future = client.compact(collection_name, _async=True)
compact_future.result()
Ejemplo n.º 15
0
class RecallServerServicer(object):
    def __init__(self):
        self.uv_client = LocalPredictor()
        self.uv_client.load_model_config(
            "user_vector_model/serving_server_dir")
        milvus_host = '127.0.0.1'
        milvus_port = '19530'
        self.milvus_client = Milvus(milvus_host, milvus_port)
        self.collection_name = 'demo_films'

    def get_user_vector(self, user_info):
        dic = {"userid": [], "gender": [], "age": [], "occupation": []}
        lod = [0]
        dic["userid"].append(hash2(user_info.user_id))
        dic["gender"].append(hash2(user_info.gender))
        dic["age"].append(hash2(user_info.age))
        dic["occupation"].append(hash2(user_info.job))
        lod.append(1)

        dic["userid.lod"] = lod
        dic["gender.lod"] = lod
        dic["age.lod"] = lod
        dic["occupation.lod"] = lod
        for key in dic:
            dic[key] = np.array(dic[key]).astype(np.int64).reshape(
                len(dic[key]), 1)

        fetch_map = self.uv_client.predict(
            feed=dic, fetch=["save_infer_model/scale_0.tmp_0"], batch=True)
        return fetch_map["save_infer_model/scale_0.tmp_0"].tolist()[0]

    def recall(self, request, context):
        '''
    message RecallRequest{
        string log_id = 1;
        user_info.UserInfo user_info = 2;
        string recall_type= 3;
        uint32 request_num= 4;
    }

    message RecallResponse{
        message Error {
            uint32 code = 1;
            string text = 2;
        }
        message ScorePair {
            string nid = 1;
            float score = 2;
        };
        Error error = 1;
        repeated ScorePair score_pairs = 2;
    }
        '''
        recall_res = recall_pb2.RecallResponse()
        user_vector = self.get_user_vector(request.user_info)

        query_hybrid = {
            "bool": {
                "must": [{
                    "vector": {
                        "embedding": {
                            "topk": 100,
                            "query": [user_vector],
                            "metric_type": "L2"
                        }
                    }
                }]
            }
        }

        results = self.milvus_client.search(self.collection_name,
                                            query_hybrid,
                                            fields=["embedding"])
        for entities in results:
            if len(entities) == 0:
                recall_res.error.code = 500
                recall_res.error.text = "Recall server get milvus fail. ({})".format(
                    str(request))
                return recall_res
            for topk_film in entities:
                current_entity = topk_film.entity
                score_pair = recall_res.score_pairs.add()
                score_pair.nid = str(topk_film.id)
                score_pair.score = float(topk_film.distance)
        recall_res.error.code = 200
        return recall_res
Ejemplo n.º 16
0
def main():
    milvus = Milvus()

    # Connect to Milvus server
    # You may need to change _HOST and _PORT accordingly
    param = {'host': _HOST, 'port': _PORT}
    status = milvus.connect(**param)
    if status.OK():
        print("Server connected.")
    else:
        print("Server connect fail.")
        sys.exit(1)

    # Create collection demo_collection if it dosen't exist.
    collection_name = 'example_collection'

    status, ok = milvus.has_collection(collection_name)
    if not ok:
        param = {
            'collection_name': collection_name,
            'dimension': _DIM,
            'index_file_size': _INDEX_FILE_SIZE,  # optional
            'metric_type': MetricType.L2  # optional
        }

        milvus.create_collection(param)

    # Show collections in Milvus server
    _, collections = milvus.show_collections()

    # present collection info
    _, info = milvus.collection_info(collection_name)
    print(info)

    # Describe demo_collection
    _, collection = milvus.describe_collection(collection_name)
    print(collection)

    # 10000 vectors with 16 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(10000)]
    # You can also use numpy to generate random vectors:
    #     `vectors = np.random.rand(10000, 16).astype(np.float32)`

    # Insert vectors into demo_collection, return status and vectors id list
    status, ids = milvus.insert(collection_name=collection_name, records=vectors)

    # Flush collection  inserted data to disk.
    milvus.flush([collection_name])

    # Get demo_collection row count
    status, result = milvus.count_collection(collection_name)

    # create index of vectors, search more rapidly
    index_param = {
        'nlist': 2048
    }

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    print("Creating index: {}".format(index_param))
    status = milvus.create_index(collection_name, IndexType.IVF_FLAT, index_param)

    # describe index, get information of index
    status, index = milvus.describe_index(collection_name)
    print(index)

    # Use the top 10 vectors for similarity search
    query_vectors = vectors[0:10]

    # execute vector similarity search
    search_param = {
        "nprobe": 16
    }
    param = {
        'collection_name': collection_name,
        'query_records': query_vectors,
        'top_k': 1,
        'params': search_param
    }
    print("Searching ... ")
    status, results = milvus.search(**param)

    if status.OK():
        # indicate search result
        # also use by:
        #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
        if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
            print('Query result is correct')
        else:
            print('Query result isn\'t correct')

    # print results
    print(results)

    # Delete demo_collection
    status = milvus.drop_collection(collection_name)

    # Disconnect from Milvus
    status = milvus.disconnect()
Ejemplo n.º 17
0
    print("Start create index ......")
    status = client.create_index(collection_name, IndexType.ANNOY, index_param)
    if status.OK():
        print("Create index ANNOY successfully\n")
    else:
        print("Create index ANNOY fail")

    # select top 10 vectors from inserted as query vectors
    query_vectors = vectors[:10]

    # specify search param
    search_param = {"search_k": 10}

    # specify topk is 1, search approximate nearest 1 neighbor
    status, result = client.search(collection_name,
                                   1,
                                   query_vectors,
                                   params=search_param)

    if status.OK():
        # show search result
        print("Search successfully. Result:\n", result)
    else:
        print("Search fail: ", status)

    # drop collection
    client.drop_collection(collection_name)

    # disconnect from server
    # client.disconnect()
Ejemplo n.º 18
0
def main():
    milvus = Milvus(handler="HTTP")

    # Connect to Milvus server
    # You may need to change _HOST and _PORT accordingly
    param = {'host': _HOST, 'port': _PORT}
    status = milvus.connect(**param)
    if status.OK():
        print("Server connected.")
    else:
        print("Server connect fail.")
        sys.exit(1)

    # Create table demo_table if it dosen't exist.
    table_name = 'demo_tables'

    status, ok = milvus.has_table(table_name)
    if not ok:
        param = {
            'table_name': table_name,
            'dimension': _DIM,
            'index_file_size': _INDEX_FILE_SIZE,  # optional
            'metric_type': MetricType.L2  # optional
        }

        milvus.create_table(param)

    # Show tables in Milvus server
    _, tables = milvus.show_tables()

    # Describe demo_table
    _, table = milvus.describe_table(table_name)
    print(table)

    # 10000 vectors with 16 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(100000)]
    # You can also use numpy to generate random vectors:
    #     `vectors = np.random.rand(10000, 16).astype(np.float32).tolist()`

    # Insert vectors into demo_table, return status and vectors id list
    status, ids = milvus.insert(table_name=table_name, records=vectors)

    # Wait for 6 seconds, until Milvus server persist vector data.
    time.sleep(6)

    # Get demo_table row count
    status, result = milvus.count_table(table_name)

    # create index of vectors, search more rapidly
    index_param = {
        'index_type': IndexType.IVFLAT,  # choice ivflat index
        'nlist': 2048
    }

    # Create ivflat index in demo_table
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    status = milvus.create_index(table_name, index_param)

    # describe index, get information of index
    status, index = milvus.describe_index(table_name)
    print(index)

    # Use the top 10 vectors for similarity search
    query_vectors = vectors[0:10]

    # execute vector similarity search
    param = {
        'table_name': table_name,
        'query_records': query_vectors,
        'top_k': 1,
        'nprobe': 16
    }
    status, results = milvus.search(**param)

    if status.OK():
        # indicate search result
        # also use by:
        #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
        if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
            print('Query result is correct')
        else:
            print('Query result isn\'t correct')

    # print results
    print(results)

    # Delete demo_table
    status = milvus.drop_table(table_name)

    # Disconnect from Milvus
    status = milvus.disconnect()
def main():
    # Specify server addr when create milvus client instance
    # milvus client instance maintain a connection pool, param
    # `pool_size` specify the max connection num.
    milvus = Milvus(_HOST, _PORT)

    # Create collection demo_collection if it dosen't exist.
    collection_name = 'example_collection_'

    status, ok = milvus.has_collection(collection_name)
    if not ok:
        param = {
            'collection_name': collection_name,
            'dimension': _DIM,
            'index_file_size': _INDEX_FILE_SIZE,  # optional
            'metric_type': MetricType.L2  # optional
        }

        milvus.create_collection(param)

    # Show collections in Milvus server
    _, collections = milvus.list_collections()

    # Describe demo_collection
    _, collection = milvus.get_collection_info(collection_name)
    print(collection)

    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = text2vec(index_sentences)
    print(vectors)

    # Insert vectors into demo_collection, return status and vectors id list
    status, ids = milvus.insert(collection_name=collection_name,
                                records=vectors)
    if not status.OK():
        print("Insert failed: {}".format(status))
    else:
        print(ids)
    #create a quick lookup table to easily access the indexed text/sentences given the ids
    look_up = {}
    for ID, sentences in zip(ids, index_sentences):
        look_up[ID] = sentences

    for k in look_up:
        print(k, look_up[k])

    # Flush collection  inserted data to disk.
    milvus.flush([collection_name])
    # Get demo_collection row count
    status, result = milvus.count_entities(collection_name)

    # present collection statistics info
    _, info = milvus.get_collection_stats(collection_name)
    print(info)

    # Obtain raw vectors by providing vector ids
    status, result_vectors = milvus.get_entity_by_id(collection_name, ids)

    # create index of vectors, search more rapidly
    index_param = {'nlist': 2048}

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    print("Creating index: {}".format(index_param))
    status = milvus.create_index(collection_name, IndexType.IVF_FLAT,
                                 index_param)

    # describe index, get information of index
    status, index = milvus.get_index_info(collection_name)
    print(index)

    # Use the query sentences for similarity search
    query_vectors = text2vec(query_sentences)

    # execute vector similarity search
    search_param = {"nprobe": 16}

    print("Searching ... ")

    param = {
        'collection_name': collection_name,
        'query_records': query_vectors,
        'top_k': 1,
        'params': search_param,
    }

    status, results = milvus.search(**param)
    if status.OK():
        # indicate search result
        # also use by:
        #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
        if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
            print('Query result is correct')
        else:
            print('Query result isn\'t correct')

        # print results
        for res in results:
            for ele in res:
                print('id:{}, text:{}, distance: {}'.format(
                    ele.id, look_up[ele.id], ele.distance))

    else:
        print("Search failed. ", status)

    # Delete demo_collection
    status = milvus.drop_collection(collection_name)
Ejemplo n.º 20
0
class Indexer:
    '''
    索引器。
    '''
    def __init__(self, name, host='127.0.0.1', port='19531'):
        '''
        初始化。
        '''
        self.client = Milvus(host=host, port=port)
        self.collection = name

    def init(self, lenient=False):
        '''
        创建集合。
        '''
        if lenient:
            status, result = self.client.has_collection(
                collection_name=self.collection)
            if status.code != 0:
                raise ExertMilvusException(status)
            if result:
                return

        status = self.client.create_collection({
            'collection_name': self.collection,
            'dimension': 512,
            'index_file_size': 1024,
            'metric_type': MetricType.L2
        })
        if status.code != 0 and not (lenient and status.code == 9):
            raise ExertMilvusException(status)

        # 创建索引。
        status = self.client.create_index(collection_name=self.collection,
                                          index_type=IndexType.IVF_FLAT,
                                          params={'nlist': 16384})
        if status.code != 0:
            raise ExertMilvusException(status)

        return status

    def drop(self):
        '''
        删除集合。
        '''
        status = self.client.drop_collection(collection_name=self.collection)
        if status.code != 0:
            raise ExertMilvusException(status)

    def flush(self):
        '''
        写入到硬盘。
        '''
        status = self.client.flush([self.collection])
        if status.code != 0:
            raise ExertMilvusException(status)

    def compact(self):
        '''
        压缩集合。
        '''
        status = self.client.compact(collection_name=self.collection)
        if status.code != 0:
            raise ExertMilvusException(status)

    def close(self):
        '''
        关闭链接。
        '''
        self.client.close()

    def new_tag(self, tag):
        '''
        建分块标签。
        '''
        status = self.client.create_partition(collection_name=self.collection,
                                              partition_tag=tag)
        if status.code != 0:
            raise ExertMilvusException(status)

    def list_tag(self):
        '''
        列举分块标签。
        '''
        status, result = self.client.list_partitions(
            collection_name=self.collection)
        if status.code != 0:
            raise ExertMilvusException(status)
        return result

    def drop_tag(self, tag):
        '''
        删除分块标签。
        '''
        status = self.client.drop_partition(collection_name=self.collection,
                                            partition_tag=tag)
        if status.code != 0:
            raise ExertMilvusException(status)

    def index(self, vectors, tag=None, ids=None):
        '''
        添加索引
        '''
        params = {}
        if tag != None:
            params['tag'] = tag
        if ids != None:
            params['ids'] = ids
        status, result = self.client.insert(collection_name=self.collection,
                                            records=vectors,
                                            **params)
        if status.code != 0:
            raise ExertMilvusException(status)

        return result

    def listing(self, ids):
        '''
        列举信息。
        '''
        status, result = self.client.get_entity_by_id(
            collection_name=self.collection, ids=ids)
        if status.code != 0:
            raise ExertMilvusException(status)
        return result

    def counting(self):
        '''
        计算索引数。
        '''
        status, result = self.client.count_entities(
            collection_name=self.collection)
        if status.code != 0:
            raise ExertMilvusException(status)
        return result

    def unindex(self, ids):
        '''
        去掉索引。
        '''
        status = self.client.delete_entity_by_id(
            collection_name=self.collection, id_array=ids)
        if status.code != 0:
            raise ExertMilvusException(status)

    def search(self, vectors, top_count=100, tags=None):
        '''
        搜索。
        '''
        params = {'params': {'nprobe': 16}}
        if tags != None:
            params['partition_tags'] = tags
        status, results = self.client.search(collection_name=self.collection,
                                             query_records=vectors,
                                             top_k=top_count,
                                             **params)
        if status.code != 0:
            raise ExertMilvusException(status)
        return results
Ejemplo n.º 21
0
class MilvusClient(object):
    def __init__(self, collection_name=None, host=None, port=None, timeout=60):
        """
        Milvus client wrapper for python-sdk.

        Default timeout set 60s
        """
        self._collection_name = collection_name
        try:
            start_time = time.time()
            if not host:
                host = SERVER_HOST_DEFAULT
            if not port:
                port = SERVER_PORT_DEFAULT
            logger.debug(host)
            logger.debug(port)
            # retry connect for remote server
            i = 0
            while time.time() < start_time + timeout:
                try:
                    self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False)
                    if self._milvus.server_status():
                        logger.debug("Try connect times: %d, %s" % (i, round(time.time() - start_time, 2)))
                        break
                except Exception as e:
                    logger.debug("Milvus connect failed: %d times" % i)
                    i = i + 1

            if time.time() > start_time + timeout:
                raise Exception("Server connect timeout")

        except Exception as e:
            raise e
        self._metric_type = None
        if self._collection_name and self.exists_collection():
            self._metric_type = metric_type_to_str(self.describe()[1].metric_type)
            self._dimension = self.describe()[1].dimension

    def __str__(self):
        return 'Milvus collection %s' % self._collection_name

    def set_collection(self, name):
        self._collection_name = name

    def check_status(self, status):
        if not status.OK():
            logger.error(self._collection_name)
            logger.error(status.message)
            logger.error(self._milvus.server_status())
            logger.error(self.count())
            raise Exception("Status not ok")

    def check_result_ids(self, result):
        for index, item in enumerate(result):
            if item[0].distance >= epsilon:
                logger.error(index)
                logger.error(item[0].distance)
                raise Exception("Distance wrong")

    def create_collection(self, collection_name, dimension, index_file_size, metric_type):
        if not self._collection_name:
            self._collection_name = collection_name
        if metric_type not in METRIC_MAP.keys():
            raise Exception("Not supported metric_type: %s" % metric_type)
        metric_type = METRIC_MAP[metric_type]
        create_param = {'collection_name': collection_name,
                 'dimension': dimension,
                 'index_file_size': index_file_size, 
                 "metric_type": metric_type}
        status = self._milvus.create_collection(create_param)
        self.check_status(status)

    def create_partition(self, tag_name):
        status = self._milvus.create_partition(self._collection_name, tag_name)
        self.check_status(status)

    def drop_partition(self, tag_name):
        status = self._milvus.drop_partition(self._collection_name, tag_name)
        self.check_status(status)

    def list_partitions(self):
        status, tags = self._milvus.list_partitions(self._collection_name)
        self.check_status(status)
        return tags

    @time_wrapper
    def insert(self, X, ids=None, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status, result = self._milvus.insert(collection_name, X, ids)
        self.check_status(status)
        return status, result

    def insert_rand(self):
        insert_xb = random.randint(1, 100)
        X = [[random.random() for _ in range(self._dimension)] for _ in range(insert_xb)]
        X = utils.normalize(self._metric_type, X)
        count_before = self.count()
        status, _ = self.insert(X)
        self.check_status(status)
        self.flush()
        if count_before + insert_xb != self.count():
            raise Exception("Assert failed after inserting")

    def get_rand_ids(self, length):
        while True:
            status, stats = self._milvus.get_collection_stats(self._collection_name)
            self.check_status(status)
            segments = stats["partitions"][0]["segments"]
            # random choice one segment
            segment = random.choice(segments)
            status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
            if not status.OK():
                logger.error(status.message)
                continue
            if len(segment_ids):
                break
        if length >= len(segment_ids):
            logger.debug("Reset length: %d" % len(segment_ids))
            return segment_ids
        return random.sample(segment_ids, length)

    def get_rand_ids_each_segment(self, length):
        res = []
        status, stats = self._milvus.get_collection_stats(self._collection_name)
        self.check_status(status)
        segments = stats["partitions"][0]["segments"]
        segments_num = len(segments)
        # random choice from each segment
        for segment in segments:
            status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
            self.check_status(status)
            res.extend(segment_ids[:length])
        return segments_num, res

    def get_rand_entities(self, length):
        ids = self.get_rand_ids(length)
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids)
        self.check_status(status)
        return ids, get_res

    @time_wrapper
    def get_entities(self, get_ids):
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids)
        self.check_status(status)
        return get_res

    @time_wrapper
    def delete(self, ids, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status = self._milvus.delete_entity_by_id(collection_name, ids)
        self.check_status(status)

    def delete_rand(self):
        delete_id_length = random.randint(1, 100)
        count_before = self.count()
        logger.info("%s: length to delete: %d" % (self._collection_name, delete_id_length))
        delete_ids = self.get_rand_ids(delete_id_length)
        self.delete(delete_ids)
        self.flush()
        logger.info("%s: count after delete: %d" % (self._collection_name, self.count()))
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids)
        self.check_status(status)
        for item in get_res:
            if item:
                raise Exception("Assert failed after delete")
        if count_before - len(delete_ids) != self.count():
            raise Exception("Assert failed after delete")

    @time_wrapper
    def flush(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status = self._milvus.flush([collection_name])
        self.check_status(status)

    @time_wrapper
    def compact(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status = self._milvus.compact(collection_name)
        self.check_status(status)

    @time_wrapper
    def create_index(self, index_type, index_param=None):
        index_type = INDEX_MAP[index_type]
        logger.info("Building index start, collection_name: %s, index_type: %s" % (self._collection_name, index_type))
        if index_param:
            logger.info(index_param)
        status = self._milvus.create_index(self._collection_name, index_type, index_param)
        self.check_status(status)

    def describe_index(self):
        status, result = self._milvus.get_index_info(self._collection_name)
        self.check_status(status)
        index_type = None
        for k, v in INDEX_MAP.items():
            if result._index_type == v:
                index_type = k
                break
        return {"index_type": index_type, "index_param": result._params}

    def drop_index(self):
        logger.info("Drop index: %s" % self._collection_name)
        return self._milvus.drop_index(self._collection_name)

    def query(self, X, top_k, search_param=None, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status, result = self._milvus.search(collection_name, top_k, query_records=X, params=search_param)
        self.check_status(status)
        return result

    def query_rand(self):
        top_k = random.randint(1, 100)
        nq = random.randint(1, 100)
        nprobe = random.randint(1, 100)
        search_param = {"nprobe": nprobe}
        _, X = self.get_rand_entities(nq)
        logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe))
        status, _ = self._milvus.search(self._collection_name, top_k, query_records=X, params=search_param)
        self.check_status(status)
        # for i, item in enumerate(search_res):
        #     if item[0].id != ids[i]:
        #         logger.warning("The index of search result: %d" % i)
        #         raise Exception("Query failed")

    # @time_wrapper
    # def query_ids(self, top_k, ids, search_param=None):
    #     status, result = self._milvus.search_by_id(self._collection_name, ids, top_k, params=search_param)
    #     self.check_result_ids(result)
    #     return result

    def count(self, name=None):
        if name is None:
            name = self._collection_name
        logger.debug(self._milvus.count_entities(name))
        row_count = self._milvus.count_entities(name)[1]
        if not row_count:
            row_count = 0
        logger.debug("Row count: %d in collection: <%s>" % (row_count, name))
        return row_count

    def drop(self, timeout=120, name=None):
        timeout = int(timeout)
        if name is None:
            name = self._collection_name
        logger.info("Start delete collection: %s" % name)
        status = self._milvus.drop_collection(name)
        self.check_status(status)
        i = 0
        while i < timeout:
            if self.count(name=name):
                time.sleep(1)
                i = i + 1
                continue
            else:
                break
        if i >= timeout:
            logger.error("Delete collection timeout")

    def describe(self):
        # logger.info(self._milvus.get_collection_info(self._collection_name))
        return self._milvus.get_collection_info(self._collection_name)

    def show_collections(self):
        return self._milvus.list_collections()

    def exists_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        _, res = self._milvus.has_collection(collection_name)
        # self.check_status(status)
        return res

    def clean_db(self):
        collection_names = self.show_collections()[1]
        for name in collection_names:
            logger.debug(name)
            self.drop(name=name)

    @time_wrapper
    def preload_collection(self):
        status = self._milvus.load_collection(self._collection_name, timeout=3000)
        self.check_status(status)
        return status

    def get_server_version(self):
        _, res = self._milvus.server_version()
        return res

    def get_server_mode(self):
        return self.cmd("mode")

    def get_server_commit(self):
        return self.cmd("build_commit_id")

    def get_server_config(self):
        return json.loads(self.cmd("get_config *"))

    def get_mem_info(self):
        result = json.loads(self.cmd("get_system_info"))
        result_human = {
            # unit: Gb
            "memory_used": round(int(result["memory_used"]) / (1024*1024*1024), 2)
        }
        return result_human

    def cmd(self, command):
        status, res = self._milvus._cmd(command)
        logger.info("Server command: %s, result: %s" % (command, res))
        self.check_status(status)
        return res
Ejemplo n.º 22
0
def main():
    # Specify server addr when create milvus client instance
    # milvus client instance maintain a connection pool, param
    # `pool_size` specify the max connection num.
    # 获取服务端的连接
    milvus = Milvus(_HOST, _PORT)

    # Create collection demo_collection if it dosen't exist.
    # 创建collection
    collection_name = 'example_collection_'
    # 看是否有这个collection
    status, ok = milvus.has_collection(collection_name)
    # 如果没有则创建
    if not ok:
        param = {
            'collection_name': collection_name,
            'dimension': _DIM,
            'index_file_size': _INDEX_FILE_SIZE,  # optional
            'metric_type': MetricType.L2  # optional
        }
        # 创建collection
        milvus.create_collection(param)

    # Show collections in Milvus server
    # 查看所有的collection
    _, collections = milvus.list_collections()
    print(collections)

    # Describe demo_collection
    # 得到当前的collection
    _, collection = milvus.get_collection_info(collection_name)
    print(collection)

    # 10000 vectors with 128 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    # 创建10个长度为8的向量
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(10)]
    print(vectors)
    # You can also use numpy to generate random vectors:
    #   vectors = np.random.rand(10000, _DIM).astype(np.float32)

    # Insert vectors into demo_collection, return status and vectors id list
    # 把这10个向量都插入milvus
    status, ids = milvus.insert(collection_name=collection_name,
                                records=vectors)
    if not status.OK():
        print("Insert failed: {}".format(status))
    print(ids)

    # Flush collection  inserted data to disk.
    # 数据落盘
    milvus.flush([collection_name])
    # Get demo_collection row count
    # 得到当前row的数量
    status, result = milvus.count_entities(collection_name)
    print(status)
    print(result)

    # present collection statistics info
    # 查看collection的统计数据
    _, info = milvus.get_collection_stats(collection_name)
    print(info)

    # Obtain raw vectors by providing vector ids
    # 得到前十个数据
    status, result_vectors = milvus.get_entity_by_id(collection_name, ids[:10])
    print(result_vectors)

    # create index of vectors, search more rapidly
    # 创建索引
    index_param = {'nlist': 2048}

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    # 创建ivf_flat
    print("Creating index: {}".format(index_param))
    status = milvus.create_index(collection_name, IndexType.IVF_FLAT,
                                 index_param)

    # describe index, get information of index
    # 得到索引的信息
    status, index = milvus.get_index_info(collection_name)
    print(index)

    # Use the top 10 vectors for similarity search
    # 对前10个数据进行query
    query_vectors = vectors[0:10]

    # execute vector similarity search
    # 索引的搜索的中心点数量
    search_param = {"nprobe": 16}

    print("Searching ... ")

    param = {
        'collection_name': collection_name,
        'query_records': query_vectors,
        'top_k': 1,
        'params': search_param,
    }
    # 进行搜索
    status, results = milvus.search(**param)
    if status.OK():
        print(results)
        # indicate search result
        # also use by:
        #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
        if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
            print('Query result is correct')
        else:
            print('Query result isn\'t correct')

        # print results
        print(results)
    else:
        print("Search failed. ", status)

    # Delete demo_collection
    # 删除掉collection
    status = milvus.drop_collection(collection_name)
Ejemplo n.º 23
0
def main():
    # Specify server addr when create milvus client instance
    milvus = Milvus(_HOST, _PORT)

    # Create collection demo_collection if it dosen't exist.
    collection_name = 'example_async_collection_'

    status, ok = milvus.has_collection(collection_name)
    if not ok:
        param = {
            'collection_name': collection_name,
            'dimension': _DIM,
            'index_file_size': 128,  # optional
            'metric_type': MetricType.L2  # optional
        }

        status = milvus.create_collection(param)
        if not status.OK():
            print("Create collection failed: {}".format(status.message),
                  file=sys.stderr)
            print("exiting ...", file=sys.stderr)
            sys.exit(1)

    # Show collections in Milvus server
    _, collections = milvus.list_collections()

    # Describe demo_collection
    _, collection = milvus.get_collection_info(collection_name)
    print(collection)

    # 10000 vectors with 16 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(100000)]

    # You can also use numpy to generate random vectors:
    #     `vectors = np.random.rand(10000, 16).astype(np.float32)`

    def _insert_callback(status, ids):
        if status.OK():
            print("Insert successfully")
        else:
            print("Insert failed.", status.message)

    # Insert vectors into demo_collection, adding callback function
    insert_future = milvus.insert(collection_name=collection_name,
                                  records=vectors,
                                  _async=True,
                                  _callback=_insert_callback)
    # Or invoke result() to get results:
    #   insert_future = milvus.insert(collection_name=collection_name, records=vectors, _async=True)
    #   status, ids = insert_future.result()
    insert_future.done()

    # Flush collection  inserted data to disk.
    def _flush_callback(status):
        if status.OK():
            print("Flush successfully")
        else:
            print("Flush failed.", status.message)

    flush_future = milvus.flush([collection_name],
                                _async=True,
                                _callback=_flush_callback)
    # Or invoke result() to get results:
    #   flush_future = milvus.flush([collection_name], _async=True)
    #   status = flush_future.result()
    flush_future.done()

    def _compact_callback(status):
        if status.OK():
            print("Compact successfully")
        else:
            print("Compact failed.", status.message)

    compact_furure = milvus.compact(collection_name,
                                    _async=True,
                                    _cakkback=_compact_callback)
    # Or invoke result() to get results:
    #   compact_future = milvus.compact(collection_name, _async=True)
    #   status = compact_future.result()
    compact_furure.done()

    # Get demo_collection row count
    status, result = milvus.count_entities(collection_name)

    # present collection info
    _, info = milvus.get_collection_stats(collection_name)
    print(info)

    # create index of vectors, search more rapidly
    index_param = {'nlist': 2048}

    def _index_callback(status):
        if status.OK():
            print("Create index successfully")
        else:
            print("Create index failed.", status.message)

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    print("Creating index: {}".format(index_param))
    index_future = milvus.create_index(collection_name,
                                       IndexType.IVF_FLAT,
                                       index_param,
                                       _async=True,
                                       _callback=_index_callback)
    # Or invoke result() to get results:
    #   index_future = milvus.create_index(collection_name, IndexType.IVF_FLAT, index_param, _async=True)
    #   status = index_future.result()
    index_future.done()

    # describe index, get information of index
    status, index = milvus.get_index_info(collection_name)
    print(index)

    # Use the top 10 vectors for similarity search
    query_vectors = vectors[0:10]

    # execute vector similarity search
    search_param = {"nprobe": 16}

    print("Searching ... ")

    def _search_callback(status, results):
        # if status.OK():
        #     print("Search successfully")
        # else:
        #     print("Search failed.", status.message)
        if status.OK():
            # indicate search result
            # also use by:
            #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
            if results[0][0].distance == 0.0:  # or results[0][0].id == ids[0]:
                print('Query result is correct')
            else:
                print('Query result isn\'t correct')

            # print results
            print(results)
        else:
            print("Search failed. ", status)

    param = {
        'collection_name': collection_name,
        'query_records': query_vectors,
        'top_k': 1,
        'params': search_param,
        "_async": True,
        "_callback": _search_callback
    }
    search_future = milvus.search(**param)
    # Or invoke result() to get results:
    #
    #   param = {
    #       'collection_name': collection_name,
    #       'query_records': query_vectors,
    #       'top_k': 1,
    #       'params': search_param,
    #       "_async": True,
    #   }
    #   search_future = milvus.search(param)
    #   status, results = index_future.result()

    search_future.done()

    # Delete demo_collection
    status = milvus.drop_collection(collection_name)
Ejemplo n.º 24
0
class SearchEngine:
    def __init__(self, host, port):
        self.host = os.environ.get('MILVUS_HOST', host)
        self.port = os.environ.get('MILVUS_PORT', str(port))
        self.engine = Milvus(host=self.host, port=self.port)
        self.collection_name = None

#################################################
# HANDLE COLLECTION
#################################################

    def create_collection(self, collection_name, dimension):
        # collection 생성

        param = {
            'collection_name': collection_name,
            'dimension': dimension,
            'index_file_size': 1000,
            'metric_type': MetricType.IP
        }

        self.engine.create_collection(param)

        print('[INFO] collection {}을 생성했습니다.'.format(collection_name))

    def drop_collection(self, collection_name):
        # collection 삭제

        self.engine.drop_collection(collection_name=collection_name)

        print('[INFO] collection {}을 삭제했습니다.'.format(collection_name))

    def get_collection_stats(self, collection_name):
        # collection 정보 출력

        print(self.engine.get_collection_info(collection_name))
        print(self.engine.get_collection_stats(collection_name))

    def set_collection(self, collection_name):
        # 쿼리 조작을 하기 위한 collection 지정

        self.collection_name = collection_name
        print('[INFO] setting collection {}'.format(self.collection_name))

#################################################
# UTILS
#################################################

    def check_set_collection(self):
        # 쿼리 조작을 위한 collection 지정이 되어있는 지 체크

        assert self.collection_name is not None, '[ERROR] collection을 setting해 주십시오!!'

    def check_exist_data_by_key(self, key):
        # collection에 정해진 key가 존재하는 지 확인

        self.check_set_collection()

        _, vector = self.engine.get_entity_by_id(
            collection_name=self.collection_name, ids=key)

        vector = vector if vector else [vector]

        return True if vector[0] else False

    def convert_key_format(self, key):
        return [key] if isinstance(key, int) else key

    def convert_value_format(self, value):
        rank = len(value.shape)

        assert rank < 2, '[ERROR] value의 dim을 2 미만으로 입력해 주세요!!'
        return value.reshape(1, -1) if rank == 1 else value

#################################################
# INSERT
#################################################

    def insert_data(self, key, value):
        # 데이터를 collection에 입력

        key = self.convert_key_format(key)
        value = self.convert_value_format(value)

        if self.check_exist_data_by_key(key):
            print("[ERROR] 이미 collection에 데이터가 존재합니다.")
            return

        self.engine.insert(collection_name=self.collection_name,
                           records=value,
                           ids=key)
        self.engine.flush([self.collection_name])

        print('[INFO] insert key {}'.format(key))

#################################################
# DELELTE
#################################################

    def delete_data(self, key):
        # 데이터를 collection에서 제거

        key = self.convert_key_format(key)

        if not self.check_exist_data_by_key(key):
            print("[ERROR] collection에 데이터가 존재하지 않습니다.")
            return

        self.engine.delete_entity_by_id(self.collection_name, key)
        self.engine.flush([self.collection_name])

        print('[INFO] delete key {}'.format(key))

#################################################
# UPDATE
#################################################

    def update_data(self, key, value):
        # 데이터를 업데이트

        key = self.convert_key_format(key)
        value = self.convert_value_format(value)

        if not self.check_exist_data_by_key(key):
            print("[ERROR] collection에 데이터가 존재하지 않습니다.")
            return

        self.engine.delete_entity_by_id(self.collection_name, key)
        self.engine.flush([self.collection_name])

        self.engine.insert(collection_name=self.collection_name,
                           records=value,
                           ids=key)
        self.engine.flush([self.collection_name])

        print('[INFO] update key {}'.format(key))

#################################################
# SEARCH
#################################################

    def search_by_feature(self, feature, top_k):
        # feature를 이용해서 데이터를 검색

        self.check_set_collection()

        feature = self.convert_value_format(feature)

        _, result = self.engine.search(collection_name=self.collection_name,
                                       query_records=feature,
                                       top_k=top_k)

        li_id = [
            list(map(lambda x: x.id, result[0])) for i in range(len(result))
        ]
        li_dist = [
            list(map(lambda x: x.distance, result[0]))
            for i in range(len(result))
        ]

        return li_id, li_dist

    def search_by_key(self, key, top_k):
        # key를 이용해서 데이터를 검색

        self.check_set_collection()
        key = self.convert_key_format(key)

        if not self.check_exist_data_by_key(key):
            print("[ERROR] collection에 데이터가 존재하지 않습니다.")
            return

        _, vector = self.engine.get_entity_by_id(
            collection_name=self.collection_name, ids=key)
        _, result = self.engine.search(collection_name=self.collection_name,
                                       query_records=vector,
                                       top_k=top_k + 1)

        li_id = [
            list(map(lambda x: x.id, result[0][1:]))
            for i in range(len(result))
        ]
        li_dist = [
            list(map(lambda x: x.distance, result[0][1:]))
            for i in range(len(result))
        ]

        return li_id, li_dist
Ejemplo n.º 25
0
class MilvusHelper:
    def __init__(self):
        try:
            self.client = Milvus(host=MILVUS_HOST, port=MILVUS_PORT)
            LOGGER.debug(
                "Successfully connect to Milvus with IP:{} and PORT:{}".format(
                    MILVUS_HOST, MILVUS_PORT))
        except Exception as e:
            LOGGER.error("Failed to connect Milvus: {}".format(e))
            sys.exit(1)

    # Return if Milvus has the collection
    def has_collection(self, collection_name):
        try:
            status = self.client.has_collection(collection_name)[1]
            return status
        except Exception as e:
            LOGGER.error("Failed to load data to Milvus: {}".format(e))
            sys.exit(1)

    # Create milvus collection if not exists
    def create_colllection(self, collection_name):
        try:
            if not self.has_collection(collection_name):
                collection_param = {
                    'collection_name': collection_name,
                    'dimension': VECTOR_DIMENSION,
                    'index_file_size': INDEX_FILE_SIZE,
                    'metric_type': METRIC_TYPE
                }
                status = self.client.create_collection(collection_param)
                if status.code != 0:
                    raise Exception(status.message)
                LOGGER.debug(
                    "Create Milvus collection: {}".format(collection_name))
        except Exception as e:
            LOGGER.error("Failed to load data to Milvus: {}".format(e))
            sys.exit(1)

    # Batch insert vectors to milvus collection
    def insert(self, collection_name, vectors):
        try:
            self.create_colllection(collection_name)
            status, ids = self.client.insert(collection_name=collection_name,
                                             records=vectors)
            if not status.code:
                LOGGER.debug(
                    "Insert vectors to Milvus in collection: {} with {} rows".
                    format(collection_name, len(vectors)))
                return ids
            else:
                raise Exception(status.message)
        except Exception as e:
            LOGGER.error("Failed to load data to Milvus: {}".format(e))
            sys.exit(1)

    # Create IVF_FLAT index on milvus collection
    def create_index(self, collection_name):
        try:
            index_param = {'nlist': 16384}
            status = self.client.create_index(collection_name,
                                              IndexType.IVF_FLAT, index_param)
            if not status.code:
                LOGGER.debug(
                    "Successfully create index in collection:{} with param:{}".
                    format(collection_name, index_param))
                return status
            else:
                raise Exception(status.message)
        except Exception as e:
            LOGGER.error("Failed to create index: {}".format(e))
            sys.exit(1)

    # Delete Milvus collection
    def delete_collection(self, collection_name):
        try:
            status = self.client.drop_collection(
                collection_name=collection_name)
            if not status.code:
                LOGGER.debug(
                    "Successfully drop collection: {}".format(collection_name))
                return status
            else:
                raise Exception(status.message)
        except Exception as e:
            LOGGER.error("Failed to drop collection: {}".format(e))
            sys.exit(1)

    # Search vector in milvus collection
    def search_vectors(self, collection_name, vectors, top_k):
        try:
            search_param = {'nprobe': 16}
            status, result = self.client.search(
                collection_name=collection_name,
                query_records=vectors,
                top_k=top_k,
                params=search_param)
            if not status.code:
                LOGGER.debug("Successfully search in collection: {}".format(
                    collection_name))
                return result
            else:
                raise Exception(status.message)
        except Exception as e:
            LOGGER.error("Failed to search vectors in Milvus: {}".format(e))
            sys.exit(1)

    # Get the number of milvus collection
    def count(self, collection_name):
        try:
            status, num = self.client.count_entities(
                collection_name=collection_name)
            if not status.code:
                LOGGER.debug(
                    "Successfully get the num:{} of the collection:{}".format(
                        num, collection_name))
                return num
            else:
                raise Exception(status.message)
        except Exception as e:
            LOGGER.error("Failed to count vectors in Milvus: {}".format(e))
            sys.exit(1)
Ejemplo n.º 26
0
class MilvusHelper(BaseVectorSimilarityHelper):
    def __init__(self, _server_url, _server_port, _timeout=10):
        super().__init__()
        self.server_url = _server_url
        self.server_port = _server_port
        self.timeout = _timeout
        self.client = None
        self.metric_type_mapper = {
            VectorMetricType.L2: MetricType.L2,
            VectorMetricType.IP: MetricType.IP,
            VectorMetricType.JACCARD: MetricType.JACCARD,
            VectorMetricType.HAMMING: MetricType.HAMMING,
        }

        self.index_type_mapper = {
            VectorIndexType.FLAT: IndexType.FLAT,
            VectorIndexType.IVFLAT: IndexType.IVFLAT,
            VectorIndexType.IVF_SQ8: IndexType.IVF_SQ8,
            VectorIndexType.RNSG: IndexType.RNSG,
            VectorIndexType.IVF_SQ8H: IndexType.IVF_SQ8H,
            VectorIndexType.IVF_PQ: IndexType.IVF_PQ,
            VectorIndexType.HNSW: IndexType.HNSW,
            VectorIndexType.ANNOY: IndexType.ANNOY,
        }

    def init(self):
        if self.client is None:
            if not (self.server_url is None or self.server_url is None):
                try:
                    self.client = Milvus(host=self.server_url, port=self.server_port)
                except:
                    raise MilvusRuntimeException(f'cannot connect to {self.server_url}:{self.server_port}')
            else:
                raise MilvusRuntimeException('Milvus config is not correct')

    def insert(self, _database_name, _to_insert_vector, _partition_tag=None, _params=None):
        """
        向数据库中插入一系列的特征向量

        notes:如果用户有自己的id,建议使用insert_with_id函数

        ATTENTION!!!

        一个库中不能既调用insert_with_id还调用insert,只能调用一种,否则会报错

        Args:
            _database_name:     数据库名称
            _to_insert_vector:  待插入的特征向量的列表
            _partition_tag:     分区标签
            _params:    插入参数

        Returns:    插入后的id

        """
        self.init()
        status, ids = self.client.insert(_database_name, _to_insert_vector, partition_tag=_partition_tag,
                                         params=_params,
                                         timeout=self.timeout)
        self.flush(_database_name)
        if status.OK():
            return ids
        else:
            raise MilvusRuntimeException(status.message)

    def insert_with_id(self, _database_name, _to_insert_vector, _to_insert_ids, _partition_tag=None, _params=None):
        """
        向数据库中插入一系列的有固定id的特征向量

        ATTENTION!!!

        一个库中不能既调用insert_with_id还调用insert,只能调用一种,否则会报错

        Args:
            _database_name:     数据库名称
            _to_insert_vector:  待插入的特征向量的列表
            _to_insert_ids:  待插入的特征向量的id的列表,每个元素必须为正整数,且不越界
            _partition_tag:     分区标签
            _params:    插入参数

        Returns:    插入后的id

        """
        self.init()
        status, ids = self.client.insert(_database_name, _to_insert_vector,
                                         ids=_to_insert_ids, partition_tag=_partition_tag,
                                         params=_params,
                                         timeout=self.timeout)
        self.flush(_database_name)
        if status.OK():
            return ids
        else:
            raise MilvusRuntimeException(status.message)

    def delete(self, _database_name, _to_delete_ids):
        """
        删除特定id

        Args:
            _database_name:     数据库名称
            _to_delete_ids:     待删除的id

        Returns:    是否删除成功

        """
        self.init()
        status = self.client.delete_entity_by_id(_database_name, _to_delete_ids, self.timeout)
        self.flush(_database_name)
        if status.OK():
            return True
        else:
            raise MilvusRuntimeException(status.message)

    def database_exist(self, _database_name):
        """
        数据库是否存在

        Args:
            _database_name:     数据库名称

        Returns:    是否存在

        """
        self.init()
        status, is_exist = self.client.has_collection(_database_name, self.timeout)
        if status.OK():
            return is_exist
        else:
            raise MilvusRuntimeException(status.message)

    def create_database(self, _database_name, _dimension, _index_file_size, _metric_type):
        """
        创建数据库

        Args:
            _database_name:     数据库名称
            _dimension:     特征向量维度
            _index_file_size:   index的文件大小
            _metric_type:   度量类型

        Returns:    是否创建成功

        """
        self.init()
        if not self.database_exist(_database_name):
            assert _metric_type in self.metric_type_mapper, f'{_metric_type} not support in milvus'
            status = self.client.create_collection({
                'collection_name': _database_name,
                'dimension': _dimension,
                'index_file_size': _index_file_size,
                'metric_type': self.metric_type_mapper[_metric_type]
            })
            if status.OK():
                return True
            else:
                raise MilvusRuntimeException(status.message)
        else:
            return True

    def create_index(self, _database_name, _index_type):
        """
        创建index(索引)

        Args:
            _database_name:     数据库名称
            _index_type:    index类型

        Returns:    是否创建成功

        """
        self.init()
        if self.database_exist(_database_name):
            assert _index_type in self.index_type_mapper, f'{_index_type} not support in milvus'
            status = self.client.create_index(_database_name, self.index_type_mapper[_index_type], timeout=self.timeout)
            if status.OK():
                return True
            else:
                raise MilvusRuntimeException(status.message)

    def search(self, _database_name, _query_vector_list, _top_k, _partition_tag=None, _params=None):
        """
        检索用参数

        Args:
            _database_name:     数据库名称
            _query_vector_list:      检索用的特征向量列表
            _top_k:     top k
            _partition_tag:     分区标签
            _params:    检索参数

        Returns:    检索的结果,包含id和distance

        """
        self.init()
        if self.database_exist(_database_name):
            status, search_result = self.client.search(_database_name, _top_k, _query_vector_list,
                                                       partition_tags=_partition_tag,
                                                       params=_params,
                                                       timeout=self.timeout)
            if status.OK():
                return search_result
            else:
                raise MilvusRuntimeException(status.message)
        else:
            raise DatabaseNotExist(f'{_database_name} not exist')

    def flush(self, _database_name):
        """
        sink数据库

        Args:
            _database_name:     数据库名称

        Returns:    是否flush成功

        """
        self.init()
        status = self.client.flush([_database_name, ], self.timeout)
        if status.OK():
            return True
        else:
            raise MilvusRuntimeException(status.message)

    def __del__(self):
        if self.client is not None:
            self.client.close()
Ejemplo n.º 27
0
                    "query": [query_embedding],
                    "metric_type": "L2",
                    "params": {
                        "nprobe": 8
                    }
                }
            }
        }]
    }
}

# ------
# Basic hybrid search entities
# ------
results = client.search(collection_name,
                        query_hybrid,
                        fields=["release_year", "embedding"])
for entities in results:
    for topk_film in entities:
        current_entity = topk_film.entity
        print("==")
        print("- id: {}".format(topk_film.id))
        print("- title: {}".format(titles[topk_film.id]))
        print("- distance: {}".format(topk_film.distance))

        print("- release_year: {}".format(current_entity.release_year))
        print("- embedding: {}".format(current_entity.embedding))

# ------
# Basic delete index:
#     You can drop index for a field.
Ejemplo n.º 28
0
def main():
    # Connect to Milvus server
    # You may need to change _HOST and _PORT accordingly
    param = {'host': _HOST, 'port': _PORT}

    # You can create a instance specified server addr and
    # invoke rpc method directly
    client = Milvus(**param)
    # Create collection demo_collection if it dosen't exist.
    collection_name = 'demo_partition_collection'
    partition_tag = "random"

    status, ok = client.has_collection(collection_name)
    # if collection exists, then drop it
    if status.OK() and ok:
        client.drop_collection(collection_name)

    param = {
        'collection_name': collection_name,
        'dimension': _DIM,
        'index_file_size': _INDEX_FILE_SIZE,  # optional
        'metric_type': MetricType.L2  # optional
    }

    client.create_collection(param)

    # Show collections in Milvus server
    _, collections = client.show_collections()

    # Describe collection
    _, collection = client.describe_collection(collection_name)
    print(collection)

    # create partition
    client.create_partition(collection_name, partition_tag=partition_tag)
    # display partitions
    _, partitions = client.show_partitions(collection_name)

    # 10000 vectors with 16 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(10000)]
    # You can also use numpy to generate random vectors:
    #     `vectors = np.random.rand(10000, 16).astype(np.float32).tolist()`

    # Insert vectors into partition of collection, return status and vectors id list
    status, ids = client.insert(collection_name=collection_name, records=vectors, partition_tag=partition_tag)

    # Wait for 6 seconds, until Milvus server persist vector data.
    time.sleep(6)

    # Get demo_collection row count
    status, num = client.count_collection(collection_name)

    # create index of vectors, search more rapidly
    index_param = {
        'nlist': 2048
    }

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    status = client.create_index(collection_name, IndexType.IVF_FLAT, index_param)

    # describe index, get information of index
    status, index = client.describe_index(collection_name)
    print(index)

    # Use the top 10 vectors for similarity search
    query_vectors = vectors[0:10]

    # execute vector similarity search, search range in partition `partition1`
    search_param = {
        "nprobe": 10
    }

    param = {
        'collection_name': collection_name,
        'query_records': query_vectors,
        'top_k': 1,
        'partition_tags': ["random"],
        'params': search_param
    }
    status, results = client.search(**param)

    if status.OK():
        # indicate search result
        # also use by:
        #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
        if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
            print('Query result is correct')
        else:
            print('Query result isn\'t correct')

    # print results
    print(results)

    # Delete partition. You can also invoke `drop_collection()`, so that all of partitions belongs to
    # designated collections will be deleted.
    # status = client.drop_partition(collection_name, partition_tag)

    # Delete collection. All of partitions of this collection will be dropped.
    status = client.drop_collection(collection_name)
Ejemplo n.º 29
0
class MilvusClient(object):
    def __init__(self,
                 collection_name=None,
                 host=None,
                 port=None,
                 timeout=180):
        self._collection_name = collection_name
        start_time = time.time()
        if not host:
            host = SERVER_HOST_DEFAULT
        if not port:
            port = SERVER_PORT_DEFAULT
        logger.debug(host)
        logger.debug(port)
        # retry connect remote server
        i = 0
        while time.time() < start_time + timeout:
            try:
                self._milvus = Milvus(host=host,
                                      port=port,
                                      try_connect=False,
                                      pre_ping=False)
                break
            except Exception as e:
                logger.error(str(e))
                logger.error("Milvus connect failed: %d times" % i)
                i = i + 1
                time.sleep(i)

        if time.time() > start_time + timeout:
            raise Exception("Server connect timeout")
        # self._metric_type = None

    def __str__(self):
        return 'Milvus collection %s' % self._collection_name

    def check_status(self, status):
        if not status.OK():
            logger.error(status.message)
            logger.error(self._milvus.server_status())
            logger.error(self.count())
            raise Exception("Status not ok")

    def check_result_ids(self, result):
        for index, item in enumerate(result):
            if item[0].distance >= epsilon:
                logger.error(index)
                logger.error(item[0].distance)
                raise Exception("Distance wrong")

    # only support the given field name
    def create_collection(self,
                          dimension,
                          data_type=DataType.FLOAT_VECTOR,
                          auto_id=False,
                          collection_name=None,
                          other_fields=None):
        self._dimension = dimension
        if not collection_name:
            collection_name = self._collection_name
        vec_field_name = utils.get_default_field_name(data_type)
        fields = [{
            "name": vec_field_name,
            "type": data_type,
            "params": {
                "dim": dimension
            }
        }]
        if other_fields:
            other_fields = other_fields.split(",")
            if "int" in other_fields:
                fields.append({
                    "name": utils.DEFAULT_INT_FIELD_NAME,
                    "type": DataType.INT64
                })
            if "float" in other_fields:
                fields.append({
                    "name": utils.DEFAULT_FLOAT_FIELD_NAME,
                    "type": DataType.FLOAT
                })
        create_param = {"fields": fields, "auto_id": auto_id}
        try:
            self._milvus.create_collection(collection_name, create_param)
            logger.info("Create collection: <%s> successfully" %
                        collection_name)
        except Exception as e:
            logger.error(str(e))
            raise

    def create_partition(self, tag, collection_name=None):
        if not collection_name:
            collection_name = self._collection_name
        self._milvus.create_partition(collection_name, tag)

    def generate_values(self, data_type, vectors, ids):
        values = None
        if data_type in [DataType.INT32, DataType.INT64]:
            values = ids
        elif data_type in [DataType.FLOAT, DataType.DOUBLE]:
            values = [(i + 0.0) for i in ids]
        elif data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
            values = vectors
        return values

    def generate_entities(self, vectors, ids=None, collection_name=None):
        entities = []
        if collection_name is None:
            collection_name = self._collection_name
        info = self.get_info(collection_name)
        for field in info["fields"]:
            field_type = field["type"]
            entities.append({
                "name":
                field["name"],
                "type":
                field_type,
                "values":
                self.generate_values(field_type, vectors, ids)
            })
        return entities

    @time_wrapper
    def insert(self, entities, ids=None, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        try:
            insert_ids = self._milvus.insert(tmp_collection_name,
                                             entities,
                                             ids=ids)
            return insert_ids
        except Exception as e:
            logger.error(str(e))

    def get_dimension(self):
        info = self.get_info()
        for field in info["fields"]:
            if field["type"] in [
                    DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR
            ]:
                return field["params"]["dim"]

    def get_rand_ids(self, length):
        segment_ids = []
        while True:
            stats = self.get_stats()
            segments = stats["partitions"][0]["segments"]
            # random choice one segment
            segment = random.choice(segments)
            try:
                segment_ids = self._milvus.list_id_in_segment(
                    self._collection_name, segment["id"])
            except Exception as e:
                logger.error(str(e))
            if not len(segment_ids):
                continue
            elif len(segment_ids) > length:
                return random.sample(segment_ids, length)
            else:
                logger.debug("Reset length: %d" % len(segment_ids))
                return segment_ids

    # def get_rand_ids_each_segment(self, length):
    #     res = []
    #     status, stats = self._milvus.get_collection_stats(self._collection_name)
    #     self.check_status(status)
    #     segments = stats["partitions"][0]["segments"]
    #     segments_num = len(segments)
    #     # random choice from each segment
    #     for segment in segments:
    #         status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
    #         self.check_status(status)
    #         res.extend(segment_ids[:length])
    #     return segments_num, res

    # def get_rand_entities(self, length):
    #     ids = self.get_rand_ids(length)
    #     status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids)
    #     self.check_status(status)
    #     return ids, get_res

    def get(self):
        get_ids = random.randint(1, 1000000)
        self._milvus.get_entity_by_id(self._collection_name, [get_ids])

    @time_wrapper
    def get_entities(self, get_ids):
        get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids)
        return get_res

    @time_wrapper
    def delete(self, ids, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        self._milvus.delete_entity_by_id(tmp_collection_name, ids)

    def delete_rand(self):
        delete_id_length = random.randint(1, 100)
        count_before = self.count()
        logger.debug("%s: length to delete: %d" %
                     (self._collection_name, delete_id_length))
        delete_ids = self.get_rand_ids(delete_id_length)
        self.delete(delete_ids)
        self.flush()
        logger.info("%s: count after delete: %d" %
                    (self._collection_name, self.count()))
        get_res = self._milvus.get_entity_by_id(self._collection_name,
                                                delete_ids)
        for item in get_res:
            assert not item
        # if count_before - len(delete_ids) < self.count():
        #     logger.error(delete_ids)
        #     raise Exception("Error occured")

    @time_wrapper
    def flush(self, _async=False, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        self._milvus.flush([tmp_collection_name], _async=_async)

    @time_wrapper
    def compact(self, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        status = self._milvus.compact(tmp_collection_name)
        self.check_status(status)

    @time_wrapper
    def create_index(self,
                     field_name,
                     index_type,
                     metric_type,
                     _async=False,
                     index_param=None):
        index_type = INDEX_MAP[index_type]
        metric_type = utils.metric_type_trans(metric_type)
        logger.info(
            "Building index start, collection_name: %s, index_type: %s, metric_type: %s"
            % (self._collection_name, index_type, metric_type))
        if index_param:
            logger.info(index_param)
        index_params = {
            "index_type": index_type,
            "metric_type": metric_type,
            "params": index_param
        }
        self._milvus.create_index(self._collection_name,
                                  field_name,
                                  index_params,
                                  _async=_async)

    # TODO: need to check
    def describe_index(self, field_name):
        # stats = self.get_stats()
        info = self._milvus.describe_index(self._collection_name, field_name)
        index_info = {"index_type": "flat", "index_param": None}
        for field in info["fields"]:
            for index in field['indexes']:
                if not index or "index_type" not in index:
                    continue
                else:
                    for k, v in INDEX_MAP.items():
                        if index['index_type'] == v:
                            index_info['index_type'] = k
                            index_info['index_param'] = index['params']
                            return index_info
        return index_info

    def drop_index(self, field_name):
        logger.info("Drop index: %s" % self._collection_name)
        return self._milvus.drop_index(self._collection_name, field_name)

    @time_wrapper
    def query(self, vector_query, filter_query=None, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        must_params = [vector_query]
        if filter_query:
            must_params.extend(filter_query)
        query = {"bool": {"must": must_params}}
        result = self._milvus.search(tmp_collection_name, query)
        return result

    @time_wrapper
    def load_and_query(self,
                       vector_query,
                       filter_query=None,
                       collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        must_params = [vector_query]
        if filter_query:
            must_params.extend(filter_query)
        query = {"bool": {"must": must_params}}
        self.load_collection(tmp_collection_name)
        result = self._milvus.search(tmp_collection_name, query)
        return result

    def get_ids(self, result):
        idss = result._entities.ids
        ids = []
        len_idss = len(idss)
        len_r = len(result)
        top_k = len_idss // len_r
        for offset in range(0, len_idss, top_k):
            ids.append(idss[offset:min(offset + top_k, len_idss)])
        return ids

    def query_rand(self, nq_max=100):
        # for ivf search
        dimension = 128
        top_k = random.randint(1, 100)
        nq = random.randint(1, nq_max)
        nprobe = random.randint(1, 100)
        search_param = {"nprobe": nprobe}
        query_vectors = [[random.random() for _ in range(dimension)]
                         for _ in range(nq)]
        metric_type = random.choice(["l2", "ip"])
        logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" %
                    (self._collection_name, nq, top_k, nprobe))
        vec_field_name = utils.get_default_field_name()
        vector_query = {
            "vector": {
                vec_field_name: {
                    "topk": top_k,
                    "query": query_vectors,
                    "metric_type": utils.metric_type_trans(metric_type),
                    "params": search_param
                }
            }
        }
        self.query(vector_query)

    def load_query_rand(self, nq_max=100):
        # for ivf search
        dimension = 128
        top_k = random.randint(1, 100)
        nq = random.randint(1, nq_max)
        nprobe = random.randint(1, 100)
        search_param = {"nprobe": nprobe}
        query_vectors = [[random.random() for _ in range(dimension)]
                         for _ in range(nq)]
        metric_type = random.choice(["l2", "ip"])
        logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" %
                    (self._collection_name, nq, top_k, nprobe))
        vec_field_name = utils.get_default_field_name()
        vector_query = {
            "vector": {
                vec_field_name: {
                    "topk": top_k,
                    "query": query_vectors,
                    "metric_type": utils.metric_type_trans(metric_type),
                    "params": search_param
                }
            }
        }
        self.load_and_query(vector_query)

    # TODO: need to check
    def count(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        row_count = self._milvus.get_collection_stats(
            collection_name)["row_count"]
        logger.debug("Row count: %d in collection: <%s>" %
                     (row_count, collection_name))
        return row_count

    def drop(self, timeout=120, collection_name=None):
        timeout = int(timeout)
        if collection_name is None:
            collection_name = self._collection_name
        logger.info("Start delete collection: %s" % collection_name)
        self._milvus.drop_collection(collection_name)
        i = 0
        while i < timeout:
            try:
                row_count = self.count(collection_name=collection_name)
                if row_count:
                    time.sleep(1)
                    i = i + 1
                    continue
                else:
                    break
            except Exception as e:
                logger.debug(str(e))
                break
        if i >= timeout:
            logger.error("Delete collection timeout")

    def get_stats(self):
        return self._milvus.get_collection_stats(self._collection_name)

    def get_info(self, collection_name=None):
        # pdb.set_trace()
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.get_collection_info(collection_name)

    def show_collections(self):
        return self._milvus.list_collections()

    def exists_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        res = self._milvus.has_collection(collection_name)
        return res

    def clean_db(self):
        collection_names = self.show_collections()
        for name in collection_names:
            self.drop(collection_name=name)

    @time_wrapper
    def load_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.load_collection(collection_name, timeout=3000)

    @time_wrapper
    def release_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.release_collection(collection_name, timeout=3000)

    @time_wrapper
    def load_partitions(self, tag_names, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.load_partitions(collection_name,
                                            tag_names,
                                            timeout=3000)

    @time_wrapper
    def release_partitions(self, tag_names, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.release_partitions(collection_name,
                                               tag_names,
                                               timeout=3000)
Ejemplo n.º 30
0
class MilvusDBHandler:
    """Milvus DB handler
        This class is intended to abstract the access and communication with external MilvusDB from Executors

        For more information about Milvus:
            - https://github.com/milvus-io/milvus/
    """

    @staticmethod
    def get_index_type(index_type):
        from milvus import IndexType

        return {
            'Flat': IndexType.FLAT,
            'IVF,Flat': IndexType.IVFLAT,
            'IVF,SQ8': IndexType.IVF_SQ8,
            'RNSG': IndexType.RNSG,
            'IVF,SQ8H': IndexType.IVF_SQ8H,
            'IVF,PQ': IndexType.IVF_PQ,
            'HNSW': IndexType.IVF_PQ,
            'Annoy': IndexType.ANNOY
        }.get(index_type, IndexType.FLAT)

    class MilvusDBInserter:
        """Milvus DB Inserter
            This class is an inner class and provides a context manager to insert vectors into Milvus while ensuring
            data is flushed.

            For more information about Milvus:
                - https://github.com/milvus-io/milvus/
        """

        def __init__(self, client, collection_name: str):
            self.logger = get_logger(self.__class__.__name__)
            self.client = client
            self.collection_name = collection_name

        def __enter__(self):
            return self

        def __exit__(self, exc_type, exc_val, exc_tb):
            self.logger.info(f'Sending flush command to Milvus Server for collection: {self.collection_name}')
            self.client.flush([self.collection_name])

        def insert(self, keys: list, vectors: 'np.ndarray'):
            status, _ = self.client.insert(collection_name=self.collection_name, records=vectors, ids=keys)
            if not status.OK():
                self.logger.error('Insert failed: {}'.format(status))
                raise MilvusDBException(status.message)

    def __init__(self, host: str, port: int, collection_name: str):
        """
        Initialize an MilvusDBHandler

        :param host: Host of the Milvus Server
        :param port: Port to connect to the Milvus Server
        :param collection_name: Name of the collection where the Handler will insert and query vectors.
        """
        self.logger = get_logger(self.__class__.__name__)
        self.host = host
        self.port = str(port)
        self.collection_name = collection_name
        self.milvus_client = None

    def __enter__(self):
        return self.connect()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()

    def connect(self):
        from milvus import Milvus
        if self.milvus_client is None or not self.milvus_client.server_status()[0].OK():
            self.logger.info(f'Setting connection to Milvus Server at {self.host}:{self.port}')
            self.milvus_client = Milvus(self.host, self.port)
        return self

    def close(self):
        self.logger.info(f'Closing connection to Milvus Server at {self.host}:{self.port}')
        self.milvus_client.close()

    def insert(self, keys: 'np.ndarray', vectors: 'np.ndarray'):
        with MilvusDBHandler.MilvusDBInserter(self.milvus_client, self.collection_name) as db:
            db.insert(reduce(operator.concat, keys.tolist()), vectors)

    def build_index(self, index_type: str, index_params: dict):
        type = self.get_index_type(index_type)

        self.logger.info(f'Creating index of type: {index_type} at'
                         f' Milvus Server. collection: {self.collection_name} with index params: {index_params}')
        status = self.milvus_client.create_index(self.collection_name, type, index_params)
        if not status.OK():
            self.logger.error('Creating index failed: {}'.format(status))
            raise MilvusDBException(status.message)

    def search(self, query_vectors: 'np.ndarray', top_k: int, search_params: dict = None):
        self.logger.info(f'Querying collection: {self.collection_name} with search params: {search_params}')
        status, results = self.milvus_client.search(collection_name=self.collection_name,
                                                    query_records=query_vectors, top_k=top_k, params=search_params)
        if not status.OK():
            self.logger.error('Querying index failed: {}'.format(status))
            raise MilvusDBException(status.message)
        else:
            return results.distance_array, results.id_array