Ejemplo n.º 1
0
class Ingester:
    def __init__(self,
                 host,
                 port,
                 collection,
                 collection_param,
                 partition=None,
                 drop=False,
                 batch_size=100,
                 dtype=np.float32):
        self.collection = collection
        self.partition = partition
        self.client = Milvus(host, port)
        self.dtype = dtype
        self.batch_size = batch_size
        if drop and self.client.has_collection(collection):
            self.client.drop_collection(collection)
        if not self.client.has_collection(collection):
            self.client.create_collection(collection, collection_param)
        if partition and not self.client.has_partition(collection, partition):
            self.client.create_partition(collection, partition)

    def ingest(self, entities, ids):
        if self.partition:
            return self.client.insert(self.collection,
                                      entities,
                                      ids=ids,
                                      partition_tag=self.partition)
        else:
            return self.client.insert(self.collection, entities, ids=ids)
Ejemplo n.º 2
0
def validate_insert(_collection_name):
    milvus = Milvus(**server_config)
    milvus.flush([_collection_name])
    status, count = milvus.count_entities(_collection_name)
    assert count == 10 * 10000, "Insert validate fail. Vectors num is not matched."

    # drop collcetion
    print("Drop collection ...")
    milvus.drop_collection(_collection_name)
    milvus.close()
Ejemplo n.º 3
0
def delete_milvus():
    client = Milvus(host=milvus_ip, port='19530')
    print(client.get_collection_stats(collection_name="ideaman"))
    print(client.get_collection_info("ideaman"))
    client.drop_collection("ideaman")
    param = {
        'collection_name': 'ideaman',
        'dimension': 128,
        'index_file_size': 1024,
        'metric_type': MetricType.L2
    }
    client.create_collection(param)
Ejemplo n.º 4
0
def milvus_test(usr_features, IS_INFER, mov_features=None, ids=None):
    _HOST = '127.0.0.1'
    _PORT = '19530'  # default value
    table_name = 'recommender_demo'
    milvus = Milvus()

    param = {'host': _HOST, 'port': _PORT}
    status = milvus.connect(**param)
    if status.OK():
        print("Server connected.")
    else:
        print("Server connect fail.")
        sys.exit(1)

    if IS_INFER:
        status = milvus.drop_collection(table_name)
        time.sleep(3)

    status, ok = milvus.has_collection(table_name)
    if not ok:
        if mov_features is None:
            print("Insert vectors is none!")
            sys.exit(1)
        param = {
            'collection_name': table_name,
            'dimension': 200,
            'index_file_size': 1024,  # optional
            'metric_type': MetricType.IP  # optional
        }

        print(milvus.create_collection(param))

        insert_vectors = normaliz_data(mov_features)
        status, ids = milvus.insert(collection_name=table_name,
                                    records=insert_vectors,
                                    ids=ids)

        time.sleep(1)

    status, result = milvus.count_collection(table_name)
    print("rows in table recommender_demo:", result)

    search_vectors = normaliz_data(usr_features)
    param = {
        'collection_name': table_name,
        'query_records': search_vectors,
        'top_k': 5,
        'params': {
            'nprobe': 16
        }
    }
    time1 = time.time()
    status, results = milvus.search(**param)
    time2 = time.time()

    print("Top\t", "Ids\t", "Title\t", "Score")
    for i, re in enumerate(results[0]):
        title = paddle.dataset.movielens.movie_info()[int(re.id)].title
        print(i, "\t", re.id, "\t", title, "\t", float(re.distance) * 5)
Ejemplo n.º 5
0
    def test_not_connect(self):
        client = Milvus()

        with pytest.raises(NotConnectError):
            client.create_collection({})

        with pytest.raises(NotConnectError):
            client.has_collection("a")

        with pytest.raises(NotConnectError):
            client.describe_collection("a")

        with pytest.raises(NotConnectError):
            client.drop_collection("a")

        with pytest.raises(NotConnectError):
            client.create_index("a")

        with pytest.raises(NotConnectError):
            client.insert("a", [], None)

        with pytest.raises(NotConnectError):
            client.count_collection("a")

        with pytest.raises(NotConnectError):
            client.show_collections()

        with pytest.raises(NotConnectError):
            client.search("a", 1, 2, [], None)

        with pytest.raises(NotConnectError):
            client.search_in_files("a", [], [], 2, 1, None)

        with pytest.raises(NotConnectError):
            client._cmd("")

        with pytest.raises(NotConnectError):
            client.preload_collection("a")

        with pytest.raises(NotConnectError):
            client.describe_index("a")

        with pytest.raises(NotConnectError):
            client.drop_index("")
Ejemplo n.º 6
0
 def del_milvus_collection(name):
     try:
         milvus = Milvus(host=MILVUS_ADDR, port=MILVUS_PORT)
         res = milvus.drop_collection(collection_name=name)
         if not res.OK():
             raise MilvusError(
                 "There was some error when drop milvus collection", res)
     except Exception as e:
         err_msg = "There was some error when delete milvus collection"
         logger.error(f"{err_msg} : {str(e)}", exc_info=True)
         raise MilvusError(err_msg, e)
Ejemplo n.º 7
0
 def del_milvus_collection(name):
     milvus = Milvus()
     try:
         milvus.connect(MILVUS_ADDR, MILVUS_PORT)
         res = milvus.drop_collection(collection_name=name)
         if not res.OK():
             raise MilvusError(
                 "There has some error when drop milvus collection", res)
     except Exception as e:
         raise MilvusError(
             "There has some error when delete milvus collection", e)
Ejemplo n.º 8
0
def milvus_test(usr_features, mov_features, ids):
    _HOST = '127.0.0.1'
    _PORT = '19530'  # default value
    milvus = Milvus()

    param = {'host': _HOST, 'port': _PORT}
    status = milvus.connect(**param)
    if status.OK():
        print("\nServer connected.")
    else:
        print("\nServer connect fail.")
        sys.exit(1)

    table_name = 'paddle_demo1'

    status, ok = milvus.has_collection(table_name)
    if not ok:
        param = {
            'collection_name': table_name,
            'dimension': 200,
            'index_file_size': 1024,  # optional
            'metric_type': MetricType.IP  # optional
        }

        milvus.create_collection(param)

    insert_vectors = normaliz_data([usr_features.tolist()])
    status, ids = milvus.insert(collection_name=table_name, records=insert_vectors, ids = ids)

    time.sleep(1)

    status, result = milvus.count_collection(table_name)
    print("rows in table paddle_demo1:", result)

    status, table = milvus.count_collection(table_name)

    search_vectors = normaliz_data([mov_features.tolist()])
    param = {
        'collection_name': table_name,
        'query_records': search_vectors,
        'top_k': 1,
        'params': {'nprobe': 16}
    }
    status, results = milvus.search(**param)
    print("Searched ids:", results[0][0].id)
    print("Score:", float(results[0][0].distance)*5)

    status = milvus.drop_collection(table_name)
Ejemplo n.º 9
0
    def _create_collection(_collection_param):
        milvus = Milvus(**server_config)
        status, ok = milvus.has_collection(_collection_name)
        if ok:
            print("Collection {} found, now going to delete it".format(
                _collection_name))
            status = milvus.drop_collection(_collection_name)
            if not status.OK():
                raise Exception("Delete collection error")
            print(
                "delete collection {} successfully!".format(_collection_name))
        time.sleep(5)

        status, ok = milvus.has_collection(_collection_name)
        if ok:
            raise Exception("Delete collection error")

        status = milvus.create_collection(param)
        if not status.OK():
            print("Create collection {} failed".format(_collection_name))
        milvus.close()
Ejemplo n.º 10
0
from milvus import Milvus, IndexType, MetricType, Status
#连续服务器
milvus = Milvus(host='localhost', port='19530')
#创建集合
param = {
    'collection_name': 'test01',
    'dimension': 256,
    'index_file_size': 1024,
    'metric_type': MetricType.L2
}
#milvus.create_collection(param)
#删除集合
milvus.drop_collection(collection_name="test01")
print("collection:", milvus.list_collections())
milvus.create_collection(param)
#print(milvus.list_partitions("test01"))
#创建分区
milvus.create_partition('test01', 'tag01')
#print(milvus.list_partitions("test01"))
#删除分区
#milvus.drop_partition('test01','tag01')
print("partition:", milvus.list_partitions("test01"))

import time
import random
import numpy as np
#vectors = [[random.random() for _ in range(256)] for _ in range(3)]
#print(np.shape(np.array(vectors)))
#vector_ids = [1,2,3]
vectors = []
vector_ids = []
Ejemplo n.º 11
0
def main():
    milvus = Milvus(_HOST, _PORT)

    # num = random.randint(1, 100000)
    num = 100000
    # Create collection demo_collection if it dosen't exist.
    collection_name = 'example_hybrid_collections_{}'.format(num)
    if milvus.has_collection(collection_name):
        milvus.drop_collection(collection_name)

    collection_param = {
        "fields": [{
            "field": "A",
            "type": DataType.INT32
        }, {
            "field": "B",
            "type": DataType.INT32
        }, {
            "field": "C",
            "type": DataType.INT64
        }, {
            "field": "Vec",
            "type": DataType.FLOAT_VECTOR,
            "params": {
                "dim": 128,
                "metric_type": "L2"
            }
        }],
        "segment_size":
        100
    }
    milvus.create_collection(collection_name, collection_param)

    milvus.compact(collection_name)

    # milvus.create_partition(collection_name, "p_01", timeout=1800)
    # pars = milvus.list_partitions(collection_name)
    # ok = milvus.has_partition(collection_name, "p_01", timeout=1800)
    # assert ok
    # ok = milvus.has_partition(collection_name, "p_02")
    # assert not ok
    # for p in pars:
    #     if p == "_default":
    #         continue
    #     milvus.drop_partition(collection_name, p)

    # milvus.drop_collection(collection_name)
    # sys.exit(0)

    A_list = [random.randint(0, 255) for _ in range(num)]
    vec = [[random.random() for _ in range(128)] for _ in range(num)]
    hybrid_entities = [{
        "field": "A",
        "values": A_list,
        "type": DataType.INT32
    }, {
        "field": "B",
        "values": A_list,
        "type": DataType.INT32
    }, {
        "field": "C",
        "values": A_list,
        "type": DataType.INT64
    }, {
        "field": "Vec",
        "values": vec,
        "type": DataType.FLOAT_VECTOR,
        "params": {
            "dim": 128
        }
    }]

    for slice_e in utils.entities_slice(hybrid_entities):
        ids = milvus.insert(collection_name, slice_e)
    milvus.flush([collection_name])
    print("Flush ... ")
    # time.sleep(3)
    count = milvus.count_entities(collection_name)

    milvus.delete_entity_by_id(collection_name, ids[:1])
    milvus.flush([collection_name])
    print("Get entity be id start ...... ")
    entities = milvus.get_entity_by_id(collection_name, ids[:1])
    et = entities.dict()
    milvus.delete_entity_by_id(collection_name, ids[1:2])
    milvus.flush([collection_name])

    print("Create index ......")
    milvus.create_index(collection_name, "Vec", {
        "index_type": "IVF_FLAT",
        "metric_type": "L2",
        "params": {
            "nlist": 100
        }
    })
    print("Create index done.")

    info = milvus.get_collection_info(collection_name)
    print(info)
    stats = milvus.get_collection_stats(collection_name)
    print("\nstats\n")
    print(stats)
    query_hybrid = \
    {
        "bool": {
            "must": [
                {
                    "term": {
                        "A": [1, 2, 5]
                    }
                },
                {
                    "range": {
                        "B": {"GT": 1, "LT": 100}
                    }
                },
                {
                    "vector": {
                        "Vec": {
                            "topk": 10, "query": vec[: 10000], "params": {"nprobe": 10}
                        }
                    }
                }
            ],
        },
    }

    # print("Start searach ..", flush=True)
    # results = milvus.search(collection_name, query_hybrid)
    # print(results)
    #
    # for r in list(results):
    #     print("ids", r.ids)
    #     print("distances", r.distances)

    t0 = time.time()
    count = 0
    results = milvus.search(collection_name, query_hybrid, fields=["B"])
    for r in list(results):
        # print("ids", r.ids)
        # print("distances", r.distances)
        for rr in r:
            count += 1
            # print(rr.entity.get("B"))

    print("Search cost {} s".format(time.time() - t0))

    # for result in results:
    #     for r in result:
    #         print(f"{r}")

    # itertor entity id
    # for result in results:
    #     for r in result:
    #         # get distance
    #         dis = r.distance
    #         id_ = r.id
    #         # obtain all field name
    #         fields = r.entity.fields
    #         for f in fields:
    #             # get field value by field name
    #             # fv = r.entity.
    #             fv = r.entity.value_of_field(f)
    #             print(fv)

    milvus.drop_collection(collection_name)
