Exemple #1
0
def create_collection(collection_name):
    client = Milvus(host, str(port))

    status, ok = client.has_collection(collection_name)
    if not ok:
        param = {
            'collection_name': collection_name,
            'dimension': 3,
        }
        client.create_collection(param)
    client.close()
Exemple #2
0
    def test_connect_repeatedly(self, args):
        '''
        target: test connect repeatedly
        method: connect again
        expected: status.code is 0, and status.message shows have connected already
        '''
        milvus = Milvus()
        uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
        milvus.connect(uri=uri_value)

        milvus.connect(uri=uri_value)
        assert milvus.connected()
Exemple #3
0
    def _test_search_concurrent_multiprocessing(self, args):
        '''
        target: test concurrent search with multiprocessess
        method: search with 10 processes, each process uses dependent connection
        expected: status ok and the returned vectors should be query_records
        '''
        nb = 100
        top_k = 10
        process_num = 4
        processes = []
        table = gen_unique_str("test_search_concurrent_multiprocessing")
        uri = "tcp://%s:%s" % (args["ip"], args["port"])
        param = {'table_name': table,
             'dimension': dim,
             'index_type': IndexType.FLAT,
             'store_raw_vector': False}
        # create table
        milvus = Milvus()
        milvus.connect(uri=uri)
        milvus.create_table(param)
        vectors, ids = self.init_data(milvus, table, nb=nb)
        query_vecs = vectors[nb//2:nb]
        def search(milvus):
            status, result = milvus.search_vectors(table, top_k, query_vecs)
            assert len(result) == len(query_vecs)
            for i in range(len(query_vecs)):
                assert result[i][0].id in ids
                assert result[i][0].distance == 0.0

        for i in range(process_num):
            milvus = Milvus()
            milvus.connect(uri=uri)
            p = Process(target=search, args=(milvus, ))
            processes.append(p)
            p.start()
            time.sleep(0.2)
        for p in processes:
            p.join()
def milvus_test(usr_features, mov_features, ids):
    _HOST = '127.0.0.1'
    _PORT = '19530'  # default value
    milvus = Milvus(_HOST, _PORT)

    table_name = 'paddle_demo1'
    status, ok = milvus.has_collection(table_name)
    if not ok:
        param = {
            'collection_name': table_name,
            'dimension': 200,
            'index_file_size': 1024,  # optional
            'metric_type': MetricType.IP  # optional
        }

        milvus.create_collection(param)

    insert_vectors = normaliz_data([usr_features.tolist()])
    status, ids = milvus.insert(collection_name=table_name,
                                records=insert_vectors,
                                ids=ids)

    time.sleep(1)

    status, result = milvus.count_entities(table_name)
    print("rows in table paddle_demo1:", result)

    # status, table = milvus.count_entities(table_name)

    search_vectors = normaliz_data([mov_features.tolist()])
    param = {
        'collection_name': table_name,
        'query_records': search_vectors,
        'top_k': 1,
        'params': {
            'nprobe': 16
        }
    }
    status, results = milvus.search(**param)
    print("Searched ids:", results[0][0].id)
    print("Score:", float(results[0][0].distance) * 5)

    status = milvus.drop_collection(table_name)
Exemple #5
0
    def test_connect_uri_null(self, args):
        '''
        target: test connect with null uri
        method: uri set null
        expected: connected is True        
        '''
        milvus = Milvus()
        uri_value = ""

        if self.local_ip(args):
            milvus.connect(uri=uri_value, timeout=1)
            assert milvus.connected()
        else:
            with pytest.raises(Exception) as e:
                milvus.connect(uri=uri_value, timeout=1)
            assert not milvus.connected()
Exemple #6
0
    def _create_collection(_collection_param):
        milvus = Milvus()
        milvus.connect(**server_config)
        status, ok = milvus.has_collection(_collection_name)
        if ok:
            print("Table {} found, now going to delete it".format(
                _collection_name))
            status = milvus.drop_collection(_collection_name)
            if not status.OK():
                raise Exception("Delete collection error")
            print(
                "delete collection {} successfully!".format(_collection_name))
        time.sleep(5)

        status, ok = milvus.has_collection(_collection_name)
        if ok:
            raise Exception("Delete collection error")

        status = milvus.create_collection(param)
        if not status.OK():
            print("Create collection {} failed".format(_collection_name))

        milvus.disconnect()
def main():
    milvus = Milvus(host=SERVER_ADDR, port=SERVER_PORT)
    create_milvus_collection(milvus)
    partition_tag = get_partition_tag()
    count = 0
    while count < (VEC_NUM // BASE_LEN):
        vectors = load_bvecs_data(FILE_PATH, BASE_LEN, count)
        vectors_ids = [
            id for id in range(count * BASE_LEN, (count + 1) * BASE_LEN)
        ]
        create_partition(partition_tag[count], milvus)
        add_vectors(vectors, vectors_ids, partition_tag[count], milvus)

        count = count + 1
Exemple #8
0
 def create(self, name, **kwargs):
     """Create a new topo object and add in if not exist.
     Here the topo object is a Pymilvus client instance.
     """
     uri = kwargs.get('uri', None)
     if not uri:
         raise RuntimeError('\"uri\" is required to create connection pool')
     milvus_args = copy.deepcopy(kwargs)
     milvus_args["max_retry"] = settings.MAX_RETRY
     pool = Milvus(name=name, **milvus_args)
     status = self.add(pool)
     if status != topology.StatusType.OK:
         pool = None
     return status, pool
Exemple #9
0
def gcon(request, ghandler):
    ip = request.config.getoption("--ip")
    port = request.config.getoption("--port")
    milvus = Milvus(host=ip, port=port, handler=ghandler)

    def fin():
        try:
            pass
        except Exception as e:
            print(e)
            pass

    request.addfinalizer(fin)
    return milvus
    def _add():
        milvus = Milvus()
        status = milvus.connect(**server_config)

        vectors = _generate_vectors(128, 10000)
        print('\n\tPID: {}, insert {} vectors'.format(os.getpid(), 10000))
        status, _ = milvus.add_vectors(_table_name, vectors)

        milvus.disconnect()
Exemple #11
0
    def multi_thread_opr(table_name, utid):
        print("[{}] | T{} | Running .....".format(datetime.datetime.now(),
                                                  utid))

        client0 = Milvus(handler="HTTP")

        table_param = {'table_name': table_name, 'dimension': 64}

        vectors = [[random.random() for _ in range(64)] for _ in range(10000)]

        client0.connect()
        client0.create_table(table_param)

        for i in range(20):
            print("[{}] | T{} | O{} | Start insert data .....".format(
                datetime.datetime.now(), utid, i))
            client0.insert(table_name, vectors)
            print("[{}] | T{} | O{} | Stop insert data .....".format(
                datetime.datetime.now(), utid, i))

        client0.disconnect()

        print("[{}] | T{} | Stopping .....".format(datetime.datetime.now(),
                                                   utid))
def main():
    # connect_milvus_server()
    milvus = Milvus(host=SERVER_ADDR, port=SERVER_PORT)
    create_milvus_collection(milvus)
    build_collection(milvus)
    count = 0
    while count < (VEC_NUM // BASE_LEN):
        vectors = load_bvecs_data(FILE_PATH, BASE_LEN, count)
        vectors_ids = [
            id for id in range(count * BASE_LEN, (count + 1) * BASE_LEN)
        ]
        sex = [random.randint(0, 2) for _ in range(10000)]
        get_time = [random.randint(2017, 2020) for _ in range(10000)]
        is_glasses = [random.randint(10, 13) for _ in range(10000)]
        hybrid_entities = [{
            "name": "sex",
            "values": sex,
            "type": DataType.INT32
        }, {
            "name": "is_glasses",
            "values": is_glasses,
            "type": DataType.INT32
        }, {
            "name": "get_time",
            "values": get_time,
            "type": DataType.INT32
        }, {
            "name": "Vec",
            "values": vectors,
            "type": DataType.FLOAT_VECTOR
        }]
        time_start = time.time()
        result = milvus.insert('mixed06', hybrid_entities, ids=vectors_ids)
        time_end = time.time()
        print("insert milvue time: ", time_end - time_start)
        count = count + 1
Exemple #13
0
    def __init__(self, collection_name=None, ip=None, port=None, timeout=60):
        self._collection_name = collection_name
        try:
            i = 1
            start_time = time.time()
            if not ip:
                self._milvus = Milvus(host=SERVER_HOST_DEFAULT,
                                      port=SERVER_PORT_DEFAULT)
            else:
                # retry connect for remote server
                while time.time() < start_time + timeout:
                    try:
                        self._milvus = Milvus(host=ip, port=port)
                        if self._milvus.server_status():
                            logger.debug(
                                "Try connect times: %d, %s" %
                                (i, round(time.time() - start_time, 2)))
                            break
                    except Exception as e:
                        logger.debug("Milvus connect failed")
                        i = i + 1

        except Exception as e:
            raise e
Exemple #14
0
    def __init__(self, collection_name=None, host=None, port=None, timeout=60):
        """
        Milvus client wrapper for python-sdk.

        Default timeout set 60s
        """
        self._collection_name = collection_name
        try:
            start_time = time.time()
            if not host:
                host = SERVER_HOST_DEFAULT
            if not port:
                port = SERVER_PORT_DEFAULT
            logger.debug(host)
            logger.debug(port)
            # retry connect for remote server
            i = 0
            while time.time() < start_time + timeout:
                try:
                    self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False)
                    if self._milvus.server_status():
                        logger.debug("Try connect times: %d, %s" % (i, round(time.time() - start_time, 2)))
                        break
                except Exception as e:
                    logger.debug("Milvus connect failed: %d times" % i)
                    i = i + 1

            if time.time() > start_time + timeout:
                raise Exception("Server connect timeout")

        except Exception as e:
            raise e
        self._metric_type = None
        if self._collection_name and self.exists_collection():
            self._metric_type = metric_type_to_str(self.describe()[1].metric_type)
            self._dimension = self.describe()[1].dimension
Exemple #15
0
    def _add():
        milvus = Milvus()
        status = milvus.connect()

        if not status.OK:
            print(f'PID: {os.getpid()}, connect failed')

        status, _ = milvus.add_vectors(_table_name, vectors)

        milvus.disconnect()
Exemple #16
0
def connect(request, handler):
    ip = request.config.getoption("--ip")
    handler = request.config.getoption("--handler")
    port_default = default_http_port if handler == "HTTP" else default_grpc_port
    port = request.config.getoption("--port", default=port_default)

    client = Milvus(host=ip, port=port, handler=handler)

    def fin():
        try:
            client.close()
        except:
            pass

    request.addfinalizer(fin)
    return client
def create():
    _HOST = 'localhost'
    _PORT = '19530'
    _collection_name = 'chs_stars_faces_512'
    _DIM = 512  # dimension of vector
    _INDEX_FILE_SIZE = 256  # max file size of stored index

    milvus = Milvus(_HOST, _PORT)
    param = {
        'collection_name': _collection_name,
        'dimension': _DIM,
        'index_file_size': _INDEX_FILE_SIZE,  # optional
        'metric_type': MetricType.IP  # optional
    }


    milvus.create_collection(param)
    index_param = {
        'nlist': 2048  # 推荐 4 * sqrt(n)
    }

    status = milvus.create_index(_collection_name, IndexType.IVF_SQ8, index_param)

    # with open("chs_stars_features_pca.pickle", "rb") as f:
    #     pca = pickle.load(f)
    #
    # with open("../chs_stars_features_pca.csv", "w") as fw, open("../chs_stars_features.csv", "r") as fr:
    #     reader = csv.reader(fr)
    #     writer = csv.writer(fw)
    #     for index, line in enumerate(tqdm(reader)):
    #         star, fname, features = line
    #         features = np.array(json.loads(features))
    #         features = np.resize(features, (1, 512))
    #         features = normalize(features)
    #         features = pca.transform(features).squeeze()
    #         status, ids = milvus.insert(collection_name=_collection_name, records=[features.tolist()], ids=[index])
    #         if not status.OK():
    #             print(status)
    #             continue
    #         writer.writerow([index, star, fname, features])

    with open("../chs_stars_labels.csv", "w") as fw, open("../chs_stars_features.csv", "r") as fr:
        reader = csv.reader(fr)
        writer = csv.writer(fw)
        for index, line in enumerate(tqdm(reader)):
            star, fname, features = line
            # features = np.array(json.loads(features))
            # features = np.resize(features, (1, 512))
            #features = normalize(features)
            features = json.loads(features)
            status, ids = milvus.insert(collection_name=_collection_name, records=[features], ids=[index])
            if not status.OK():
                print(status)
                continue
            writer.writerow([index, star, fname])
Exemple #18
0
def main():
    milvus = Milvus()
    milvus.connect(host=_HOST, port=_PORT)
    #
    # table_name = 'test_search_in_file'
    # dimension = 256

    # vectors = Prepare.records([[random.random()for _ in range(dimension)] for _ in range(20)])
    # param = {
    #     'table_name': table_name,
    #     'file_ids': ['1'],
    #     'query_records': vectors,
    #     'top_k': 5,
    #     # 'query_ranges': []  # Not fully tested yet
    # }
    # status, result = milvus.search_vectors_in_files(**param)
    # if status.OK():
    #     pprint(result)
    #
    # _, result = milvus.get_table_row_count(table_name)
    # print('# Count: {}'.format(result))

    table_name = 'test_search'
    dimension = 256
    # param = {'start_date': '2019-06-24', 'end_date': '2019-06-25'}
    ranges = [['2019-06-25', '2019-06-25']]

    vectors = Prepare.records([[random.random() for _ in range(dimension)]
                               for _ in range(1)])
    # ranges = [Prepare.range(**param)]
    LOGGER.info(ranges)
    param = {
        'table_name': table_name,
        'query_records': vectors,
        'top_k': 5,
        'query_ranges': ranges  # Not fully tested yet
    }
    status, result = milvus.search_vectors(**param)
    if status.OK():
        pprint(result)

    _, result = milvus.get_table_row_count(table_name)
    print('# Count: {}'.format(result))
    milvus.disconnect()
 def test_search_multi_table_IP(search, args):
     '''
     target: test search multi tables of IP
     method: add vectors into 10 tables, and search
     expected: search status ok, the length of result
     '''
     num = 10
     top_k = 10
     nprobe = 1
     tables = []
     idx = []
     for i in range(num):
         table = gen_unique_str("test_add_multitable_%d" % i)
         uri = "tcp://%s:%s" % (args["ip"], args["port"])
         param = {
             'table_name': table,
             'dimension': dim,
             'index_file_size': 10,
             'metric_type': MetricType.L2
         }
         # create table
         milvus = Milvus()
         milvus.connect(uri=uri)
         milvus.create_table(param)
         status, ids = milvus.add_vectors(table, vectors)
         assert status.OK()
         assert len(ids) == len(vectors)
         tables.append(table)
         idx.append(ids[0])
         idx.append(ids[10])
         idx.append(ids[20])
     time.sleep(6)
     query_vecs = [vectors[0], vectors[10], vectors[20]]
     # start query from random table
     for i in range(num):
         table = tables[i]
         status, result = milvus.search_vectors(table, top_k, nprobe,
                                                query_vecs)
         assert status.OK()
         assert len(result) == len(query_vecs)
         for j in range(len(query_vecs)):
             assert len(result[j]) == top_k
         for j in range(len(query_vecs)):
             assert check_result(result[j], idx[3 * i + j])
Exemple #20
0
def connect(request):
    host = '192.168.1.238'
    port = 19530
    try:
        milvus = Milvus(host=host, port=port)
    except Exception as e:
        logging.getLogger().error(str(e))
        pytest.exit("Milvus server can not connected, exit pytest ...")

    def fin():
        try:
            milvus.close()
            pass
        except Exception as e:
            logging.getLogger().info(str(e))

    request.addfinalizer(fin)
    return milvus
Exemple #21
0
    def _add_milvus_question(self, question_vector, collection: str, partition: str,
                             milvus: mv.Milvus) -> int:
        """
        添加标准问题

        @param {object} question_vector - 问题向量
        @param {str} collection - 问题分类
        @param {str} partition - 场景
        @param {mv.Milvus} milvus - Milvus服务连接对象

        @returns {int} - 返回milvus_id
        """
        _status, _milvus_ids = milvus.insert(
            collection, [question_vector, ], partition_tag=partition)
        self.confirm_milvus_status(_status, 'insert')
        self._log_debug('insert _milvus_ids: %s' % str(_milvus_ids))

        return _milvus_ids[0]
Exemple #22
0
def dis_connect(request):
    ip = request.config.getoption("--ip")
    port = request.config.getoption("--port")
    milvus = Milvus()
    milvus.connect(host=ip, port=port)
    milvus.disconnect()
    def fin():
        try:
            milvus.disconnect()
        except:
            pass

    request.addfinalizer(fin)
    return milvus
class TestToServer:
    fake_milvus = Milvus()
    fake_milvus.connect(host='127.0.0.1', port='9090')

    @mock.patch.object(Ms, 'server_status')
    def test_ping(self, server_status):
        server_status.return_value = 'OK'
        ans = self.fake_milvus.server_status('fake_ping')
        assert ans == 'OK'

        ans = self.fake_milvus.server_status('version')
        assert ans == 'OK'

    @mock.patch.object(Ms, 'create_table')
    def test_crate_table(self, create_table):
        create_table.return_value = Status.SUCCESS
        ans = self.fake_milvus.create_table('fakeparam')
        assert ans == Status.SUCCESS

    @mock.patch.object(Ms, 'add_vectors')
    def test_add_vector(self, add_vectors):
        add_vectors.return_value = ['aaaa']
        ans = self.fake_milvus.add_vectors('fake1', 'fake2')
        assert ans == ['aaaa']

    @mock.patch.object(Ms, 'describe_table')
    def test_describe_table(self, describe_table):
        describe_table.return_value = 'fake_table_name'
        ans = self.fake_milvus.describe_table('fake_param')
        assert ans == 'fake_table_name'

    @mock.patch.object(Ms, 'show_tables')
    def test_show_tables(self, show_tables):
        show_tables.return_value = 'some_table'
        ans = self.fake_milvus.show_tables()
        assert ans == 'some_table'

    @mock.patch.object(Ms, 'get_table_row_count')
    def test_get_table_row_count(self, get_table_row_count):
        get_table_row_count.return_value = 666
        ans = self.fake_milvus.get_table_row_count('fake_table')
        assert ans == 666
Exemple #24
0
 def test_disconnect_repeatedly(self, connect, args):
     '''
     target: test disconnect repeatedly
     method: disconnect a connected client, disconnect again
     expected: raise an error after disconnected
     '''
     if not connect.connected():
         milvus = Milvus()
         uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
         milvus.connect(uri=uri_value)
         res = milvus.disconnect()
         with pytest.raises(Exception) as e:
             res = milvus.disconnect()
     else:
         res = connect.disconnect()
         with pytest.raises(Exception) as e:
             res = connect.disconnect()
Exemple #25
0
    def _test_connect_disconnect_with_multiprocess(self, args):
        '''
        target: test uri connect and disconnect repeatly with multiprocess
        method: set correct uri, test with multiprocessing connecting and disconnecting
        expected: all connection is connected after 10 times operation       
        '''
        uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
        process_num = 4
        processes = []

        def connect(milvus):
            milvus.connect(uri=uri_value)
            milvus.disconnect()
            milvus.connect(uri=uri_value)
            assert milvus.connected()

        for i in range(process_num):
            milvus = Milvus()
            p = Process(target=connect, args=(milvus, ))
            processes.append(p)
            p.start()
        for p in processes:
            p.join()
def predict(start_date_str, end_date_str):
    print("加载模型")
    model = gensim.models.doc2vec.Doc2Vec.load("./doc2vec.model")
    print("建立milvus链接")
    client = Milvus(host=milvus_ip, port='19530')
    print("读取数据ing")
    start_date = datetime.strptime(start_date_str,
                                   '%Y-%m-%d').timestamp() * 1000
    end_date = datetime.strptime(end_date_str, '%Y-%m-%d').timestamp() * 1000
    res = Paper.query_by_time_interval(start_date, end_date)

    num = 0
    start = time.time()
    id_list = []
    user_id_list = []
    vecs = []

    for i in res:
        paper_id = i.id
        paper_user_id = i.user_id
        paper_str = i.title + " . " + i.description
        vec = get_vector(model, [paper_str])
        # 将词向量写入到Milvus
        id_list.append(paper_id)
        user_id_list.append(paper_user_id)
        vecs.append(list(vec))
        # 将词向量写入数据库
        paper_vec = str(vec).replace('\n', '').replace('[', '').replace(
            ']', '').replace("  ", " ").replace(" ", ",")[1:]
        paper_vec = paper_vec.replace(",,", ",0,")
        Paper.update_SQL('doc_vector', paper_vec, paper_user_id)

        num += 1
        if num % 200 == 0:
            print("完成了", num, '篇', '--用时:', time.time() - start)
            start = time.time()
            # hybrid_entities = [
            #     {"name": "id", "values": id_list, "type": DataType.INT32},
            #     {"name": "Vec", "values": vecs, "type": DataType.FLOAT_VECTOR}
            # ]
            client.insert('ideaman', records=vecs, ids=id_list)
            client.flush(collection_name_array=["ideaman"])
            user_id_list.clear()
            id_list.clear()
            vecs.clear()
Exemple #27
0
    def _test_connect_with_multiprocess(self, args):
        '''
        target: test uri connect with multiprocess
        method: set correct uri, test with multiprocessing connecting
        expected: all connection is connected        
        '''
        uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
        process_num = 4
        processes = []

        def connect(milvus):
            milvus.connect(uri=uri_value)
            with pytest.raises(Exception) as e:
                milvus.connect(uri=uri_value)
            assert milvus.connected()

        for i in range(process_num):
            milvus = Milvus()
            p = Process(target=connect, args=(milvus, ))
            processes.append(p)
            p.start()
        for p in processes:
            p.join()
Exemple #28
0
class MilvusConnection:
    def __init__(self, env, name="movies_L2", port="19530", param=None):

        if param is None:
            param = dict()
        param = {
            "collection_name": name,
            "dimension": 128,
            "index_file_size": 1024,
            "metric_type": MetricType.L2,
            **param,
        }
        self.name = name
        self.client = Milvus(host="localhost", port=port)
        self.statuses = {}
        if not self.client.has_collection(name)[1]:
            status_created_collection = self.client.create_collection(param)
            vectors = env.base.embeddings.detach().cpu().numpy().astype(
                "float32")
            target_ids = list(range(vectors.shape[0]))
            status_inserted, inserted_vector_ids = self.client.insert(
                collection_name=name, records=vectors, ids=target_ids)
            status_flushed = self.client.flush([name])
            status_compacted = self.client.compact(collection_name=name)
            self.statuses["created_collection"] = status_created_collection
            self.statuses["inserted"] = status_inserted
            self.statuses["flushed"] = status_flushed
            self.statuses["compacted"] = status_compacted

    def search(self, search_vecs, topk=10, search_param=None):
        if search_param is None:
            search_param = dict()
        search_param = {"nprobe": 16, **search_param}
        status, results = self.client.search(
            collection_name=self.name,
            query_records=search_vecs,
            top_k=topk,
            params=search_param,
        )
        self.statuses['last_search'] = status
        return torch.tensor(results.id_array)

    def get_log(self):
        return self.statuses
Exemple #29
0
    return status.OK() and ok


if __name__ == "__main__":
    import numpy

    dim = 128
    nq = 10000
    table = "test"

    file_name = '/poc/yuncong/ann_1000m/query.npy'
    data = np.load(file_name)
    vectors = data[0:nq].tolist()
    # print(vectors)

    connect = Milvus()
    # connect.connect(host="192.168.1.27")
    # print(connect.show_tables())
    # print(connect.get_table_row_count(table))
    # sys.exit()
    connect.connect(host="127.0.0.1")
    connect.delete_table(table)
    # sys.exit()
    # time.sleep(2)
    print(connect.get_table_row_count(table))
    param = {
        'table_name': table,
        'dimension': dim,
        'metric_type': MetricType.L2,
        'index_file_size': 10
    }
Exemple #30
0
class MilvusClient(object):
    def __init__(self,
                 collection_name=None,
                 host=None,
                 port=None,
                 timeout=180):
        self._collection_name = collection_name
        start_time = time.time()
        if not host:
            host = SERVER_HOST_DEFAULT
        if not port:
            port = SERVER_PORT_DEFAULT
        logger.debug(host)
        logger.debug(port)
        # retry connect remote server
        i = 0
        while time.time() < start_time + timeout:
            try:
                self._milvus = Milvus(host=host,
                                      port=port,
                                      try_connect=False,
                                      pre_ping=False)
                break
            except Exception as e:
                logger.error(str(e))
                logger.error("Milvus connect failed: %d times" % i)
                i = i + 1
                time.sleep(i)

        if time.time() > start_time + timeout:
            raise Exception("Server connect timeout")
        # self._metric_type = None

    def __str__(self):
        return 'Milvus collection %s' % self._collection_name

    def check_status(self, status):
        if not status.OK():
            logger.error(status.message)
            logger.error(self._milvus.server_status())
            logger.error(self.count())
            raise Exception("Status not ok")

    def check_result_ids(self, result):
        for index, item in enumerate(result):
            if item[0].distance >= epsilon:
                logger.error(index)
                logger.error(item[0].distance)
                raise Exception("Distance wrong")

    # only support the given field name
    def create_collection(self,
                          dimension,
                          data_type=DataType.FLOAT_VECTOR,
                          auto_id=False,
                          collection_name=None,
                          other_fields=None):
        self._dimension = dimension
        if not collection_name:
            collection_name = self._collection_name
        vec_field_name = utils.get_default_field_name(data_type)
        fields = [{
            "name": vec_field_name,
            "type": data_type,
            "params": {
                "dim": dimension
            }
        }]
        if other_fields:
            other_fields = other_fields.split(",")
            if "int" in other_fields:
                fields.append({
                    "name": utils.DEFAULT_INT_FIELD_NAME,
                    "type": DataType.INT64
                })
            if "float" in other_fields:
                fields.append({
                    "name": utils.DEFAULT_FLOAT_FIELD_NAME,
                    "type": DataType.FLOAT
                })
        create_param = {"fields": fields, "auto_id": auto_id}
        try:
            self._milvus.create_collection(collection_name, create_param)
            logger.info("Create collection: <%s> successfully" %
                        collection_name)
        except Exception as e:
            logger.error(str(e))
            raise

    def create_partition(self, tag, collection_name=None):
        if not collection_name:
            collection_name = self._collection_name
        self._milvus.create_partition(collection_name, tag)

    def generate_values(self, data_type, vectors, ids):
        values = None
        if data_type in [DataType.INT32, DataType.INT64]:
            values = ids
        elif data_type in [DataType.FLOAT, DataType.DOUBLE]:
            values = [(i + 0.0) for i in ids]
        elif data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
            values = vectors
        return values

    def generate_entities(self, vectors, ids=None, collection_name=None):
        entities = []
        if collection_name is None:
            collection_name = self._collection_name
        info = self.get_info(collection_name)
        for field in info["fields"]:
            field_type = field["type"]
            entities.append({
                "name":
                field["name"],
                "type":
                field_type,
                "values":
                self.generate_values(field_type, vectors, ids)
            })
        return entities

    @time_wrapper
    def insert(self, entities, ids=None, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        try:
            insert_ids = self._milvus.insert(tmp_collection_name,
                                             entities,
                                             ids=ids)
            return insert_ids
        except Exception as e:
            logger.error(str(e))

    def get_dimension(self):
        info = self.get_info()
        for field in info["fields"]:
            if field["type"] in [
                    DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR
            ]:
                return field["params"]["dim"]

    def get_rand_ids(self, length):
        segment_ids = []
        while True:
            stats = self.get_stats()
            segments = stats["partitions"][0]["segments"]
            # random choice one segment
            segment = random.choice(segments)
            try:
                segment_ids = self._milvus.list_id_in_segment(
                    self._collection_name, segment["id"])
            except Exception as e:
                logger.error(str(e))
            if not len(segment_ids):
                continue
            elif len(segment_ids) > length:
                return random.sample(segment_ids, length)
            else:
                logger.debug("Reset length: %d" % len(segment_ids))
                return segment_ids

    # def get_rand_ids_each_segment(self, length):
    #     res = []
    #     status, stats = self._milvus.get_collection_stats(self._collection_name)
    #     self.check_status(status)
    #     segments = stats["partitions"][0]["segments"]
    #     segments_num = len(segments)
    #     # random choice from each segment
    #     for segment in segments:
    #         status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
    #         self.check_status(status)
    #         res.extend(segment_ids[:length])
    #     return segments_num, res

    # def get_rand_entities(self, length):
    #     ids = self.get_rand_ids(length)
    #     status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids)
    #     self.check_status(status)
    #     return ids, get_res

    def get(self):
        get_ids = random.randint(1, 1000000)
        self._milvus.get_entity_by_id(self._collection_name, [get_ids])

    @time_wrapper
    def get_entities(self, get_ids):
        get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids)
        return get_res

    @time_wrapper
    def delete(self, ids, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        self._milvus.delete_entity_by_id(tmp_collection_name, ids)

    def delete_rand(self):
        delete_id_length = random.randint(1, 100)
        count_before = self.count()
        logger.debug("%s: length to delete: %d" %
                     (self._collection_name, delete_id_length))
        delete_ids = self.get_rand_ids(delete_id_length)
        self.delete(delete_ids)
        self.flush()
        logger.info("%s: count after delete: %d" %
                    (self._collection_name, self.count()))
        get_res = self._milvus.get_entity_by_id(self._collection_name,
                                                delete_ids)
        for item in get_res:
            assert not item
        # if count_before - len(delete_ids) < self.count():
        #     logger.error(delete_ids)
        #     raise Exception("Error occured")

    @time_wrapper
    def flush(self, _async=False, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        self._milvus.flush([tmp_collection_name], _async=_async)

    @time_wrapper
    def compact(self, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        status = self._milvus.compact(tmp_collection_name)
        self.check_status(status)

    @time_wrapper
    def create_index(self,
                     field_name,
                     index_type,
                     metric_type,
                     _async=False,
                     index_param=None):
        index_type = INDEX_MAP[index_type]
        metric_type = utils.metric_type_trans(metric_type)
        logger.info(
            "Building index start, collection_name: %s, index_type: %s, metric_type: %s"
            % (self._collection_name, index_type, metric_type))
        if index_param:
            logger.info(index_param)
        index_params = {
            "index_type": index_type,
            "metric_type": metric_type,
            "params": index_param
        }
        self._milvus.create_index(self._collection_name,
                                  field_name,
                                  index_params,
                                  _async=_async)

    # TODO: need to check
    def describe_index(self, field_name):
        # stats = self.get_stats()
        info = self._milvus.describe_index(self._collection_name, field_name)
        index_info = {"index_type": "flat", "index_param": None}
        for field in info["fields"]:
            for index in field['indexes']:
                if not index or "index_type" not in index:
                    continue
                else:
                    for k, v in INDEX_MAP.items():
                        if index['index_type'] == v:
                            index_info['index_type'] = k
                            index_info['index_param'] = index['params']
                            return index_info
        return index_info

    def drop_index(self, field_name):
        logger.info("Drop index: %s" % self._collection_name)
        return self._milvus.drop_index(self._collection_name, field_name)

    @time_wrapper
    def query(self, vector_query, filter_query=None, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        must_params = [vector_query]
        if filter_query:
            must_params.extend(filter_query)
        query = {"bool": {"must": must_params}}
        result = self._milvus.search(tmp_collection_name, query)
        return result

    @time_wrapper
    def load_and_query(self,
                       vector_query,
                       filter_query=None,
                       collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        must_params = [vector_query]
        if filter_query:
            must_params.extend(filter_query)
        query = {"bool": {"must": must_params}}
        self.load_collection(tmp_collection_name)
        result = self._milvus.search(tmp_collection_name, query)
        return result

    def get_ids(self, result):
        idss = result._entities.ids
        ids = []
        len_idss = len(idss)
        len_r = len(result)
        top_k = len_idss // len_r
        for offset in range(0, len_idss, top_k):
            ids.append(idss[offset:min(offset + top_k, len_idss)])
        return ids

    def query_rand(self, nq_max=100):
        # for ivf search
        dimension = 128
        top_k = random.randint(1, 100)
        nq = random.randint(1, nq_max)
        nprobe = random.randint(1, 100)
        search_param = {"nprobe": nprobe}
        query_vectors = [[random.random() for _ in range(dimension)]
                         for _ in range(nq)]
        metric_type = random.choice(["l2", "ip"])
        logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" %
                    (self._collection_name, nq, top_k, nprobe))
        vec_field_name = utils.get_default_field_name()
        vector_query = {
            "vector": {
                vec_field_name: {
                    "topk": top_k,
                    "query": query_vectors,
                    "metric_type": utils.metric_type_trans(metric_type),
                    "params": search_param
                }
            }
        }
        self.query(vector_query)

    def load_query_rand(self, nq_max=100):
        # for ivf search
        dimension = 128
        top_k = random.randint(1, 100)
        nq = random.randint(1, nq_max)
        nprobe = random.randint(1, 100)
        search_param = {"nprobe": nprobe}
        query_vectors = [[random.random() for _ in range(dimension)]
                         for _ in range(nq)]
        metric_type = random.choice(["l2", "ip"])
        logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" %
                    (self._collection_name, nq, top_k, nprobe))
        vec_field_name = utils.get_default_field_name()
        vector_query = {
            "vector": {
                vec_field_name: {
                    "topk": top_k,
                    "query": query_vectors,
                    "metric_type": utils.metric_type_trans(metric_type),
                    "params": search_param
                }
            }
        }
        self.load_and_query(vector_query)

    # TODO: need to check
    def count(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        row_count = self._milvus.get_collection_stats(
            collection_name)["row_count"]
        logger.debug("Row count: %d in collection: <%s>" %
                     (row_count, collection_name))
        return row_count

    def drop(self, timeout=120, collection_name=None):
        timeout = int(timeout)
        if collection_name is None:
            collection_name = self._collection_name
        logger.info("Start delete collection: %s" % collection_name)
        self._milvus.drop_collection(collection_name)
        i = 0
        while i < timeout:
            try:
                row_count = self.count(collection_name=collection_name)
                if row_count:
                    time.sleep(1)
                    i = i + 1
                    continue
                else:
                    break
            except Exception as e:
                logger.debug(str(e))
                break
        if i >= timeout:
            logger.error("Delete collection timeout")

    def get_stats(self):
        return self._milvus.get_collection_stats(self._collection_name)

    def get_info(self, collection_name=None):
        # pdb.set_trace()
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.get_collection_info(collection_name)

    def show_collections(self):
        return self._milvus.list_collections()

    def exists_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        res = self._milvus.has_collection(collection_name)
        return res

    def clean_db(self):
        collection_names = self.show_collections()
        for name in collection_names:
            self.drop(collection_name=name)

    @time_wrapper
    def load_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.load_collection(collection_name, timeout=3000)

    @time_wrapper
    def release_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.release_collection(collection_name, timeout=3000)

    @time_wrapper
    def load_partitions(self, tag_names, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.load_partitions(collection_name,
                                            tag_names,
                                            timeout=3000)

    @time_wrapper
    def release_partitions(self, tag_names, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.release_partitions(collection_name,
                                               tag_names,
                                               timeout=3000)