Ejemplo n.º 12
0
import numpy as np
from milvus import Milvus, IndexType, MetricType
import time

t1 = time.time()
# 初始化一个Milvus类,以后所有的操作都是通过milvus来的
milvus = Milvus(host='localhost', port='19530')
vec_dim = 1000 #向量的维度
num_vec=100 #向量数量10000,后面还加入了一条和查询向量一样的做验证
#删除collection
milvus.drop_collection('test01')
#创建collection
param = {'collection_name':'test01', 'dimension':vec_dim, 'index_file_size':1024, 'metric_type':MetricType.L2}
milvus.create_collection(param)
#建立分区
milvus.create_partition('test01', 'tag01')

#nlist聚类为多少簇
ivf_param = {'nlist': 16}
milvus.create_index('test01', IndexType.IVF_FLAT, ivf_param) #可以不写这个,默认方式IndexType.FLAT
# 随机生成一批向量数据
vectors_array = np.random.rand(num_vec,vec_dim)
vectors_list = vectors_array.tolist()
vectors_list.append([1 for _ in range(vec_dim)])
ids_list = [i for i in range(len(vectors_list))]
#单次插入的数据量不能大于 256 MB,插入后存在缓存区,缓存区大小由参数index_file_size决定,默认1024M
milvus.insert(collection_name='test01', records=vectors_list, partition_tag="tag01",ids=ids_list)
#一些信息
print(milvus.list_collections())
print(milvus.get_collection_info('test01'))
print(milvus.get_index_info('test01'))
Ejemplo n.º 13
0
def main():
    # Connect to Milvus server
    # You may need to change _HOST and _PORT accordingly
    param = {'host': _HOST, 'port': _PORT}

    # You can create a instance specified server addr and
    # invoke rpc method directly
    client = Milvus(**param)
    # Create collection demo_collection if it dosen't exist.
    collection_name = 'demo_partition_collection'
    partition_tag = "random"

    status, ok = client.has_collection(collection_name)
    # if collection exists, then drop it
    if status.OK() and ok:
        client.drop_collection(collection_name)

    param = {
        'collection_name': collection_name,
        'dimension': _DIM,
        'index_file_size': _INDEX_FILE_SIZE,  # optional
        'metric_type': MetricType.L2  # optional
    }

    client.create_collection(param)

    # Show collections in Milvus server
    _, collections = client.show_collections()

    # Describe collection
    _, collection = client.describe_collection(collection_name)
    print(collection)

    # create partition
    client.create_partition(collection_name, partition_tag=partition_tag)
    # display partitions
    _, partitions = client.show_partitions(collection_name)

    # 10000 vectors with 16 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(10000)]
    # You can also use numpy to generate random vectors:
    #     `vectors = np.random.rand(10000, 16).astype(np.float32).tolist()`

    # Insert vectors into partition of collection, return status and vectors id list
    status, ids = client.insert(collection_name=collection_name, records=vectors, partition_tag=partition_tag)

    # Wait for 6 seconds, until Milvus server persist vector data.
    time.sleep(6)

    # Get demo_collection row count
    status, num = client.count_collection(collection_name)

    # create index of vectors, search more rapidly
    index_param = {
        'nlist': 2048
    }

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    status = client.create_index(collection_name, IndexType.IVF_FLAT, index_param)

    # describe index, get information of index
    status, index = client.describe_index(collection_name)
    print(index)

    # Use the top 10 vectors for similarity search
    query_vectors = vectors[0:10]

    # execute vector similarity search, search range in partition `partition1`
    search_param = {
        "nprobe": 10
    }

    param = {
        'collection_name': collection_name,
        'query_records': query_vectors,
        'top_k': 1,
        'partition_tags': ["random"],
        'params': search_param
    }
    status, results = client.search(**param)

    if status.OK():
        # indicate search result
        # also use by:
        #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
        if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
            print('Query result is correct')
        else:
            print('Query result isn\'t correct')

    # print results
    print(results)

    # Delete partition. You can also invoke `drop_collection()`, so that all of partitions belongs to
    # designated collections will be deleted.
    # status = client.drop_partition(collection_name, partition_tag)

    # Delete collection. All of partitions of this collection will be dropped.
    status = client.drop_collection(collection_name)
Ejemplo n.º 14
0
def main():
    milvus = Milvus(uri=uri)
    param = {
        'collection_name': collection_name,
        'dimension': _DIM,
        'index_file_size': 32,
        #'metric_type': MetricType.IP
        'metric_type': MetricType.L2
    }
    # show collections in Milvus server
    _, collections = milvus.list_collections()

    # 创建 collection
    milvus.create_collection(param)
    # 创建 collection partion
    milvus.create_partition(collection_name, partition_tag)

    print(f'collections in Milvus: {collections}')
    # Describe demo_collection
    _, collection = milvus.get_collection_info(collection_name)
    print(f'descript demo_collection: {collection}')

    # build fake vectors
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(10)]
    vectors1 = [[random.random() for _ in range(_DIM)] for _ in range(10)]

    status, id = milvus.insert(collection_name=collection_name,
                               records=vectors,
                               ids=list(range(10)),
                               partition_tag=partition_tag)
    print(f'status: {status} | id: {id}')
    if not status.OK():
        print(f"insert failded: {status}")

    status1, id1 = milvus.insert(collection_name=collection_name,
                                 records=vectors1,
                                 ids=list(range(10, 20)),
                                 partition_tag=partition_tag)
    print(f'status1: {status1} | id1: {id1}')

    ids_deleted = list(range(10))

    status_delete = milvus.delete_entity_by_id(collection_name=collection_name,
                                               id_array=ids_deleted)
    if status_delete.OK():
        print(f'delete successful')

    # Flush collection insered data to disk
    milvus.flush([collection_name])
    # Get demo_collection row count
    status, result = milvus.count_entities(collection_name)
    print(f"demo_collection row count: {result}")

    # Obtain raw vectors by providing vector ids
    status, result_vectors = milvus.get_entity_by_id(collection_name,
                                                     list(range(10, 20)))

    # create index of vectors, search more repidly
    index_param = {'nlist': 2}

    # create ivflat index in demo_collection
    status = milvus.create_index(collection_name, IndexType.IVF_FLAT,
                                 index_param)
    if status.OK():
        print(f"create index ivf_flat succeeed")

    # use the top 10 vectors for similarity search
    query_vectors = vectors1[0:2]

    # execute vector similariy search
    search_param = {"nprobe": 16}

    param = {
        'collection_name': collection_name,
        'query_records': query_vectors,
        'top_k': 1,
        'params': search_param
    }

    status, results = milvus.search(**param)
    if status.OK():
        if results[0][0].distance == 0.0:
            print('query result is correct')
        else:
            print('not correct')
        print(results)
    else:
        print(f'search failed: {status}')

    # 清除已经存在的collection
    milvus.drop_collection(collection_name=collection_name)

    milvus.close()
Ejemplo n.º 15
0
class SearchEngine:
    def __init__(self, host, port):
        self.host = os.environ.get('MILVUS_HOST', host)
        self.port = os.environ.get('MILVUS_PORT', str(port))
        self.engine = Milvus(host=self.host, port=self.port)
        self.collection_name = None

#################################################
# HANDLE COLLECTION
#################################################

    def create_collection(self, collection_name, dimension):
        # collection 생성

        param = {
            'collection_name': collection_name,
            'dimension': dimension,
            'index_file_size': 1000,
            'metric_type': MetricType.IP
        }

        self.engine.create_collection(param)

        print('[INFO] collection {}을 생성했습니다.'.format(collection_name))

    def drop_collection(self, collection_name):
        # collection 삭제

        self.engine.drop_collection(collection_name=collection_name)

        print('[INFO] collection {}을 삭제했습니다.'.format(collection_name))

    def get_collection_stats(self, collection_name):
        # collection 정보 출력

        print(self.engine.get_collection_info(collection_name))
        print(self.engine.get_collection_stats(collection_name))

    def set_collection(self, collection_name):
        # 쿼리 조작을 하기 위한 collection 지정

        self.collection_name = collection_name
        print('[INFO] setting collection {}'.format(self.collection_name))

#################################################
# UTILS
#################################################

    def check_set_collection(self):
        # 쿼리 조작을 위한 collection 지정이 되어있는 지 체크

        assert self.collection_name is not None, '[ERROR] collection을 setting해 주십시오!!'

    def check_exist_data_by_key(self, key):
        # collection에 정해진 key가 존재하는 지 확인

        self.check_set_collection()

        _, vector = self.engine.get_entity_by_id(
            collection_name=self.collection_name, ids=key)

        vector = vector if vector else [vector]

        return True if vector[0] else False

    def convert_key_format(self, key):
        return [key] if isinstance(key, int) else key

    def convert_value_format(self, value):
        rank = len(value.shape)

        assert rank < 2, '[ERROR] value의 dim을 2 미만으로 입력해 주세요!!'
        return value.reshape(1, -1) if rank == 1 else value

#################################################
# INSERT
#################################################

    def insert_data(self, key, value):
        # 데이터를 collection에 입력

        key = self.convert_key_format(key)
        value = self.convert_value_format(value)

        if self.check_exist_data_by_key(key):
            print("[ERROR] 이미 collection에 데이터가 존재합니다.")
            return

        self.engine.insert(collection_name=self.collection_name,
                           records=value,
                           ids=key)
        self.engine.flush([self.collection_name])

        print('[INFO] insert key {}'.format(key))

#################################################
# DELELTE
#################################################

    def delete_data(self, key):
        # 데이터를 collection에서 제거

        key = self.convert_key_format(key)

        if not self.check_exist_data_by_key(key):
            print("[ERROR] collection에 데이터가 존재하지 않습니다.")
            return

        self.engine.delete_entity_by_id(self.collection_name, key)
        self.engine.flush([self.collection_name])

        print('[INFO] delete key {}'.format(key))

#################################################
# UPDATE
#################################################

    def update_data(self, key, value):
        # 데이터를 업데이트

        key = self.convert_key_format(key)
        value = self.convert_value_format(value)

        if not self.check_exist_data_by_key(key):
            print("[ERROR] collection에 데이터가 존재하지 않습니다.")
            return

        self.engine.delete_entity_by_id(self.collection_name, key)
        self.engine.flush([self.collection_name])

        self.engine.insert(collection_name=self.collection_name,
                           records=value,
                           ids=key)
        self.engine.flush([self.collection_name])

        print('[INFO] update key {}'.format(key))

#################################################
# SEARCH
#################################################

    def search_by_feature(self, feature, top_k):
        # feature를 이용해서 데이터를 검색

        self.check_set_collection()

        feature = self.convert_value_format(feature)

        _, result = self.engine.search(collection_name=self.collection_name,
                                       query_records=feature,
                                       top_k=top_k)

        li_id = [
            list(map(lambda x: x.id, result[0])) for i in range(len(result))
        ]
        li_dist = [
            list(map(lambda x: x.distance, result[0]))
            for i in range(len(result))
        ]

        return li_id, li_dist

    def search_by_key(self, key, top_k):
        # key를 이용해서 데이터를 검색

        self.check_set_collection()
        key = self.convert_key_format(key)

        if not self.check_exist_data_by_key(key):
            print("[ERROR] collection에 데이터가 존재하지 않습니다.")
            return

        _, vector = self.engine.get_entity_by_id(
            collection_name=self.collection_name, ids=key)
        _, result = self.engine.search(collection_name=self.collection_name,
                                       query_records=vector,
                                       top_k=top_k + 1)

        li_id = [
            list(map(lambda x: x.id, result[0][1:]))
            for i in range(len(result))
        ]
        li_dist = [
            list(map(lambda x: x.distance, result[0][1:]))
            for i in range(len(result))
        ]

        return li_id, li_dist
Ejemplo n.º 16
0
def main():
    milvus = Milvus()

    # Connect to Milvus server
    # You may need to change _HOST and _PORT accordingly
    param = {'host': _HOST, 'port': _PORT}
    status = milvus.connect(**param)
    if status.OK():
        print("Server connected.")
    else:
        print("Server connect fail.")
        sys.exit(1)

    # Create collection demo_collection if it dosen't exist.
    collection_name = 'example_collection'

    status, ok = milvus.has_collection(collection_name)
    if not ok:
        param = {
            'collection_name': collection_name,
            'dimension': _DIM,
            'index_file_size': _INDEX_FILE_SIZE,  # optional
            'metric_type': MetricType.L2  # optional
        }

        milvus.create_collection(param)

    # Show collections in Milvus server
    _, collections = milvus.show_collections()

    # present collection info
    _, info = milvus.collection_info(collection_name)
    print(info)

    # Describe demo_collection
    _, collection = milvus.describe_collection(collection_name)
    print(collection)

    # 10000 vectors with 16 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(10000)]
    # You can also use numpy to generate random vectors:
    #     `vectors = np.random.rand(10000, 16).astype(np.float32)`

    # Insert vectors into demo_collection, return status and vectors id list
    status, ids = milvus.insert(collection_name=collection_name, records=vectors)

    # Flush collection  inserted data to disk.
    milvus.flush([collection_name])

    # Get demo_collection row count
    status, result = milvus.count_collection(collection_name)

    # create index of vectors, search more rapidly
    index_param = {
        'nlist': 2048
    }

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    print("Creating index: {}".format(index_param))
    status = milvus.create_index(collection_name, IndexType.IVF_FLAT, index_param)

    # describe index, get information of index
    status, index = milvus.describe_index(collection_name)
    print(index)

    # Use the top 10 vectors for similarity search
    query_vectors = vectors[0:10]

    # execute vector similarity search
    search_param = {
        "nprobe": 16
    }
    param = {
        'collection_name': collection_name,
        'query_records': query_vectors,
        'top_k': 1,
        'params': search_param
    }
    print("Searching ... ")
    status, results = milvus.search(**param)

    if status.OK():
        # indicate search result
        # also use by:
        #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
        if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
            print('Query result is correct')
        else:
            print('Query result isn\'t correct')

    # print results
    print(results)

    # Delete demo_collection
    status = milvus.drop_collection(collection_name)

    # Disconnect from Milvus
    status = milvus.disconnect()
Ejemplo n.º 17
0
class MilvusHelper:
    def __init__(self):
        try:
            self.client = Milvus(host=MILVUS_HOST, port=MILVUS_PORT)
            LOGGER.debug(
                "Successfully connect to Milvus with IP:{} and PORT:{}".format(
                    MILVUS_HOST, MILVUS_PORT))
        except Exception as e:
            LOGGER.error("Failed to connect Milvus: {}".format(e))
            sys.exit(1)

    # Return if Milvus has the collection
    def has_collection(self, collection_name):
        try:
            status = self.client.has_collection(collection_name)[1]
            return status
        except Exception as e:
            LOGGER.error("Failed to load data to Milvus: {}".format(e))
            sys.exit(1)

    # Create milvus collection if not exists
    def create_colllection(self, collection_name):
        try:
            if not self.has_collection(collection_name):
                collection_param = {
                    'collection_name': collection_name,
                    'dimension': VECTOR_DIMENSION,
                    'index_file_size': INDEX_FILE_SIZE,
                    'metric_type': METRIC_TYPE
                }
                status = self.client.create_collection(collection_param)
                if status.code != 0:
                    raise Exception(status.message)
                LOGGER.debug(
                    "Create Milvus collection: {}".format(collection_name))
        except Exception as e:
            LOGGER.error("Failed to load data to Milvus: {}".format(e))
            sys.exit(1)

    # Batch insert vectors to milvus collection
    def insert(self, collection_name, vectors):
        try:
            self.create_colllection(collection_name)
            status, ids = self.client.insert(collection_name=collection_name,
                                             records=vectors)
            if not status.code:
                LOGGER.debug(
                    "Insert vectors to Milvus in collection: {} with {} rows".
                    format(collection_name, len(vectors)))
                return ids
            else:
                raise Exception(status.message)
        except Exception as e:
            LOGGER.error("Failed to load data to Milvus: {}".format(e))
            sys.exit(1)

    # Create IVF_FLAT index on milvus collection
    def create_index(self, collection_name):
        try:
            index_param = {'nlist': 16384}
            status = self.client.create_index(collection_name,
                                              IndexType.IVF_FLAT, index_param)
            if not status.code:
                LOGGER.debug(
                    "Successfully create index in collection:{} with param:{}".
                    format(collection_name, index_param))
                return status
            else:
                raise Exception(status.message)
        except Exception as e:
            LOGGER.error("Failed to create index: {}".format(e))
            sys.exit(1)

    # Delete Milvus collection
    def delete_collection(self, collection_name):
        try:
            status = self.client.drop_collection(
                collection_name=collection_name)
            if not status.code:
                LOGGER.debug(
                    "Successfully drop collection: {}".format(collection_name))
                return status
            else:
                raise Exception(status.message)
        except Exception as e:
            LOGGER.error("Failed to drop collection: {}".format(e))
            sys.exit(1)

    # Search vector in milvus collection
    def search_vectors(self, collection_name, vectors, top_k):
        try:
            search_param = {'nprobe': 16}
            status, result = self.client.search(
                collection_name=collection_name,
                query_records=vectors,
                top_k=top_k,
                params=search_param)
            if not status.code:
                LOGGER.debug("Successfully search in collection: {}".format(
                    collection_name))
                return result
            else:
                raise Exception(status.message)
        except Exception as e:
            LOGGER.error("Failed to search vectors in Milvus: {}".format(e))
            sys.exit(1)

    # Get the number of milvus collection
    def count(self, collection_name):
        try:
            status, num = self.client.count_entities(
                collection_name=collection_name)
            if not status.code:
                LOGGER.debug(
                    "Successfully get the num:{} of the collection:{}".format(
                        num, collection_name))
                return num
            else:
                raise Exception(status.message)
        except Exception as e:
            LOGGER.error("Failed to count vectors in Milvus: {}".format(e))
            sys.exit(1)
Ejemplo n.º 18
0
def main():
    # Specify server addr when create milvus client instance
    # milvus client instance maintain a connection pool, param
    # `pool_size` specify the max connection num.
    # 获取服务端的连接
    milvus = Milvus(_HOST, _PORT)

    # Create collection demo_collection if it dosen't exist.
    # 创建collection
    collection_name = 'example_collection_'
    # 看是否有这个collection
    status, ok = milvus.has_collection(collection_name)
    # 如果没有则创建
    if not ok:
        param = {
            'collection_name': collection_name,
            'dimension': _DIM,
            'index_file_size': _INDEX_FILE_SIZE,  # optional
            'metric_type': MetricType.L2  # optional
        }
        # 创建collection
        milvus.create_collection(param)

    # Show collections in Milvus server
    # 查看所有的collection
    _, collections = milvus.list_collections()
    print(collections)

    # Describe demo_collection
    # 得到当前的collection
    _, collection = milvus.get_collection_info(collection_name)
    print(collection)

    # 10000 vectors with 128 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    # 创建10个长度为8的向量
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(10)]
    print(vectors)
    # You can also use numpy to generate random vectors:
    #   vectors = np.random.rand(10000, _DIM).astype(np.float32)

    # Insert vectors into demo_collection, return status and vectors id list
    # 把这10个向量都插入milvus
    status, ids = milvus.insert(collection_name=collection_name,
                                records=vectors)
    if not status.OK():
        print("Insert failed: {}".format(status))
    print(ids)

    # Flush collection  inserted data to disk.
    # 数据落盘
    milvus.flush([collection_name])
    # Get demo_collection row count
    # 得到当前row的数量
    status, result = milvus.count_entities(collection_name)
    print(status)
    print(result)

    # present collection statistics info
    # 查看collection的统计数据
    _, info = milvus.get_collection_stats(collection_name)
    print(info)

    # Obtain raw vectors by providing vector ids
    # 得到前十个数据
    status, result_vectors = milvus.get_entity_by_id(collection_name, ids[:10])
    print(result_vectors)

    # create index of vectors, search more rapidly
    # 创建索引
    index_param = {'nlist': 2048}

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    # 创建ivf_flat
    print("Creating index: {}".format(index_param))
    status = milvus.create_index(collection_name, IndexType.IVF_FLAT,
                                 index_param)

    # describe index, get information of index
    # 得到索引的信息
    status, index = milvus.get_index_info(collection_name)
    print(index)

    # Use the top 10 vectors for similarity search
    # 对前10个数据进行query
    query_vectors = vectors[0:10]

    # execute vector similarity search
    # 索引的搜索的中心点数量
    search_param = {"nprobe": 16}

    print("Searching ... ")

    param = {
        'collection_name': collection_name,
        'query_records': query_vectors,
        'top_k': 1,
        'params': search_param,
    }
    # 进行搜索
    status, results = milvus.search(**param)
    if status.OK():
        print(results)
        # indicate search result
        # also use by:
        #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
        if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
            print('Query result is correct')
        else:
            print('Query result isn\'t correct')

        # print results
        print(results)
    else:
        print("Search failed. ", status)

    # Delete demo_collection
    # 删除掉collection
    status = milvus.drop_collection(collection_name)
Ejemplo n.º 19
0
class MilvusClient(object):
    def __init__(self, collection_name=None, host=None, port=None, timeout=60):
        """
        Milvus client wrapper for python-sdk.

        Default timeout set 60s
        """
        self._collection_name = collection_name
        try:
            start_time = time.time()
            if not host:
                host = SERVER_HOST_DEFAULT
            if not port:
                port = SERVER_PORT_DEFAULT
            logger.debug(host)
            logger.debug(port)
            # retry connect for remote server
            i = 0
            while time.time() < start_time + timeout:
                try:
                    self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False)
                    if self._milvus.server_status():
                        logger.debug("Try connect times: %d, %s" % (i, round(time.time() - start_time, 2)))
                        break
                except Exception as e:
                    logger.debug("Milvus connect failed: %d times" % i)
                    i = i + 1

            if time.time() > start_time + timeout:
                raise Exception("Server connect timeout")

        except Exception as e:
            raise e
        self._metric_type = None
        if self._collection_name and self.exists_collection():
            self._metric_type = metric_type_to_str(self.describe()[1].metric_type)
            self._dimension = self.describe()[1].dimension

    def __str__(self):
        return 'Milvus collection %s' % self._collection_name

    def set_collection(self, name):
        self._collection_name = name

    def check_status(self, status):
        if not status.OK():
            logger.error(self._collection_name)
            logger.error(status.message)
            logger.error(self._milvus.server_status())
            logger.error(self.count())
            raise Exception("Status not ok")

    def check_result_ids(self, result):
        for index, item in enumerate(result):
            if item[0].distance >= epsilon:
                logger.error(index)
                logger.error(item[0].distance)
                raise Exception("Distance wrong")

    def create_collection(self, collection_name, dimension, index_file_size, metric_type):
        if not self._collection_name:
            self._collection_name = collection_name
        if metric_type not in METRIC_MAP.keys():
            raise Exception("Not supported metric_type: %s" % metric_type)
        metric_type = METRIC_MAP[metric_type]
        create_param = {'collection_name': collection_name,
                 'dimension': dimension,
                 'index_file_size': index_file_size, 
                 "metric_type": metric_type}
        status = self._milvus.create_collection(create_param)
        self.check_status(status)

    def create_partition(self, tag_name):
        status = self._milvus.create_partition(self._collection_name, tag_name)
        self.check_status(status)

    def drop_partition(self, tag_name):
        status = self._milvus.drop_partition(self._collection_name, tag_name)
        self.check_status(status)

    def list_partitions(self):
        status, tags = self._milvus.list_partitions(self._collection_name)
        self.check_status(status)
        return tags

    @time_wrapper
    def insert(self, X, ids=None, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status, result = self._milvus.insert(collection_name, X, ids)
        self.check_status(status)
        return status, result

    def insert_rand(self):
        insert_xb = random.randint(1, 100)
        X = [[random.random() for _ in range(self._dimension)] for _ in range(insert_xb)]
        X = utils.normalize(self._metric_type, X)
        count_before = self.count()
        status, _ = self.insert(X)
        self.check_status(status)
        self.flush()
        if count_before + insert_xb != self.count():
            raise Exception("Assert failed after inserting")

    def get_rand_ids(self, length):
        while True:
            status, stats = self._milvus.get_collection_stats(self._collection_name)
            self.check_status(status)
            segments = stats["partitions"][0]["segments"]
            # random choice one segment
            segment = random.choice(segments)
            status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
            if not status.OK():
                logger.error(status.message)
                continue
            if len(segment_ids):
                break
        if length >= len(segment_ids):
            logger.debug("Reset length: %d" % len(segment_ids))
            return segment_ids
        return random.sample(segment_ids, length)

    def get_rand_ids_each_segment(self, length):
        res = []
        status, stats = self._milvus.get_collection_stats(self._collection_name)
        self.check_status(status)
        segments = stats["partitions"][0]["segments"]
        segments_num = len(segments)
        # random choice from each segment
        for segment in segments:
            status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
            self.check_status(status)
            res.extend(segment_ids[:length])
        return segments_num, res

    def get_rand_entities(self, length):
        ids = self.get_rand_ids(length)
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids)
        self.check_status(status)
        return ids, get_res

    @time_wrapper
    def get_entities(self, get_ids):
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids)
        self.check_status(status)
        return get_res

    @time_wrapper
    def delete(self, ids, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status = self._milvus.delete_entity_by_id(collection_name, ids)
        self.check_status(status)

    def delete_rand(self):
        delete_id_length = random.randint(1, 100)
        count_before = self.count()
        logger.info("%s: length to delete: %d" % (self._collection_name, delete_id_length))
        delete_ids = self.get_rand_ids(delete_id_length)
        self.delete(delete_ids)
        self.flush()
        logger.info("%s: count after delete: %d" % (self._collection_name, self.count()))
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids)
        self.check_status(status)
        for item in get_res:
            if item:
                raise Exception("Assert failed after delete")
        if count_before - len(delete_ids) != self.count():
            raise Exception("Assert failed after delete")

    @time_wrapper
    def flush(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status = self._milvus.flush([collection_name])
        self.check_status(status)

    @time_wrapper
    def compact(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status = self._milvus.compact(collection_name)
        self.check_status(status)

    @time_wrapper
    def create_index(self, index_type, index_param=None):
        index_type = INDEX_MAP[index_type]
        logger.info("Building index start, collection_name: %s, index_type: %s" % (self._collection_name, index_type))
        if index_param:
            logger.info(index_param)
        status = self._milvus.create_index(self._collection_name, index_type, index_param)
        self.check_status(status)

    def describe_index(self):
        status, result = self._milvus.get_index_info(self._collection_name)
        self.check_status(status)
        index_type = None
        for k, v in INDEX_MAP.items():
            if result._index_type == v:
                index_type = k
                break
        return {"index_type": index_type, "index_param": result._params}

    def drop_index(self):
        logger.info("Drop index: %s" % self._collection_name)
        return self._milvus.drop_index(self._collection_name)

    def query(self, X, top_k, search_param=None, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status, result = self._milvus.search(collection_name, top_k, query_records=X, params=search_param)
        self.check_status(status)
        return result

    def query_rand(self):
        top_k = random.randint(1, 100)
        nq = random.randint(1, 100)
        nprobe = random.randint(1, 100)
        search_param = {"nprobe": nprobe}
        _, X = self.get_rand_entities(nq)
        logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe))
        status, _ = self._milvus.search(self._collection_name, top_k, query_records=X, params=search_param)
        self.check_status(status)
        # for i, item in enumerate(search_res):
        #     if item[0].id != ids[i]:
        #         logger.warning("The index of search result: %d" % i)
        #         raise Exception("Query failed")

    # @time_wrapper
    # def query_ids(self, top_k, ids, search_param=None):
    #     status, result = self._milvus.search_by_id(self._collection_name, ids, top_k, params=search_param)
    #     self.check_result_ids(result)
    #     return result

    def count(self, name=None):
        if name is None:
            name = self._collection_name
        logger.debug(self._milvus.count_entities(name))
        row_count = self._milvus.count_entities(name)[1]
        if not row_count:
            row_count = 0
        logger.debug("Row count: %d in collection: <%s>" % (row_count, name))
        return row_count

    def drop(self, timeout=120, name=None):
        timeout = int(timeout)
        if name is None:
            name = self._collection_name
        logger.info("Start delete collection: %s" % name)
        status = self._milvus.drop_collection(name)
        self.check_status(status)
        i = 0
        while i < timeout:
            if self.count(name=name):
                time.sleep(1)
                i = i + 1
                continue
            else:
                break
        if i >= timeout:
            logger.error("Delete collection timeout")

    def describe(self):
        # logger.info(self._milvus.get_collection_info(self._collection_name))
        return self._milvus.get_collection_info(self._collection_name)

    def show_collections(self):
        return self._milvus.list_collections()

    def exists_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        _, res = self._milvus.has_collection(collection_name)
        # self.check_status(status)
        return res

    def clean_db(self):
        collection_names = self.show_collections()[1]
        for name in collection_names:
            logger.debug(name)
            self.drop(name=name)

    @time_wrapper
    def preload_collection(self):
        status = self._milvus.load_collection(self._collection_name, timeout=3000)
        self.check_status(status)
        return status

    def get_server_version(self):
        _, res = self._milvus.server_version()
        return res

    def get_server_mode(self):
        return self.cmd("mode")

    def get_server_commit(self):
        return self.cmd("build_commit_id")

    def get_server_config(self):
        return json.loads(self.cmd("get_config *"))

    def get_mem_info(self):
        result = json.loads(self.cmd("get_system_info"))
        result_human = {
            # unit: Gb
            "memory_used": round(int(result["memory_used"]) / (1024*1024*1024), 2)
        }
        return result_human

    def cmd(self, command):
        status, res = self._milvus._cmd(command)
        logger.info("Server command: %s, result: %s" % (command, res))
        self.check_status(status)
        return res
Ejemplo n.º 20
0
def main():
    # Specify server addr when create milvus client instance
    # milvus client instance maintain a connection pool, param
    # `pool_size` specify the max connection num.
    milvus = Milvus(_HOST, _PORT)

    # Create collection demo_collection if it dosen't exist.
    collection_name = 'example_collection'

    ok = milvus.has_collection(collection_name)
    field_name = 'example_field'
    if not ok:
        fields = {
            "fields": [{
                "name": field_name,
                "type": DataType.FLOAT_VECTOR,
                "metric_type": "L2",
                "params": {
                    "dim": _DIM
                },
                "indexes": [{
                    "metric_type": "L2"
                }]
            }]
        }

        milvus.create_collection(collection_name=collection_name,
                                 fields=fields)
    else:
        milvus.drop_collection(collection_name=collection_name)

    # Show collections in Milvus server
    collections = milvus.list_collections()
    print(collections)

    # Describe demo_collection
    stats = milvus.get_collection_stats(collection_name)
    print(stats)

    # 10000 vectors with 128 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(10)]
    print(vectors)
    # You can also use numpy to generate random vectors:
    #   vectors = np.random.rand(10000, _DIM).astype(np.float32)

    # Insert vectors into demo_collection, return status and vectors id list
    entities = [{
        "name": field_name,
        "type": DataType.FLOAT_VECTOR,
        "values": vectors
    }]

    res_ids = milvus.insert(collection_name=collection_name, entities=entities)
    print("ids:", res_ids)

    # Flush collection  inserted data to disk.
    milvus.flush([collection_name])

    # present collection statistics info
    stats = milvus.get_collection_stats(collection_name)
    print(stats)

    # create index of vectors, search more rapidly
    index_param = {
        "metric_type": "L2",
        "index_type": "IVF_FLAT",
        "params": {
            "nlist": 1024
        }
    }

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    print("Creating index: {}".format(index_param))
    status = milvus.create_index(collection_name, field_name, index_param)

    # execute vector similarity search

    print("Searching ... ")

    dsl = {
        "bool": {
            "must": [{
                "vector": {
                    field_name: {
                        "metric_type": "L2",
                        "query": vectors,
                        "topk": 10,
                        "params": {
                            "nprobe": 16
                        }
                    }
                }
            }]
        }
    }

    milvus.load_collection(collection_name)
    results = milvus.search(collection_name, dsl)
    # indicate search result
    # also use by:
    #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
    if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
        print('Query result is correct')
    else:
        print('Query result isn\'t correct')

    milvus.drop_index(collection_name, field_name)
    milvus.release_collection(collection_name)

    # Delete demo_collection
    status = milvus.drop_collection(collection_name)
Ejemplo n.º 21
0
class MilvusClient:
    def __init__(self, host, port, collection_name, vector_dim):
        self.client = Milvus(host, port)
        self.collection_name_prefix = collection_name
        self.vector_dim = vector_dim

    def save_vectors(self, vectors, botid):
        collection_name = self.decide_collection_name(botid)
        self.create_collection_if_need(collection_name)
        status, ids = self.client.insert(collection_name=collection_name,
                                         records=vectors)
        self.check_status(status)
        LogService.info('%d vectors saved into collection %s', len(vectors),
                        collection_name)
        return ids

    def search_vectors(self, vectors, botid):
        search_param = {'nprobe': 16}
        collection_name = self.decide_collection_name(botid)
        LogService.info('search vector in collection %s', collection_name)
        status, results = self.client.search(collection_name=collection_name,
                                             query_records=vectors,
                                             top_k=1,
                                             params=search_param)
        self.check_status(status)
        if len(results) > 0:
            vid = results[0][0].id
            LogService.info('vector found with id %d in collection %s', vid,
                            collection_name)
            return vid
        else:
            LogService.info('vector not found in collection %s',
                            collection_name)
            return None

    def empty_vectors(self, botid):
        collection_name = self.decide_collection_name(botid)
        status = self.client.drop_collection(collection_name)
        self.check_status(status)
        LogService.info('collection %s dropped', collection_name)

    def create_collection_if_need(self, collection_name):
        LogService.info('prepare collection %s', collection_name)
        status, exists = self.client.has_collection(collection_name)
        self.check_status(status)
        if not exists:
            LogService.info('collection %s not exists, create it.',
                            collection_name)
            create_param = {
                'collection_name': collection_name,
                'dimension': self.vector_dim,
                'index_file_size': 1024,
                'metric_type': MetricType.L2
            }
            status = self.client.create_collection(create_param)
            self.check_status(status)
            LogService.info('collection %s created', collection_name)

            status = self.client.create_index(collection_name,
                                              IndexType.IVF_FLAT,
                                              {'nlist': 16384})
            self.check_status(status)
            LogService.info('index for collection %s created', collection_name)
        else:
            LogService.info('collection %s already exists', collection_name)

    def drop_all_collections(self):
        LogService.info('drop all collections')
        status, collection_names = self.client.show_collections()
        self.check_status(status)
        if len(collection_names) > 0:
            for collection_name in collection_names:
                status = self.client.drop_collection(collection_name)
                self.check_status(status)
                LogService.info('%s dropped', collection_name)

    @classmethod
    def check_status(cls, status):
        if status.code != 0:
            raise Exception(status.message)

    def decide_collection_name(self, botid):
        return '%s%s' % (self.collection_name_prefix, botid)
Ejemplo n.º 22
0
}
milvus.create_collection(param=param)

ivf_param = {'nlist': 16384}
milvus.create_index(collection_name=col_name,
                    index_type=IndexType.IVF_FLAT,
                    params=ivf_param)

vectors = [[random.random() for _ in range(dim)] for _ in range(2000)]
vector_ids = list(range(2000))
_, ids = milvus.insert(collection_name=col_name,
                       records=vectors,
                       ids=vector_ids)
# print(ids)

time.sleep(1)
search_param = {'nprobe': 16}
q_records = [[random.random() for _ in range(dim)] for _ in range(5)]
_, result = milvus.search(collection_name=col_name,
                          query_records=q_records,
                          top_k=2,
                          params=search_param)
# for r in result:
#     print(r)
print(result.id_array)
print(result)

milvus.drop_collection(collection_name=col_name)

milvus.close()
Ejemplo n.º 23
0
class Indexer:
    '''
    索引器。
    '''
    def __init__(self, name, host='127.0.0.1', port='19531'):
        '''
        初始化。
        '''
        self.client = Milvus(host=host, port=port)
        self.collection = name

    def init(self, lenient=False):
        '''
        创建集合。
        '''
        if lenient:
            status, result = self.client.has_collection(
                collection_name=self.collection)
            if status.code != 0:
                raise ExertMilvusException(status)
            if result:
                return

        status = self.client.create_collection({
            'collection_name': self.collection,
            'dimension': 512,
            'index_file_size': 1024,
            'metric_type': MetricType.L2
        })
        if status.code != 0 and not (lenient and status.code == 9):
            raise ExertMilvusException(status)

        # 创建索引。
        status = self.client.create_index(collection_name=self.collection,
                                          index_type=IndexType.IVF_FLAT,
                                          params={'nlist': 16384})
        if status.code != 0:
            raise ExertMilvusException(status)

        return status

    def drop(self):
        '''
        删除集合。
        '''
        status = self.client.drop_collection(collection_name=self.collection)
        if status.code != 0:
            raise ExertMilvusException(status)

    def flush(self):
        '''
        写入到硬盘。
        '''
        status = self.client.flush([self.collection])
        if status.code != 0:
            raise ExertMilvusException(status)

    def compact(self):
        '''
        压缩集合。
        '''
        status = self.client.compact(collection_name=self.collection)
        if status.code != 0:
            raise ExertMilvusException(status)

    def close(self):
        '''
        关闭链接。
        '''
        self.client.close()

    def new_tag(self, tag):
        '''
        建分块标签。
        '''
        status = self.client.create_partition(collection_name=self.collection,
                                              partition_tag=tag)
        if status.code != 0:
            raise ExertMilvusException(status)

    def list_tag(self):
        '''
        列举分块标签。
        '''
        status, result = self.client.list_partitions(
            collection_name=self.collection)
        if status.code != 0:
            raise ExertMilvusException(status)
        return result

    def drop_tag(self, tag):
        '''
        删除分块标签。
        '''
        status = self.client.drop_partition(collection_name=self.collection,
                                            partition_tag=tag)
        if status.code != 0:
            raise ExertMilvusException(status)

    def index(self, vectors, tag=None, ids=None):
        '''
        添加索引
        '''
        params = {}
        if tag != None:
            params['tag'] = tag
        if ids != None:
            params['ids'] = ids
        status, result = self.client.insert(collection_name=self.collection,
                                            records=vectors,
                                            **params)
        if status.code != 0:
            raise ExertMilvusException(status)

        return result

    def listing(self, ids):
        '''
        列举信息。
        '''
        status, result = self.client.get_entity_by_id(
            collection_name=self.collection, ids=ids)
        if status.code != 0:
            raise ExertMilvusException(status)
        return result

    def counting(self):
        '''
        计算索引数。
        '''
        status, result = self.client.count_entities(
            collection_name=self.collection)
        if status.code != 0:
            raise ExertMilvusException(status)
        return result

    def unindex(self, ids):
        '''
        去掉索引。
        '''
        status = self.client.delete_entity_by_id(
            collection_name=self.collection, id_array=ids)
        if status.code != 0:
            raise ExertMilvusException(status)

    def search(self, vectors, top_count=100, tags=None):
        '''
        搜索。
        '''
        params = {'params': {'nprobe': 16}}
        if tags != None:
            params['partition_tags'] = tags
        status, results = self.client.search(collection_name=self.collection,
                                             query_records=vectors,
                                             top_k=top_count,
                                             **params)
        if status.code != 0:
            raise ExertMilvusException(status)
        return results
Ejemplo n.º 24
0
class Test:
    def __init__(self, nvec):
        self.cname = "benchmark"
        self.fname = "feature"
        self.dim = 128
        self.client = Milvus("localhost", 19530)
        self.prefix = '/sift1b/binary_128d_'
        self.suffix = '.npy'
        self.vecs_per_file = 100000
        self.maxfiles = 1000
        self.insert_bulk_size = 5000
        self.nvec = nvec
        self.insert_cost = 0
        self.flush_cost = 0
        self.create_index_cost = 0
        self.search_cost = 0
        assert self.nvec >= self.insert_bulk_size & self.nvec % self.insert_bulk_size == 0

    def run(self, suite):
        report = dict()
        try:
            # step 1 create collection
            logging.info(f'step 1 create collection')
            self._create_collection()
            logging.info(f'step 1 complete')

            # step 2 fill data
            logging.info(f'step 2 insert')
            start = time.time()
            self._insert()
            self.insert_cost = time.time() - start
            report["insert-speed"] = {
                "value": format(self.nvec / self.insert_cost, ".4f"),
                "unit": "vec/sec"
            }
            logging.info(f'step 2 complete')

            # step 3 flush
            logging.info(f'step 3 flush')
            start = time.time()
            self._flush()
            self.flush_cost = time.time() - start
            report["flush-cost"] = {
                "value": format(self.flush_cost, ".4f"),
                "unit": "s"
            }
            logging.info(f'step 3 complete')

            # step 4 create index
            logging.info(f'step 4 create index')
            start = time.time()
            self._create_index()
            self.create_index_cost = time.time() - start
            report["create-index-cost"] = {
                "value": format(self.create_index_cost, ".4f"),
                "unit": "s"
            }
            logging.info(f'step 4 complete')

            # step 5 load
            logging.info(f'step 5 load')
            self._load_collection()
            logging.info(f'step 5 complete')

            # step 6 search
            logging.info(f'step 6 search')
            for nq in suite["nq"]:
                for topk in suite["topk"]:
                    for nprobe in suite["nprobe"]:
                        start = time.time()
                        self._search(nq=nq, topk=topk, nprobe=nprobe)
                        self.search_cost = time.time() - start
                        report[f"search-q{nq}-k{topk}-p{nprobe}-cost"] = {
                            "value": format(self.search_cost, ".4f"),
                            "unit": "s"
                        }
            logging.info(f'step 6 complete')
        except AssertionError as ae:
            logging.exception(ae)
        except Exception as e:
            logging.error(f'test failed: {e}')
        finally:
            return report

    def _create_collection(self):
        logging.debug(f'create_collection() start')

        if self.client.has_collection(self.cname):
            logging.debug(f'collection {self.cname} existed')

            self.client.drop_collection(self.cname)
            logging.info(f'drop collection {self.cname}')

        logging.debug(f'before create collection: {self.cname}')
        self.client.create_collection(self.cname, {
            "fields": [{
                "name": self.fname,
                "type": DataType.FLOAT_VECTOR,
                "metric_type": "L2",
                "params": {"dim": self.dim},
                "indexes": [{"metric_type": "L2"}]
            }]
        })
        logging.info(f'created collection: {self.cname}')

        assert self.client.has_collection(self.cname)
        logging.debug(f'create_collection() finished')

    def _insert(self):
        logging.debug(f'insert() start')

        count = 0
        for i in range(0, self.maxfiles):
            filename = self.prefix + str(i).zfill(5) + self.suffix
            logging.debug(f'filename: {filename}')

            array = np.load(filename)
            logging.debug(f'numpy array shape: {array.shape}')

            step = self.insert_bulk_size
            for p in range(0, self.vecs_per_file, step):
                entities = [
                    {"name": self.fname, "type": DataType.FLOAT_VECTOR, "values": array[p:p + step][:].tolist()}]
                logging.debug(f'before insert slice: {p}, {p + step}')

                self.client.insert(self.cname, entities)
                logging.info(f'after insert slice: {p}, {p + step}')

                count += step
                logging.debug(f'insert count: {count}')

                if count == self.nvec:
                    logging.debug(f'inner break')
                    break
            if count == self.nvec:
                logging.debug(f'outer break')
                break
        logging.debug(f'insert() finished')

    def _flush(self):
        logging.debug(f'flush() start')

        logging.debug(f'before flush: {self.cname}')
        self.client.flush([self.cname])
        logging.info(f'after flush')

        stats = self.client.get_collection_stats(self.cname)
        logging.debug(stats)

        assert stats["row_count"] == self.nvec
        logging.debug(f'flush() finished')

    def _create_index(self):
        logging.debug(f'create_index() start')

        index_params = {
            "metric_type": "L2",
            "index_type": "IVF_FLAT",
            "params": {"nlist": 1024}
        }
        self.client.create_index(self.cname, self.fname, index_params)
        logging.debug(f'create index {self.cname} : {self.fname} : {index_params}')
        logging.debug(f'create_index() finished')

    def _load_collection(self):
        logging.debug(f'load_collection() start')

        logging.debug(f'before load collection: {self.cname}')
        self.client.load_collection(self.cname)
        logging.debug(f'load_collection() finished')

    def _search(self, nq, topk, nprobe):
        logging.debug(f'search() start')

        result = self.client.search(self.cname,
                                    {"bool": {"must": [{"vector": {
                                        self.fname: {
                                            "metric_type": "L2",
                                            "query": _gen_vectors(nq, self.dim),
                                            "topk": topk,
                                            "params": {"nprobe": nprobe}
                                        }
                                    }}]}}
                                    )
        logging.debug(f'{result}')
        logging.debug(f'search() finished')
Ejemplo n.º 25
0
def main():
    # Specify server addr when create milvus client instance
    milvus = Milvus(_HOST, _PORT)

    # Create collection demo_collection if it dosen't exist.
    collection_name = 'example_async_collection_'

    status, ok = milvus.has_collection(collection_name)
    if not ok:
        param = {
            'collection_name': collection_name,
            'dimension': _DIM,
            'index_file_size': 128,  # optional
            'metric_type': MetricType.L2  # optional
        }

        status = milvus.create_collection(param)
        if not status.OK():
            print("Create collection failed: {}".format(status.message),
                  file=sys.stderr)
            print("exiting ...", file=sys.stderr)
            sys.exit(1)

    # Show collections in Milvus server
    _, collections = milvus.list_collections()

    # Describe demo_collection
    _, collection = milvus.get_collection_info(collection_name)
    print(collection)

    # 10000 vectors with 16 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(100000)]

    # You can also use numpy to generate random vectors:
    #     `vectors = np.random.rand(10000, 16).astype(np.float32)`

    def _insert_callback(status, ids):
        if status.OK():
            print("Insert successfully")
        else:
            print("Insert failed.", status.message)

    # Insert vectors into demo_collection, adding callback function
    insert_future = milvus.insert(collection_name=collection_name,
                                  records=vectors,
                                  _async=True,
                                  _callback=_insert_callback)
    # Or invoke result() to get results:
    #   insert_future = milvus.insert(collection_name=collection_name, records=vectors, _async=True)
    #   status, ids = insert_future.result()
    insert_future.done()

    # Flush collection  inserted data to disk.
    def _flush_callback(status):
        if status.OK():
            print("Flush successfully")
        else:
            print("Flush failed.", status.message)

    flush_future = milvus.flush([collection_name],
                                _async=True,
                                _callback=_flush_callback)
    # Or invoke result() to get results:
    #   flush_future = milvus.flush([collection_name], _async=True)
    #   status = flush_future.result()
    flush_future.done()

    def _compact_callback(status):
        if status.OK():
            print("Compact successfully")
        else:
            print("Compact failed.", status.message)

    compact_furure = milvus.compact(collection_name,
                                    _async=True,
                                    _cakkback=_compact_callback)
    # Or invoke result() to get results:
    #   compact_future = milvus.compact(collection_name, _async=True)
    #   status = compact_future.result()
    compact_furure.done()

    # Get demo_collection row count
    status, result = milvus.count_entities(collection_name)

    # present collection info
    _, info = milvus.get_collection_stats(collection_name)
    print(info)

    # create index of vectors, search more rapidly
    index_param = {'nlist': 2048}

    def _index_callback(status):
        if status.OK():
            print("Create index successfully")
        else:
            print("Create index failed.", status.message)

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    print("Creating index: {}".format(index_param))
    index_future = milvus.create_index(collection_name,
                                       IndexType.IVF_FLAT,
                                       index_param,
                                       _async=True,
                                       _callback=_index_callback)
    # Or invoke result() to get results:
    #   index_future = milvus.create_index(collection_name, IndexType.IVF_FLAT, index_param, _async=True)
    #   status = index_future.result()
    index_future.done()

    # describe index, get information of index
    status, index = milvus.get_index_info(collection_name)
    print(index)

    # Use the top 10 vectors for similarity search
    query_vectors = vectors[0:10]

    # execute vector similarity search
    search_param = {"nprobe": 16}

    print("Searching ... ")

    def _search_callback(status, results):
        # if status.OK():
        #     print("Search successfully")
        # else:
        #     print("Search failed.", status.message)
        if status.OK():
            # indicate search result
            # also use by:
            #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
            if results[0][0].distance == 0.0:  # or results[0][0].id == ids[0]:
                print('Query result is correct')
            else:
                print('Query result isn\'t correct')

            # print results
            print(results)
        else:
            print("Search failed. ", status)

    param = {
        'collection_name': collection_name,
        'query_records': query_vectors,
        'top_k': 1,
        'params': search_param,
        "_async": True,
        "_callback": _search_callback
    }
    search_future = milvus.search(**param)
    # Or invoke result() to get results:
    #
    #   param = {
    #       'collection_name': collection_name,
    #       'query_records': query_vectors,
    #       'top_k': 1,
    #       'params': search_param,
    #       "_async": True,
    #   }
    #   search_future = milvus.search(param)
    #   status, results = index_future.result()

    search_future.done()

    # Delete demo_collection
    status = milvus.drop_collection(collection_name)
def main():
    # Specify server addr when create milvus client instance
    # milvus client instance maintain a connection pool, param
    # `pool_size` specify the max connection num.
    milvus = Milvus(_HOST, _PORT)

    # Create collection demo_collection if it dosen't exist.
    collection_name = 'example_collection_'

    status, ok = milvus.has_collection(collection_name)
    if not ok:
        param = {
            'collection_name': collection_name,
            'dimension': _DIM,
            'index_file_size': _INDEX_FILE_SIZE,  # optional
            'metric_type': MetricType.L2  # optional
        }

        milvus.create_collection(param)

    # Show collections in Milvus server
    _, collections = milvus.list_collections()

    # Describe demo_collection
    _, collection = milvus.get_collection_info(collection_name)
    print(collection)

    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = text2vec(index_sentences)
    print(vectors)

    # Insert vectors into demo_collection, return status and vectors id list
    status, ids = milvus.insert(collection_name=collection_name,
                                records=vectors)
    if not status.OK():
        print("Insert failed: {}".format(status))
    else:
        print(ids)
    #create a quick lookup table to easily access the indexed text/sentences given the ids
    look_up = {}
    for ID, sentences in zip(ids, index_sentences):
        look_up[ID] = sentences

    for k in look_up:
        print(k, look_up[k])

    # Flush collection  inserted data to disk.
    milvus.flush([collection_name])
    # Get demo_collection row count
    status, result = milvus.count_entities(collection_name)

    # present collection statistics info
    _, info = milvus.get_collection_stats(collection_name)
    print(info)

    # Obtain raw vectors by providing vector ids
    status, result_vectors = milvus.get_entity_by_id(collection_name, ids)

    # create index of vectors, search more rapidly
    index_param = {'nlist': 2048}

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    print("Creating index: {}".format(index_param))
    status = milvus.create_index(collection_name, IndexType.IVF_FLAT,
                                 index_param)

    # describe index, get information of index
    status, index = milvus.get_index_info(collection_name)
    print(index)

    # Use the query sentences for similarity search
    query_vectors = text2vec(query_sentences)

    # execute vector similarity search
    search_param = {"nprobe": 16}

    print("Searching ... ")

    param = {
        'collection_name': collection_name,
        'query_records': query_vectors,
        'top_k': 1,
        'params': search_param,
    }

    status, results = milvus.search(**param)
    if status.OK():
        # indicate search result
        # also use by:
        #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
        if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
            print('Query result is correct')
        else:
            print('Query result isn\'t correct')

        # print results
        for res in results:
            for ele in res:
                print('id:{}, text:{}, distance: {}'.format(
                    ele.id, look_up[ele.id], ele.distance))

    else:
        print("Search failed. ", status)

    # Delete demo_collection
    status = milvus.drop_collection(collection_name)
Ejemplo n.º 27
0
This example is runable for Milvus(0.11.x) and pymilvus(0.3.x).
"""
import random
import csv
from pprint import pprint

from milvus import Milvus, DataType

_HOST = '127.0.0.1'
_PORT = '19530'
client = Milvus(_HOST, _PORT)

collection_name = 'demo_index'
if collection_name in client.list_collections():
    client.drop_collection(collection_name)

collection_param = {
    "fields": [
        {
            "name": "release_year",
            "type": DataType.INT32
        },
        {
            "name": "embedding",
            "type": DataType.FLOAT_VECTOR,
            "params": {
                "dim": 8
            }
        },
    ],
Ejemplo n.º 28
0
class MilvusClient(object):
    def __init__(self, collection_name=None, ip=None, port=None, timeout=60):
        self._collection_name = collection_name
        try:
            i = 1
            start_time = time.time()
            if not ip:
                self._milvus = Milvus(host=SERVER_HOST_DEFAULT,
                                      port=SERVER_PORT_DEFAULT)
            else:
                # retry connect for remote server
                while time.time() < start_time + timeout:
                    try:
                        self._milvus = Milvus(host=ip, port=port)
                        if self._milvus.server_status():
                            logger.debug(
                                "Try connect times: %d, %s" %
                                (i, round(time.time() - start_time, 2)))
                            break
                    except Exception as e:
                        logger.debug("Milvus connect failed")
                        i = i + 1

        except Exception as e:
            raise e

    def __str__(self):
        return 'Milvus collection %s' % self._collection_name

    def check_status(self, status):
        if not status.OK():
            logger.error(status.message)
            raise Exception("Status not ok")

    def check_result_ids(self, result):
        for index, item in enumerate(result):
            if item[0].distance >= epsilon:
                logger.error(index)
                logger.error(item[0].distance)
                raise Exception("Distance wrong")

    def create_collection(self, collection_name, dimension, index_file_size,
                          metric_type):
        if not self._collection_name:
            self._collection_name = collection_name
        if metric_type == "l2":
            metric_type = MetricType.L2
        elif metric_type == "ip":
            metric_type = MetricType.IP
        elif metric_type == "jaccard":
            metric_type = MetricType.JACCARD
        elif metric_type == "hamming":
            metric_type = MetricType.HAMMING
        elif metric_type == "sub":
            metric_type = MetricType.SUBSTRUCTURE
        elif metric_type == "super":
            metric_type = MetricType.SUPERSTRUCTURE
        else:
            logger.error("Not supported metric_type: %s" % metric_type)
        create_param = {
            'collection_name': collection_name,
            'dimension': dimension,
            'index_file_size': index_file_size,
            "metric_type": metric_type
        }
        status = self._milvus.create_collection(create_param)
        self.check_status(status)

    @time_wrapper
    def insert(self, X, ids=None):
        status, result = self._milvus.add_vectors(self._collection_name, X,
                                                  ids)
        self.check_status(status)
        return status, result

    @time_wrapper
    def delete_vectors(self, ids):
        status = self._milvus.delete_by_id(self._collection_name, ids)
        self.check_status(status)

    @time_wrapper
    def flush(self):
        status = self._milvus.flush([self._collection_name])
        self.check_status(status)

    @time_wrapper
    def compact(self):
        status = self._milvus.compact(self._collection_name)
        self.check_status(status)

    @time_wrapper
    def create_index(self, index_type, index_param=None):
        index_type = INDEX_MAP[index_type]
        logger.info(
            "Building index start, collection_name: %s, index_type: %s" %
            (self._collection_name, index_type))
        if index_param:
            logger.info(index_param)
        status = self._milvus.create_index(self._collection_name, index_type,
                                           index_param)
        self.check_status(status)

    def describe_index(self):
        status, result = self._milvus.describe_index(self._collection_name)
        self.check_status(status)
        index_type = None
        for k, v in INDEX_MAP.items():
            if result._index_type == v:
                index_type = k
                break
        return {"index_type": index_type, "index_param": result._params}

    def drop_index(self):
        logger.info("Drop index: %s" % self._collection_name)
        return self._milvus.drop_index(self._collection_name)

    @time_wrapper
    def query(self, X, top_k, search_param=None):
        status, result = self._milvus.search_vectors(self._collection_name,
                                                     top_k,
                                                     query_records=X,
                                                     params=search_param)
        self.check_status(status)
        return result

    @time_wrapper
    def query_ids(self, top_k, ids, search_param=None):
        status, result = self._milvus.search_by_ids(self._collection_name,
                                                    ids,
                                                    top_k,
                                                    params=search_param)
        self.check_result_ids(result)
        return result

    def count(self):
        return self._milvus.count_collection(self._collection_name)[1]

    def delete(self, timeout=120):
        timeout = int(timeout)
        logger.info("Start delete collection: %s" % self._collection_name)
        self._milvus.drop_collection(self._collection_name)
        i = 0
        while i < timeout:
            if self.count():
                time.sleep(1)
                i = i + 1
                continue
            else:
                break
        if i >= timeout:
            logger.error("Delete collection timeout")

    def describe(self):
        return self._milvus.describe_collection(self._collection_name)

    def show_collections(self):
        return self._milvus.show_collections()

    def exists_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status, res = self._milvus.has_collection(collection_name)
        # self.check_status(status)
        return res

    @time_wrapper
    def preload_collection(self):
        status = self._milvus.preload_collection(self._collection_name,
                                                 timeout=3000)
        self.check_status(status)
        return status

    def get_server_version(self):
        status, res = self._milvus.server_version()
        return res

    def get_server_mode(self):
        return self.cmd("mode")

    def get_server_commit(self):
        return self.cmd("build_commit_id")

    def get_server_config(self):
        return json.loads(self.cmd("get_config *"))

    def get_mem_info(self):
        result = json.loads(self.cmd("get_system_info"))
        result_human = {
            # unit: Gb
            "memory_used":
            round(int(result["memory_used"]) / (1024 * 1024 * 1024), 2)
        }
        return result_human

    def cmd(self, command):
        status, res = self._milvus._cmd(command)
        logger.info("Server command: %s, result: %s" % (command, res))
        self.check_status(status)
        return res
Ejemplo n.º 29
0
class MilvusClient(object):
    def __init__(self,
                 collection_name=None,
                 host=None,
                 port=None,
                 timeout=180):
        self._collection_name = collection_name
        start_time = time.time()
        if not host:
            host = SERVER_HOST_DEFAULT
        if not port:
            port = SERVER_PORT_DEFAULT
        logger.debug(host)
        logger.debug(port)
        # retry connect remote server
        i = 0
        while time.time() < start_time + timeout:
            try:
                self._milvus = Milvus(host=host,
                                      port=port,
                                      try_connect=False,
                                      pre_ping=False)
                break
            except Exception as e:
                logger.error(str(e))
                logger.error("Milvus connect failed: %d times" % i)
                i = i + 1
                time.sleep(i)

        if time.time() > start_time + timeout:
            raise Exception("Server connect timeout")
        # self._metric_type = None

    def __str__(self):
        return 'Milvus collection %s' % self._collection_name

    def check_status(self, status):
        if not status.OK():
            logger.error(status.message)
            logger.error(self._milvus.server_status())
            logger.error(self.count())
            raise Exception("Status not ok")

    def check_result_ids(self, result):
        for index, item in enumerate(result):
            if item[0].distance >= epsilon:
                logger.error(index)
                logger.error(item[0].distance)
                raise Exception("Distance wrong")

    # only support the given field name
    def create_collection(self,
                          dimension,
                          data_type=DataType.FLOAT_VECTOR,
                          auto_id=False,
                          collection_name=None,
                          other_fields=None):
        self._dimension = dimension
        if not collection_name:
            collection_name = self._collection_name
        vec_field_name = utils.get_default_field_name(data_type)
        fields = [{
            "name": vec_field_name,
            "type": data_type,
            "params": {
                "dim": dimension
            }
        }]
        if other_fields:
            other_fields = other_fields.split(",")
            if "int" in other_fields:
                fields.append({
                    "name": utils.DEFAULT_INT_FIELD_NAME,
                    "type": DataType.INT64
                })
            if "float" in other_fields:
                fields.append({
                    "name": utils.DEFAULT_FLOAT_FIELD_NAME,
                    "type": DataType.FLOAT
                })
        create_param = {"fields": fields, "auto_id": auto_id}
        try:
            self._milvus.create_collection(collection_name, create_param)
            logger.info("Create collection: <%s> successfully" %
                        collection_name)
        except Exception as e:
            logger.error(str(e))
            raise

    def create_partition(self, tag, collection_name=None):
        if not collection_name:
            collection_name = self._collection_name
        self._milvus.create_partition(collection_name, tag)

    def generate_values(self, data_type, vectors, ids):
        values = None
        if data_type in [DataType.INT32, DataType.INT64]:
            values = ids
        elif data_type in [DataType.FLOAT, DataType.DOUBLE]:
            values = [(i + 0.0) for i in ids]
        elif data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
            values = vectors
        return values

    def generate_entities(self, vectors, ids=None, collection_name=None):
        entities = []
        if collection_name is None:
            collection_name = self._collection_name
        info = self.get_info(collection_name)
        for field in info["fields"]:
            field_type = field["type"]
            entities.append({
                "name":
                field["name"],
                "type":
                field_type,
                "values":
                self.generate_values(field_type, vectors, ids)
            })
        return entities

    @time_wrapper
    def insert(self, entities, ids=None, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        try:
            insert_ids = self._milvus.insert(tmp_collection_name,
                                             entities,
                                             ids=ids)
            return insert_ids
        except Exception as e:
            logger.error(str(e))

    def get_dimension(self):
        info = self.get_info()
        for field in info["fields"]:
            if field["type"] in [
                    DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR
            ]:
                return field["params"]["dim"]

    def get_rand_ids(self, length):
        segment_ids = []
        while True:
            stats = self.get_stats()
            segments = stats["partitions"][0]["segments"]
            # random choice one segment
            segment = random.choice(segments)
            try:
                segment_ids = self._milvus.list_id_in_segment(
                    self._collection_name, segment["id"])
            except Exception as e:
                logger.error(str(e))
            if not len(segment_ids):
                continue
            elif len(segment_ids) > length:
                return random.sample(segment_ids, length)
            else:
                logger.debug("Reset length: %d" % len(segment_ids))
                return segment_ids

    # def get_rand_ids_each_segment(self, length):
    #     res = []
    #     status, stats = self._milvus.get_collection_stats(self._collection_name)
    #     self.check_status(status)
    #     segments = stats["partitions"][0]["segments"]
    #     segments_num = len(segments)
    #     # random choice from each segment
    #     for segment in segments:
    #         status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
    #         self.check_status(status)
    #         res.extend(segment_ids[:length])
    #     return segments_num, res

    # def get_rand_entities(self, length):
    #     ids = self.get_rand_ids(length)
    #     status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids)
    #     self.check_status(status)
    #     return ids, get_res

    def get(self):
        get_ids = random.randint(1, 1000000)
        self._milvus.get_entity_by_id(self._collection_name, [get_ids])

    @time_wrapper
    def get_entities(self, get_ids):
        get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids)
        return get_res

    @time_wrapper
    def delete(self, ids, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        self._milvus.delete_entity_by_id(tmp_collection_name, ids)

    def delete_rand(self):
        delete_id_length = random.randint(1, 100)
        count_before = self.count()
        logger.debug("%s: length to delete: %d" %
                     (self._collection_name, delete_id_length))
        delete_ids = self.get_rand_ids(delete_id_length)
        self.delete(delete_ids)
        self.flush()
        logger.info("%s: count after delete: %d" %
                    (self._collection_name, self.count()))
        get_res = self._milvus.get_entity_by_id(self._collection_name,
                                                delete_ids)
        for item in get_res:
            assert not item
        # if count_before - len(delete_ids) < self.count():
        #     logger.error(delete_ids)
        #     raise Exception("Error occured")

    @time_wrapper
    def flush(self, _async=False, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        self._milvus.flush([tmp_collection_name], _async=_async)

    @time_wrapper
    def compact(self, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        status = self._milvus.compact(tmp_collection_name)
        self.check_status(status)

    @time_wrapper
    def create_index(self,
                     field_name,
                     index_type,
                     metric_type,
                     _async=False,
                     index_param=None):
        index_type = INDEX_MAP[index_type]
        metric_type = utils.metric_type_trans(metric_type)
        logger.info(
            "Building index start, collection_name: %s, index_type: %s, metric_type: %s"
            % (self._collection_name, index_type, metric_type))
        if index_param:
            logger.info(index_param)
        index_params = {
            "index_type": index_type,
            "metric_type": metric_type,
            "params": index_param
        }
        self._milvus.create_index(self._collection_name,
                                  field_name,
                                  index_params,
                                  _async=_async)

    # TODO: need to check
    def describe_index(self, field_name):
        # stats = self.get_stats()
        info = self._milvus.describe_index(self._collection_name, field_name)
        index_info = {"index_type": "flat", "index_param": None}
        for field in info["fields"]:
            for index in field['indexes']:
                if not index or "index_type" not in index:
                    continue
                else:
                    for k, v in INDEX_MAP.items():
                        if index['index_type'] == v:
                            index_info['index_type'] = k
                            index_info['index_param'] = index['params']
                            return index_info
        return index_info

    def drop_index(self, field_name):
        logger.info("Drop index: %s" % self._collection_name)
        return self._milvus.drop_index(self._collection_name, field_name)

    @time_wrapper
    def query(self, vector_query, filter_query=None, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        must_params = [vector_query]
        if filter_query:
            must_params.extend(filter_query)
        query = {"bool": {"must": must_params}}
        result = self._milvus.search(tmp_collection_name, query)
        return result

    @time_wrapper
    def load_and_query(self,
                       vector_query,
                       filter_query=None,
                       collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        must_params = [vector_query]
        if filter_query:
            must_params.extend(filter_query)
        query = {"bool": {"must": must_params}}
        self.load_collection(tmp_collection_name)
        result = self._milvus.search(tmp_collection_name, query)
        return result

    def get_ids(self, result):
        idss = result._entities.ids
        ids = []
        len_idss = len(idss)
        len_r = len(result)
        top_k = len_idss // len_r
        for offset in range(0, len_idss, top_k):
            ids.append(idss[offset:min(offset + top_k, len_idss)])
        return ids

    def query_rand(self, nq_max=100):
        # for ivf search
        dimension = 128
        top_k = random.randint(1, 100)
        nq = random.randint(1, nq_max)
        nprobe = random.randint(1, 100)
        search_param = {"nprobe": nprobe}
        query_vectors = [[random.random() for _ in range(dimension)]
                         for _ in range(nq)]
        metric_type = random.choice(["l2", "ip"])
        logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" %
                    (self._collection_name, nq, top_k, nprobe))
        vec_field_name = utils.get_default_field_name()
        vector_query = {
            "vector": {
                vec_field_name: {
                    "topk": top_k,
                    "query": query_vectors,
                    "metric_type": utils.metric_type_trans(metric_type),
                    "params": search_param
                }
            }
        }
        self.query(vector_query)

    def load_query_rand(self, nq_max=100):
        # for ivf search
        dimension = 128
        top_k = random.randint(1, 100)
        nq = random.randint(1, nq_max)
        nprobe = random.randint(1, 100)
        search_param = {"nprobe": nprobe}
        query_vectors = [[random.random() for _ in range(dimension)]
                         for _ in range(nq)]
        metric_type = random.choice(["l2", "ip"])
        logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" %
                    (self._collection_name, nq, top_k, nprobe))
        vec_field_name = utils.get_default_field_name()
        vector_query = {
            "vector": {
                vec_field_name: {
                    "topk": top_k,
                    "query": query_vectors,
                    "metric_type": utils.metric_type_trans(metric_type),
                    "params": search_param
                }
            }
        }
        self.load_and_query(vector_query)

    # TODO: need to check
    def count(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        row_count = self._milvus.get_collection_stats(
            collection_name)["row_count"]
        logger.debug("Row count: %d in collection: <%s>" %
                     (row_count, collection_name))
        return row_count

    def drop(self, timeout=120, collection_name=None):
        timeout = int(timeout)
        if collection_name is None:
            collection_name = self._collection_name
        logger.info("Start delete collection: %s" % collection_name)
        self._milvus.drop_collection(collection_name)
        i = 0
        while i < timeout:
            try:
                row_count = self.count(collection_name=collection_name)
                if row_count:
                    time.sleep(1)
                    i = i + 1
                    continue
                else:
                    break
            except Exception as e:
                logger.debug(str(e))
                break
        if i >= timeout:
            logger.error("Delete collection timeout")

    def get_stats(self):
        return self._milvus.get_collection_stats(self._collection_name)

    def get_info(self, collection_name=None):
        # pdb.set_trace()
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.get_collection_info(collection_name)

    def show_collections(self):
        return self._milvus.list_collections()

    def exists_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        res = self._milvus.has_collection(collection_name)
        return res

    def clean_db(self):
        collection_names = self.show_collections()
        for name in collection_names:
            self.drop(collection_name=name)

    @time_wrapper
    def load_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.load_collection(collection_name, timeout=3000)

    @time_wrapper
    def release_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.release_collection(collection_name, timeout=3000)

    @time_wrapper
    def load_partitions(self, tag_names, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.load_partitions(collection_name,
                                            tag_names,
                                            timeout=3000)

    @time_wrapper
    def release_partitions(self, tag_names, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.release_partitions(collection_name,
                                               tag_names,
                                               timeout=3000)
Ejemplo n.º 30
0
class MilvusDocumentStore(SQLDocumentStore):
    """
    Milvus (https://milvus.io/) is a highly reliable, scalable Document Store specialized on storing and processing vectors.
    Therefore, it is particularly suited for Haystack users that work with dense retrieval methods (like DPR).
    In contrast to FAISS, Milvus ...
     - runs as a separate service (e.g. a Docker container) and can scale easily in a distributed environment
     - allows dynamic data management (i.e. you can insert/delete vectors without recreating the whole index)
     - encapsulates multiple ANN libraries (FAISS, ANNOY ...)

    This class uses Milvus for all vector related storage, processing and querying.
    The meta-data (e.g. for filtering) and the document text are however stored in a separate SQL Database as Milvus
    does not allow these data types (yet).

    Usage:
    1. Start a Milvus server (see https://milvus.io/docs/v0.10.5/install_milvus.md)
    2. Init a MilvusDocumentStore in Haystack
    """
    def __init__(
        self,
        sql_url: str = "sqlite:///",
        milvus_url: str = "tcp://localhost:19530",
        connection_pool: str = "SingletonThread",
        index: str = "document",
        vector_dim: int = 768,
        index_file_size: int = 1024,
        similarity: str = "dot_product",
        index_type: IndexType = IndexType.FLAT,
        index_param: Optional[Dict[str, Any]] = None,
        search_param: Optional[Dict[str, Any]] = None,
        update_existing_documents: bool = False,
        return_embedding: bool = False,
        embedding_field: str = "embedding",
        **kwargs,
    ):
        """
        :param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale
                        deployment, Postgres is recommended. If using MySQL then same server can also be used for
                        Milvus metadata. For more details see https://milvus.io/docs/v0.10.5/data_manage.md.
        :param milvus_url: Milvus server connection URL for storing and processing vectors.
                           Protocol, host and port will automatically be inferred from the URL.
                           See https://milvus.io/docs/v0.10.5/install_milvus.md for instructions to start a Milvus instance.
        :param connection_pool: Connection pool type to connect with Milvus server. Default: "SingletonThread".
        :param index: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
        :param vector_dim: The embedding vector size. Default: 768.
        :param index_file_size: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
         When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
         Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.
         As a rule of thumb, we would see a 30% ~ 50% increase in the search performance after changing the value of index_file_size from 1024 to 2048.
         Note that an overly large index_file_size value may cause failure to load a segment into the memory or graphics memory.
         (From https://milvus.io/docs/v0.10.5/performance_faq.md#How-can-I-get-the-best-performance-from-Milvus-through-setting-index_file_size)
        :param similarity: The similarity function used to compare document vectors. 'dot_product' is the default and recommended for DPR embeddings.
                           'cosine' is recommended for Sentence Transformers, but is not directly supported by Milvus.
                           However, you can normalize your embeddings and use `dot_product` to get the same results.
                           See https://milvus.io/docs/v0.10.5/metric.md?Inner-product-(IP)#floating.
        :param index_type: Type of approximate nearest neighbour (ANN) index used. The choice here determines your tradeoff between speed and accuracy.
                           Some popular options:
                           - FLAT (default): Exact method, slow
                           - IVF_FLAT, inverted file based heuristic, fast
                           - HSNW: Graph based, fast
                           - ANNOY: Tree based, fast
                           See: https://milvus.io/docs/v0.10.5/index.md
        :param index_param: Configuration parameters for the chose index_type needed at indexing time.
                            For example: {"nlist": 16384} as the number of cluster units to create for index_type IVF_FLAT.
                            See https://milvus.io/docs/v0.10.5/index.md
        :param search_param: Configuration parameters for the chose index_type needed at query time
                             For example: {"nprobe": 10} as the number of cluster units to query for index_type IVF_FLAT.
                             See https://milvus.io/docs/v0.10.5/index.md
        :param update_existing_documents: Whether to update any existing documents with the same ID when adding
                                          documents. When set as True, any document with an existing ID gets updated.
                                          If set to False, an error is raised if the document ID of the document being
                                          added already exists.
        :param return_embedding: To return document embedding.
        :param embedding_field: Name of field containing an embedding vector.
        """
        self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
        self.vector_dim = vector_dim
        self.index_file_size = index_file_size

        if similarity == "dot_product":
            self.metric_type = MetricType.L2
        else:
            raise ValueError(
                "The Milvus document store can currently only support dot_product similarity. "
                "Please set similarity=\"dot_product\"")

        self.index_type = index_type
        self.index_param = index_param or {"nlist": 16384}
        self.search_param = search_param or {"nprobe": 10}
        self.index = index
        self._create_collection_and_index_if_not_exist(self.index)
        self.return_embedding = return_embedding
        self.embedding_field = embedding_field

        super().__init__(url=sql_url,
                         update_existing_documents=update_existing_documents,
                         index=index)

    def __del__(self):
        return self.milvus_server.close()

    def _create_collection_and_index_if_not_exist(
            self,
            index: Optional[str] = None,
            index_param: Optional[Dict[str, Any]] = None):
        index = index or self.index
        index_param = index_param or self.index_param

        status, ok = self.milvus_server.has_collection(collection_name=index)
        if not ok:
            collection_param = {
                'collection_name': index,
                'dimension': self.vector_dim,
                'index_file_size': self.index_file_size,
                'metric_type': self.metric_type
            }

            status = self.milvus_server.create_collection(collection_param)
            if status.code != Status.SUCCESS:
                raise RuntimeError(
                    f'Collection creation on Milvus server failed: {status}')

            status = self.milvus_server.create_index(index, self.index_type,
                                                     index_param)
            if status.code != Status.SUCCESS:
                raise RuntimeError(
                    f'Index creation on Milvus server failed: {status}')

    def _create_document_field_map(self) -> Dict:
        return {
            self.index: self.embedding_field,
        }

    def write_documents(self,
                        documents: Union[List[dict], List[Document]],
                        index: Optional[str] = None,
                        batch_size: int = 10_000):
        """
        Add new documents to the DocumentStore.

        :param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index
                                  them right away in Milvus. If not, you can later call update_embeddings() to create & index them.
        :param index: (SQL) index name for storing the docs and metadata
        :param batch_size: When working with large number of documents, batching can help reduce memory footprint.
        :return:
        """
        index = index or self.index
        self._create_collection_and_index_if_not_exist(index)
        field_map = self._create_document_field_map()

        if len(documents) == 0:
            logger.warning(
                "Calling DocumentStore.write_documents() with empty list")
            return

        document_objects = [
            Document.from_dict(d, field_map=field_map)
            if isinstance(d, dict) else d for d in documents
        ]

        add_vectors = False if document_objects[0].embedding is None else True

        batched_documents = get_batches_from_generator(document_objects,
                                                       batch_size)
        with tqdm(total=len(document_objects)) as progress_bar:
            for document_batch in batched_documents:
                vector_ids = []
                if add_vectors:
                    doc_ids = []
                    embeddings = []
                    for doc in document_batch:
                        doc_ids.append(doc.id)
                        if isinstance(doc.embedding, np.ndarray):
                            embeddings.append(doc.embedding.tolist())
                        elif isinstance(doc.embedding, list):
                            embeddings.append(doc.embedding)
                        else:
                            raise AttributeError(
                                f'Format of supplied document embedding {type(doc.embedding)} is not '
                                f'supported. Please use list or numpy.ndarray')

                    if self.update_existing_documents:
                        existing_docs = super().get_documents_by_id(
                            ids=doc_ids, index=index)
                        self._delete_vector_ids_from_milvus(
                            documents=existing_docs, index=index)

                    status, vector_ids = self.milvus_server.insert(
                        collection_name=index, records=embeddings)
                    if status.code != Status.SUCCESS:
                        raise RuntimeError(
                            f'Vector embedding insertion failed: {status}')

                docs_to_write_in_sql = []
                for idx, doc in enumerate(document_batch):
                    meta = doc.meta
                    if add_vectors:
                        meta["vector_id"] = vector_ids[idx]
                    docs_to_write_in_sql.append(doc)

                super().write_documents(docs_to_write_in_sql, index=index)
                progress_bar.update(batch_size)
        progress_bar.close()

        self.milvus_server.flush([index])
        if self.update_existing_documents:
            self.milvus_server.compact(collection_name=index)

    def update_embeddings(self,
                          retriever: BaseRetriever,
                          index: Optional[str] = None,
                          batch_size: int = 10_000):
        """
        Updates the embeddings in the the document store using the encoding model specified in the retriever.
        This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config).

        :param retriever: Retriever to use to get embeddings for text
        :param index: (SQL) index name for storing the docs and metadata
        :param batch_size: When working with large number of documents, batching can help reduce memory footprint.
        :return: None
        """
        index = index or self.index
        self._create_collection_and_index_if_not_exist(index)

        document_count = self.get_document_count(index=index)
        if document_count == 0:
            logger.warning(
                "Calling DocumentStore.update_embeddings() on an empty index")
            return

        logger.info(f"Updating embeddings for {document_count} docs...")

        result = self.get_all_documents_generator(index=index,
                                                  batch_size=batch_size,
                                                  return_embedding=False)
        batched_documents = get_batches_from_generator(result, batch_size)
        with tqdm(total=document_count) as progress_bar:
            for document_batch in batched_documents:
                self._delete_vector_ids_from_milvus(documents=document_batch,
                                                    index=index)

                embeddings = retriever.embed_passages(
                    document_batch)  # type: ignore
                embeddings_list = [
                    embedding.tolist() for embedding in embeddings
                ]
                assert len(document_batch) == len(embeddings_list)

                status, vector_ids = self.milvus_server.insert(
                    collection_name=index, records=embeddings_list)
                if status.code != Status.SUCCESS:
                    raise RuntimeError(
                        f'Vector embedding insertion failed: {status}')

                vector_id_map = {}
                for vector_id, doc in zip(vector_ids, document_batch):
                    vector_id_map[doc.id] = vector_id

                self.update_vector_ids(vector_id_map, index=index)
                progress_bar.update(batch_size)
        progress_bar.close()

        self.milvus_server.flush([index])
        self.milvus_server.compact(collection_name=index)

    def query_by_embedding(
            self,
            query_emb: np.array,
            filters: Optional[dict] = None,
            top_k: int = 10,
            index: Optional[str] = None,
            return_embedding: Optional[bool] = None) -> List[Document]:
        """
        Find the document that is most similar to the provided `query_emb` by using a vector similarity metric.

        :param query_emb: Embedding of the query (e.g. gathered from DPR)
        :param filters: Optional filters to narrow down the search space.
                        Example: {"name": ["some", "more"], "category": ["only_one"]}
        :param top_k: How many documents to return
        :param index: (SQL) index name for storing the docs and metadata
        :param return_embedding: To return document embedding
        :return:
        """
        if filters:
            raise Exception(
                "Query filters are not implemented for the MilvusDocumentStore."
            )

        index = index or self.index
        status, ok = self.milvus_server.has_collection(collection_name=index)
        if status.code != Status.SUCCESS:
            raise RuntimeError(f'Milvus has collection check failed: {status}')
        if not ok:
            raise Exception(
                "No index exists. Use 'update_embeddings()` to create an index."
            )

        if return_embedding is None:
            return_embedding = self.return_embedding
        index = index or self.index

        query_emb = query_emb.reshape(1, -1).astype(np.float32)
        status, search_result = self.milvus_server.search(
            collection_name=index,
            query_records=query_emb,
            top_k=top_k,
            params=self.search_param)
        if status.code != Status.SUCCESS:
            raise RuntimeError(f'Vector embedding search failed: {status}')

        vector_ids_for_query = []
        scores_for_vector_ids: Dict[str, float] = {}
        for vector_id_list, distance_list in zip(search_result.id_array,
                                                 search_result.distance_array):
            for vector_id, distance in zip(vector_id_list, distance_list):
                vector_ids_for_query.append(str(vector_id))
                scores_for_vector_ids[str(vector_id)] = distance

        documents = self.get_documents_by_vector_ids(vector_ids_for_query,
                                                     index=index)

        if return_embedding:
            self._populate_embeddings_to_docs(index=index, docs=documents)

        for doc in documents:
            doc.score = scores_for_vector_ids[doc.meta["vector_id"]]
            doc.probability = float(expit(np.asarray(doc.score / 100)))

        return documents

    def delete_all_documents(self,
                             index: Optional[str] = None,
                             filters: Optional[Dict[str, List[str]]] = None):
        """
        Delete all documents (from SQL AND Milvus).
        :param index: (SQL) index name for storing the docs and metadata
        :param filters: Optional filters to narrow down the search space.
                        Example: {"name": ["some", "more"], "category": ["only_one"]}
        :return: None
        """
        index = index or self.index
        super().delete_all_documents(index=index, filters=filters)
        status, ok = self.milvus_server.has_collection(collection_name=index)
        if status.code != Status.SUCCESS:
            raise RuntimeError(f'Milvus has collection check failed: {status}')
        if ok:
            status = self.milvus_server.drop_collection(collection_name=index)
            if status.code != Status.SUCCESS:
                raise RuntimeError(f'Milvus drop collection failed: {status}')

            self.milvus_server.flush([index])
            self.milvus_server.compact(collection_name=index)

    def get_all_documents_generator(
        self,
        index: Optional[str] = None,
        filters: Optional[Dict[str, List[str]]] = None,
        return_embedding: Optional[bool] = None,
        batch_size: int = 10_000,
    ) -> Generator[Document, None, None]:
        """
        Get all documents from the document store. Under-the-hood, documents are fetched in batches from the
        document store and yielded as individual documents. This method can be used to iteratively process
        a large number of documents without having to load all documents in memory.

        :param index: Name of the index to get the documents from. If None, the
                      DocumentStore's default index (self.index) will be used.
        :param filters: Optional filters to narrow down the documents to return.
                        Example: {"name": ["some", "more"], "category": ["only_one"]}
        :param return_embedding: Whether to return the document embeddings.
        :param batch_size: When working with large number of documents, batching can help reduce memory footprint.
        """
        index = index or self.index
        documents = super().get_all_documents_generator(index=index,
                                                        filters=filters,
                                                        batch_size=batch_size)
        if return_embedding is None:
            return_embedding = self.return_embedding

        for doc in documents:
            if return_embedding:
                self._populate_embeddings_to_docs(index=index, docs=[doc])
            yield doc