Пример #1
0
class MilvusClient(object):
    def __init__(self, collection_name=None, host=None, port=None, timeout=60):
        """
        Milvus client wrapper for python-sdk.

        Default timeout set 60s
        """
        self._collection_name = collection_name
        try:
            start_time = time.time()
            if not host:
                host = SERVER_HOST_DEFAULT
            if not port:
                port = SERVER_PORT_DEFAULT
            logger.debug(host)
            logger.debug(port)
            # retry connect for remote server
            i = 0
            while time.time() < start_time + timeout:
                try:
                    self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False)
                    if self._milvus.server_status():
                        logger.debug("Try connect times: %d, %s" % (i, round(time.time() - start_time, 2)))
                        break
                except Exception as e:
                    logger.debug("Milvus connect failed: %d times" % i)
                    i = i + 1

            if time.time() > start_time + timeout:
                raise Exception("Server connect timeout")

        except Exception as e:
            raise e
        self._metric_type = None
        if self._collection_name and self.exists_collection():
            self._metric_type = metric_type_to_str(self.describe()[1].metric_type)
            self._dimension = self.describe()[1].dimension

    def __str__(self):
        return 'Milvus collection %s' % self._collection_name

    def set_collection(self, name):
        self._collection_name = name

    def check_status(self, status):
        if not status.OK():
            logger.error(self._collection_name)
            logger.error(status.message)
            logger.error(self._milvus.server_status())
            logger.error(self.count())
            raise Exception("Status not ok")

    def check_result_ids(self, result):
        for index, item in enumerate(result):
            if item[0].distance >= epsilon:
                logger.error(index)
                logger.error(item[0].distance)
                raise Exception("Distance wrong")

    def create_collection(self, collection_name, dimension, index_file_size, metric_type):
        if not self._collection_name:
            self._collection_name = collection_name
        if metric_type not in METRIC_MAP.keys():
            raise Exception("Not supported metric_type: %s" % metric_type)
        metric_type = METRIC_MAP[metric_type]
        create_param = {'collection_name': collection_name,
                 'dimension': dimension,
                 'index_file_size': index_file_size, 
                 "metric_type": metric_type}
        status = self._milvus.create_collection(create_param)
        self.check_status(status)

    def create_partition(self, tag_name):
        status = self._milvus.create_partition(self._collection_name, tag_name)
        self.check_status(status)

    def drop_partition(self, tag_name):
        status = self._milvus.drop_partition(self._collection_name, tag_name)
        self.check_status(status)

    def list_partitions(self):
        status, tags = self._milvus.list_partitions(self._collection_name)
        self.check_status(status)
        return tags

    @time_wrapper
    def insert(self, X, ids=None, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status, result = self._milvus.insert(collection_name, X, ids)
        self.check_status(status)
        return status, result

    def insert_rand(self):
        insert_xb = random.randint(1, 100)
        X = [[random.random() for _ in range(self._dimension)] for _ in range(insert_xb)]
        X = utils.normalize(self._metric_type, X)
        count_before = self.count()
        status, _ = self.insert(X)
        self.check_status(status)
        self.flush()
        if count_before + insert_xb != self.count():
            raise Exception("Assert failed after inserting")

    def get_rand_ids(self, length):
        while True:
            status, stats = self._milvus.get_collection_stats(self._collection_name)
            self.check_status(status)
            segments = stats["partitions"][0]["segments"]
            # random choice one segment
            segment = random.choice(segments)
            status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
            if not status.OK():
                logger.error(status.message)
                continue
            if len(segment_ids):
                break
        if length >= len(segment_ids):
            logger.debug("Reset length: %d" % len(segment_ids))
            return segment_ids
        return random.sample(segment_ids, length)

    def get_rand_ids_each_segment(self, length):
        res = []
        status, stats = self._milvus.get_collection_stats(self._collection_name)
        self.check_status(status)
        segments = stats["partitions"][0]["segments"]
        segments_num = len(segments)
        # random choice from each segment
        for segment in segments:
            status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
            self.check_status(status)
            res.extend(segment_ids[:length])
        return segments_num, res

    def get_rand_entities(self, length):
        ids = self.get_rand_ids(length)
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids)
        self.check_status(status)
        return ids, get_res

    @time_wrapper
    def get_entities(self, get_ids):
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids)
        self.check_status(status)
        return get_res

    @time_wrapper
    def delete(self, ids, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status = self._milvus.delete_entity_by_id(collection_name, ids)
        self.check_status(status)

    def delete_rand(self):
        delete_id_length = random.randint(1, 100)
        count_before = self.count()
        logger.info("%s: length to delete: %d" % (self._collection_name, delete_id_length))
        delete_ids = self.get_rand_ids(delete_id_length)
        self.delete(delete_ids)
        self.flush()
        logger.info("%s: count after delete: %d" % (self._collection_name, self.count()))
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids)
        self.check_status(status)
        for item in get_res:
            if item:
                raise Exception("Assert failed after delete")
        if count_before - len(delete_ids) != self.count():
            raise Exception("Assert failed after delete")

    @time_wrapper
    def flush(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status = self._milvus.flush([collection_name])
        self.check_status(status)

    @time_wrapper
    def compact(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status = self._milvus.compact(collection_name)
        self.check_status(status)

    @time_wrapper
    def create_index(self, index_type, index_param=None):
        index_type = INDEX_MAP[index_type]
        logger.info("Building index start, collection_name: %s, index_type: %s" % (self._collection_name, index_type))
        if index_param:
            logger.info(index_param)
        status = self._milvus.create_index(self._collection_name, index_type, index_param)
        self.check_status(status)

    def describe_index(self):
        status, result = self._milvus.get_index_info(self._collection_name)
        self.check_status(status)
        index_type = None
        for k, v in INDEX_MAP.items():
            if result._index_type == v:
                index_type = k
                break
        return {"index_type": index_type, "index_param": result._params}

    def drop_index(self):
        logger.info("Drop index: %s" % self._collection_name)
        return self._milvus.drop_index(self._collection_name)

    def query(self, X, top_k, search_param=None, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status, result = self._milvus.search(collection_name, top_k, query_records=X, params=search_param)
        self.check_status(status)
        return result

    def query_rand(self):
        top_k = random.randint(1, 100)
        nq = random.randint(1, 100)
        nprobe = random.randint(1, 100)
        search_param = {"nprobe": nprobe}
        _, X = self.get_rand_entities(nq)
        logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe))
        status, _ = self._milvus.search(self._collection_name, top_k, query_records=X, params=search_param)
        self.check_status(status)
        # for i, item in enumerate(search_res):
        #     if item[0].id != ids[i]:
        #         logger.warning("The index of search result: %d" % i)
        #         raise Exception("Query failed")

    # @time_wrapper
    # def query_ids(self, top_k, ids, search_param=None):
    #     status, result = self._milvus.search_by_id(self._collection_name, ids, top_k, params=search_param)
    #     self.check_result_ids(result)
    #     return result

    def count(self, name=None):
        if name is None:
            name = self._collection_name
        logger.debug(self._milvus.count_entities(name))
        row_count = self._milvus.count_entities(name)[1]
        if not row_count:
            row_count = 0
        logger.debug("Row count: %d in collection: <%s>" % (row_count, name))
        return row_count

    def drop(self, timeout=120, name=None):
        timeout = int(timeout)
        if name is None:
            name = self._collection_name
        logger.info("Start delete collection: %s" % name)
        status = self._milvus.drop_collection(name)
        self.check_status(status)
        i = 0
        while i < timeout:
            if self.count(name=name):
                time.sleep(1)
                i = i + 1
                continue
            else:
                break
        if i >= timeout:
            logger.error("Delete collection timeout")

    def describe(self):
        # logger.info(self._milvus.get_collection_info(self._collection_name))
        return self._milvus.get_collection_info(self._collection_name)

    def show_collections(self):
        return self._milvus.list_collections()

    def exists_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        _, res = self._milvus.has_collection(collection_name)
        # self.check_status(status)
        return res

    def clean_db(self):
        collection_names = self.show_collections()[1]
        for name in collection_names:
            logger.debug(name)
            self.drop(name=name)

    @time_wrapper
    def preload_collection(self):
        status = self._milvus.load_collection(self._collection_name, timeout=3000)
        self.check_status(status)
        return status

    def get_server_version(self):
        _, res = self._milvus.server_version()
        return res

    def get_server_mode(self):
        return self.cmd("mode")

    def get_server_commit(self):
        return self.cmd("build_commit_id")

    def get_server_config(self):
        return json.loads(self.cmd("get_config *"))

    def get_mem_info(self):
        result = json.loads(self.cmd("get_system_info"))
        result_human = {
            # unit: Gb
            "memory_used": round(int(result["memory_used"]) / (1024*1024*1024), 2)
        }
        return result_human

    def cmd(self, command):
        status, res = self._milvus._cmd(command)
        logger.info("Server command: %s, result: %s" % (command, res))
        self.check_status(status)
        return res
Пример #2
0
class MilvusClient(object):
    def __init__(self, collection_name=None, ip=None, port=None, timeout=60):
        self._collection_name = collection_name
        try:
            i = 1
            start_time = time.time()
            if not ip:
                self._milvus = Milvus(host=SERVER_HOST_DEFAULT,
                                      port=SERVER_PORT_DEFAULT)
            else:
                # retry connect for remote server
                while time.time() < start_time + timeout:
                    try:
                        self._milvus = Milvus(host=ip, port=port)
                        if self._milvus.server_status():
                            logger.debug(
                                "Try connect times: %d, %s" %
                                (i, round(time.time() - start_time, 2)))
                            break
                    except Exception as e:
                        logger.debug("Milvus connect failed")
                        i = i + 1

        except Exception as e:
            raise e

    def __str__(self):
        return 'Milvus collection %s' % self._collection_name

    def check_status(self, status):
        if not status.OK():
            logger.error(status.message)
            raise Exception("Status not ok")

    def check_result_ids(self, result):
        for index, item in enumerate(result):
            if item[0].distance >= epsilon:
                logger.error(index)
                logger.error(item[0].distance)
                raise Exception("Distance wrong")

    def create_collection(self, collection_name, dimension, index_file_size,
                          metric_type):
        if not self._collection_name:
            self._collection_name = collection_name
        if metric_type == "l2":
            metric_type = MetricType.L2
        elif metric_type == "ip":
            metric_type = MetricType.IP
        elif metric_type == "jaccard":
            metric_type = MetricType.JACCARD
        elif metric_type == "hamming":
            metric_type = MetricType.HAMMING
        elif metric_type == "sub":
            metric_type = MetricType.SUBSTRUCTURE
        elif metric_type == "super":
            metric_type = MetricType.SUPERSTRUCTURE
        else:
            logger.error("Not supported metric_type: %s" % metric_type)
        create_param = {
            'collection_name': collection_name,
            'dimension': dimension,
            'index_file_size': index_file_size,
            "metric_type": metric_type
        }
        status = self._milvus.create_collection(create_param)
        self.check_status(status)

    @time_wrapper
    def insert(self, X, ids=None):
        status, result = self._milvus.add_vectors(self._collection_name, X,
                                                  ids)
        self.check_status(status)
        return status, result

    @time_wrapper
    def delete_vectors(self, ids):
        status = self._milvus.delete_by_id(self._collection_name, ids)
        self.check_status(status)

    @time_wrapper
    def flush(self):
        status = self._milvus.flush([self._collection_name])
        self.check_status(status)

    @time_wrapper
    def compact(self):
        status = self._milvus.compact(self._collection_name)
        self.check_status(status)

    @time_wrapper
    def create_index(self, index_type, index_param=None):
        index_type = INDEX_MAP[index_type]
        logger.info(
            "Building index start, collection_name: %s, index_type: %s" %
            (self._collection_name, index_type))
        if index_param:
            logger.info(index_param)
        status = self._milvus.create_index(self._collection_name, index_type,
                                           index_param)
        self.check_status(status)

    def describe_index(self):
        status, result = self._milvus.describe_index(self._collection_name)
        self.check_status(status)
        index_type = None
        for k, v in INDEX_MAP.items():
            if result._index_type == v:
                index_type = k
                break
        return {"index_type": index_type, "index_param": result._params}

    def drop_index(self):
        logger.info("Drop index: %s" % self._collection_name)
        return self._milvus.drop_index(self._collection_name)

    @time_wrapper
    def query(self, X, top_k, search_param=None):
        status, result = self._milvus.search_vectors(self._collection_name,
                                                     top_k,
                                                     query_records=X,
                                                     params=search_param)
        self.check_status(status)
        return result

    @time_wrapper
    def query_ids(self, top_k, ids, search_param=None):
        status, result = self._milvus.search_by_ids(self._collection_name,
                                                    ids,
                                                    top_k,
                                                    params=search_param)
        self.check_result_ids(result)
        return result

    def count(self):
        return self._milvus.count_collection(self._collection_name)[1]

    def delete(self, timeout=120):
        timeout = int(timeout)
        logger.info("Start delete collection: %s" % self._collection_name)
        self._milvus.drop_collection(self._collection_name)
        i = 0
        while i < timeout:
            if self.count():
                time.sleep(1)
                i = i + 1
                continue
            else:
                break
        if i >= timeout:
            logger.error("Delete collection timeout")

    def describe(self):
        return self._milvus.describe_collection(self._collection_name)

    def show_collections(self):
        return self._milvus.show_collections()

    def exists_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status, res = self._milvus.has_collection(collection_name)
        # self.check_status(status)
        return res

    @time_wrapper
    def preload_collection(self):
        status = self._milvus.preload_collection(self._collection_name,
                                                 timeout=3000)
        self.check_status(status)
        return status

    def get_server_version(self):
        status, res = self._milvus.server_version()
        return res

    def get_server_mode(self):
        return self.cmd("mode")

    def get_server_commit(self):
        return self.cmd("build_commit_id")

    def get_server_config(self):
        return json.loads(self.cmd("get_config *"))

    def get_mem_info(self):
        result = json.loads(self.cmd("get_system_info"))
        result_human = {
            # unit: Gb
            "memory_used":
            round(int(result["memory_used"]) / (1024 * 1024 * 1024), 2)
        }
        return result_human

    def cmd(self, command):
        status, res = self._milvus._cmd(command)
        logger.info("Server command: %s, result: %s" % (command, res))
        self.check_status(status)
        return res
Пример #3
0
def main():
    milvus = Milvus()

    # Print client version
    print('# Client version: {}'.format(milvus.client_version()))

    # Connect milvus server
    # Please change HOST and PORT to the correct one
    param = {'host': _HOST, 'port': _PORT}
    cnn_status = milvus.connect(**param)
    print('# Connect Status: {}'.format(cnn_status))

    # Check if connected
    # is_connected = milvus.connected
    print('# Is connected: {}'.format(milvus.connected))

    # Print milvus server version
    print('# Server version: {}'.format(milvus.server_version()))

    # Describe table
    table_name = 'table01'
    res_status, table = milvus.describe_table(table_name)
    print('# Describe table status: {}'.format(res_status))
    print('# Describe table:{}'.format(table))

    # Create table
    # Check if `table01` exists, if not, create a table `table01`
    dimension = 256
    if not table:
        param = {
            'table_name': table_name,
            'dimension': dimension,
            'index_type': IndexType.IDMAP,
            'store_raw_vector': False
        }

        res_status = milvus.create_table(Prepare.table_schema(**param))
        print('# Create table status: {}'.format(res_status))

    # Show tables and their description
    status, tables = milvus.show_tables()
    pprint(tables)

    # Add vectors
    # Prepare vector with 256 dimension
    vectors = Prepare.records([[random.random() for _ in range(dimension)]
                               for _ in range(20)])

    # Insert vectors into table 'table01'
    status, ids = milvus.add_vectors(table_name=table_name, records=vectors)
    print('# Add vector status: {}'.format(status))
    pprint(ids)

    # Search vectors
    # When adding vectors for the first time, server will take at least 5s to
    # persist vector data, so you have to wait for 6s after adding vectors for
    # the first time.
    print('# Waiting for 6s...')
    time.sleep(6)

    q_records = Prepare.records([[random.random() for _ in range(dimension)]
                                 for _ in range(2)])

    param = {
        'table_name': table_name,
        'query_records': q_records,
        'top_k': 10,
    }
    status, results = milvus.search_vectors(**param)
    print('# Search vectors status: {}'.format(status))
    pprint(results)

    # Get table row count
    status, result = milvus.get_table_row_count(table_name)
    print('# Status: {}'.format(status))
    print('# Count: {}'.format(result))

    # Disconnect
    status = milvus.disconnect()
    print('# Disconnect Status: {}'.format(status))
Пример #4
0
class MilvusClient(object):
    def __init__(self, table_name=None, ip=None, port=None, timeout=60):
        self._milvus = Milvus()
        self._table_name = table_name
        try:
            i = 1
            start_time = time.time()
            if not ip:
                self._milvus.connect(host=SERVER_HOST_DEFAULT,
                                     port=SERVER_PORT_DEFAULT)
            else:
                # retry connect for remote server
                while time.time() < start_time + timeout:
                    try:
                        self._milvus.connect(host=ip, port=port)
                        if self._milvus.connected() is True:
                            logger.debug(
                                "Try connect times: %d, %s" %
                                (i, round(time.time() - start_time, 2)))
                            break
                    except Exception as e:
                        logger.debug("Milvus connect failed")
                        i = i + 1

        except Exception as e:
            raise e

    def __str__(self):
        return 'Milvus table %s' % self._table_name

    def check_status(self, status):
        if not status.OK():
            logger.error(status.message)
            # raise Exception("Status not ok")

    def create_table(self, table_name, dimension, index_file_size,
                     metric_type):
        if not self._table_name:
            self._table_name = table_name
        if metric_type == "l2":
            metric_type = MetricType.L2
        elif metric_type == "ip":
            metric_type = MetricType.IP
        elif metric_type == "jaccard":
            metric_type = MetricType.JACCARD
        elif metric_type == "hamming":
            metric_type = MetricType.HAMMING
        else:
            logger.error("Not supported metric_type: %s" % metric_type)
        create_param = {
            'table_name': table_name,
            'dimension': dimension,
            'index_file_size': index_file_size,
            "metric_type": metric_type
        }
        status = self._milvus.create_table(create_param)
        self.check_status(status)

    @time_wrapper
    def insert(self, X, ids=None):
        status, result = self._milvus.add_vectors(self._table_name, X, ids)
        self.check_status(status)
        return status, result

    @time_wrapper
    def create_index(self, index_type, nlist):
        index_params = {
            "index_type": INDEX_MAP[index_type],
            "nlist": nlist,
        }
        logger.info("Building index start, table_name: %s, index_params: %s" %
                    (self._table_name, json.dumps(index_params)))
        status = self._milvus.create_index(self._table_name,
                                           index=index_params)
        self.check_status(status)

    def describe_index(self):
        status, result = self._milvus.describe_index(self._table_name)
        index_type = None
        for k, v in INDEX_MAP.items():
            if result._index_type == v:
                index_type = k
                break
        nlist = result._nlist
        res = {"index_type": index_type, "nlist": nlist}
        return res

    def drop_index(self):
        logger.info("Drop index: %s" % self._table_name)
        return self._milvus.drop_index(self._table_name)

    @time_wrapper
    def query(self, X, top_k, nprobe):
        status, result = self._milvus.search_vectors(self._table_name, top_k,
                                                     nprobe, X)
        self.check_status(status)
        return result

    def count(self):
        return self._milvus.get_table_row_count(self._table_name)[1]

    def delete(self, timeout=60):
        logger.info("Start delete table: %s" % self._table_name)
        self._milvus.delete_table(self._table_name)
        i = 0
        while i < timeout:
            if self.count():
                time.sleep(1)
                i = i + 1
                continue
            else:
                break
        if i >= timeout:
            logger.error("Delete table timeout")

    def describe(self):
        return self._milvus.describe_table(self._table_name)

    def show_tables(self):
        return self._milvus.show_tables()

    def exists_table(self, table_name=None):
        if table_name is None:
            table_name = self._table_name
        status, res = self._milvus.has_table(table_name)
        self.check_status(status)
        return res

    @time_wrapper
    def preload_table(self):
        return self._milvus.preload_table(self._table_name, timeout=3000)

    def get_server_version(self):
        status, res = self._milvus.server_version()
        return res

    def get_server_mode(self):
        return self.cmd("mode")

    def get_server_commit(self):
        return self.cmd("build_commit_id")

    def get_server_config(self):
        return json.loads(self.cmd("get_config *"))

    def get_mem_info(self):
        result = json.loads(self.cmd("get_system_info"))
        result_human = {
            # unit: Gb
            "memory_used":
            round(int(result["memory_used"]) / (1024 * 1024 * 1024), 2)
        }
        return result_human

    def cmd(self, command):
        status, res = self._milvus._cmd(command)
        logger.info("Server command: %s, result: %s" % (command, res))
        self.check_status(status)
        return res
Пример #5
0
class MilvusANN(object):
    def __init__(self, host='10.46.5.98', port='19530'):
        self.milvus = Milvus()

        print("Client Version:", self.milvus.client_version())

        status = self.milvus.connect(host, port)

        if status.OK():
            print("Server connected.")
        else:
            print("Server connect fail.")
            sys.exit(1)

        print("Server Version:", self.milvus.server_version()[-1])

    def desc(self, tabel_name=None):
        milvus = self.milvus
        milvus.show_collections()
        # milvus.drop_collection()
        if tabel_name:
            print(f"Describe: {milvus.describe_collection(tabel_name)[-1]}")
            print(
                f"Vector number in {tabel_name}: {milvus.count_collection(tabel_name)}"
            )

    def create_tabel_demo(self):
        # Create table demo_table if it dosen't exist.
        milvus = self.milvus
        table_name = 'demo_table'

        status, ok = milvus.has_collection(table_name)
        if not ok:
            param = {
                'collection_name': table_name,
                'dimension': 16,
                'index_file_size':
                1024,  # optional index_file_size:文件到达这个大小的时候,milvus开始为这个文件创建索引。
                'metric_type': MetricType.L2  # optional
            }

            milvus.create_collection(param)

        # Show tables in Milvus server
        _, collections = milvus.show_collections()

        # Describe demo_table
        _, table = milvus.describe_collection(table_name)
        print(table)

    def insert_vectors_demo(self, collection_name):
        milvus = self.milvus

        # 10000 vectors with 16 dimension
        # element per dimension is float32 type
        # vectors should be a 2-D array
        # vectors = [[random.random() for _ in range(16)] for _ in range(10000)]
        vectors = np.random.rand(10000, 16).astype(np.float32).tolist()
        # You can also use numpy to generate random vectors:
        #     `vectors = np.random.rand(10000, 16).astype(np.float32).tolist()`

        # Insert vectors into demo_table, return status and vectors id list
        status, self.ids = milvus.insert(collection_name,
                                         vectors)  # 时间戳 1581655102 786 118

        # Wait for 6 seconds, until Milvus server persist vector data.
        time.sleep(6)

        # Get demo_table row count
        status, result = milvus.count_collection(collection_name)

        # create index of vectors, search more rapidly
        index_param = {'nlist': 2048}

        # Create ivflat index in demo_table
        # You can search vectors without creating index. however, Creating index help to
        # search faster
        status = milvus.create_index(collection_name,
                                     index_type=IndexType.IVFLAT,
                                     params=index_param)

        # describe index, get information of index
        status, index = milvus.describe_index(collection_name)
        print(index)

        # Use the top 10 vectors for similarity search
        self._query_vectors = vectors[0:10]

    def search_vectors_demo(self, query_vectors, collection_name):
        milvus = self.milvus

        # execute vector similarity search
        status, results = milvus.search_vectors(collection_name,
                                                top_k=1,
                                                query_records=query_vectors,
                                                params={'nprobe': 16})
        if status.OK():
            # indicate search result
            # also use by:
            #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
            if results[0][0].distance == 0.0 or results[0][0].id == self.ids[0]:
                print('Query result is correct')
            else:
                print('Query result isn\'t correct')

        # print results
        print(results)

    def drop_table(self, collection_name):
        milvus = self.milvus

        # Delete demo_table
        status = milvus.drop_collection(collection_name)

        # Disconnect from Milvus
        status = milvus.disconnect()
Пример #6
0
class MilvusClient(object):
    def __init__(self, table_name=None, host=None, port=None):
        self._milvus = Milvus()
        self._table_name = table_name
        try:
            if not host:
                self._milvus.connect(host=SERVER_HOST_DEFAULT,
                                     port=SERVER_PORT_DEFAULT)
            else:
                self._milvus.connect(host=host, port=port)
        except Exception as e:
            raise e

    def __str__(self):
        return 'Milvus table %s' % self._table_name

    def check_status(self, status):
        if not status.OK():
            logger.error(status.message)
            raise Exception("Status not ok")

    def create_table(self, table_name, dimension, index_file_size,
                     metric_type):
        if not self._table_name:
            self._table_name = table_name
        if metric_type == "l2":
            metric_type = MetricType.L2
        elif metric_type == "ip":
            metric_type = MetricType.IP
        else:
            logger.error("Not supported metric_type: %s" % metric_type)
        self._metric_type = metric_type
        create_param = {
            'table_name': table_name,
            'dimension': dimension,
            'index_file_size': index_file_size,
            "metric_type": metric_type
        }
        status = self._milvus.create_table(create_param)
        self.check_status(status)

    @time_wrapper
    def insert(self, X, ids):
        if self._metric_type == MetricType.IP:
            logger.info("Set normalize for metric_type: Inner Product")
            X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
        X = X.astype(numpy.float32)
        status, result = self._milvus.add_vectors(self._table_name,
                                                  X.tolist(),
                                                  ids=ids)
        self.check_status(status)
        return status, result

    @time_wrapper
    def create_index(self, index_type, nlist):
        if index_type == "flat":
            index_type = IndexType.FLAT
        elif index_type == "ivf_flat":
            index_type = IndexType.IVFLAT
        elif index_type == "ivf_sq8":
            index_type = IndexType.IVF_SQ8
        elif index_type == "ivf_sq8h":
            index_type = IndexType.IVF_SQ8H
        elif index_type == "nsg":
            index_type = IndexType.NSG
        elif index_type == "ivf_pq":
            index_type = IndexType.IVF_PQ
        index_params = {
            "index_type": index_type,
            "nlist": nlist,
        }
        logger.info("Building index start, table_name: %s, index_params: %s" %
                    (self._table_name, json.dumps(index_params)))
        status = self._milvus.create_index(self._table_name,
                                           index=index_params,
                                           timeout=6 * 3600)
        self.check_status(status)

    def describe_index(self):
        return self._milvus.describe_index(self._table_name)

    def drop_index(self):
        logger.info("Drop index: %s" % self._table_name)
        return self._milvus.drop_index(self._table_name)

    @time_wrapper
    def query(self, X, top_k, nprobe):
        if self._metric_type == MetricType.IP:
            logger.info("Set normalize for metric_type: Inner Product")
            X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
        X = X.astype(numpy.float32)
        status, results = self._milvus.search_vectors(self._table_name, top_k,
                                                      nprobe, X.tolist())
        self.check_status(status)
        ids = []
        for result in results:
            tmp_ids = []
            for item in result:
                tmp_ids.append(item.id)
            ids.append(tmp_ids)
        return ids

    def count(self):
        return self._milvus.get_table_row_count(self._table_name)[1]

    def delete(self, table_name):
        logger.info("Start delete table: %s" % table_name)
        return self._milvus.delete_table(table_name)

    def describe(self):
        return self._milvus.describe_table(self._table_name)

    def exists_table(self, table_name):
        return self._milvus.has_table(table_name)

    def get_server_version(self):
        status, res = self._milvus.server_version()
        self.check_status(status)
        return res

    @time_wrapper
    def preload_table(self):
        return self._milvus.preload_table(self._table_name)
Пример #7
0
class ANN(object):
    def __init__(self, host='10.119.33.90', port='19530', show_info=False):
        self.client = Milvus(host, port)

        if show_info:
            logger.info({
                "ClientVersion": self.client.client_version(),
                "ServerVersion": self.client.server_version()
            })

    def create_collection(self,
                          collection_name,
                          collection_param,
                          partition_tag=None,
                          overwrite=True):
        """

        :param collection_name:
        :param collection_param:
            collection_param = {
                "fields": [
                    #  Milvus doesn't support string type now, but we are considering supporting it soon.
                    #  {"name": "title", "type": DataType.STRING},
                    {"name": "category_", "type": DataType.INT32},
                    {"name": "vector", "type": DataType.FLOAT_VECTOR, "params": {"dim": 768}},
                ],
                "segment_row_limit": 4096,
                "auto_id": False
            }

        :param overwrite:
        :return:
        """
        if self.client.has_collection(collection_name) and overwrite:
            self.client.drop_collection(collection_name)
            self.client.flush()
            time.sleep(5)

            self.client.create_collection(collection_name, collection_param)
        elif self.client.has_collection(collection_name):
            print(f"{collection_name} already exist !!!")
        else:
            self.client.create_collection(collection_name, collection_param)

        if partition_tag is not None:
            self.client.create_partition(collection_name,
                                         partition_tag=partition_tag)

    def create_index(self,
                     collection_name,
                     field_name,
                     index_type='IVF_FLAT',
                     metric_type='IP',
                     index_params=None):
        """
        MetricType:
            INVALID = 0
            L2 = 1
            IP = 2
            # Only supported for byte vectors
            HAMMING = 3
            JACCARD = 4
            TANIMOTO = 5
            #
            SUBSTRUCTURE = 6
            SUPERSTRUCTURE = 7
        IndexType:
            INVALID = 0
            FLAT = 1
            IVFLAT = 2
            IVF_SQ8 = 3
            RNSG = 4
            IVF_SQ8H = 5
            IVF_PQ = 6
            HNSW = 11
            ANNOY = 12

            # alternative name
            IVF_FLAT = IVFLAT
            IVF_SQ8_H = IVF_SQ8H

        class DataType(IntEnum):
            NULL = 0
            INT8 = 1
            INT16 = 2
            INT32 = 3
            INT64 = 4

            STRING = 20

            BOOL = 30

            FLOAT = 40
            DOUBLE = 41

            VECTOR = 100
            UNKNOWN = 9999

        class RangeType(IntEnum):
            LT = 0   # less than
            LTE = 1  # less than or equal
            EQ = 2   # equal
            GT = 3   # greater than
            GTE = 4  # greater than or equal
            NE = 5   # not equal
        :return:
        """
        if index_params is None:
            index_params = {'nlist': 1024}

        params = {
            'index_type': index_type,
            # 'index_file_size': 1024,
            'params': index_params,
            'metric_type': metric_type,
        }
        self.client.create_index(collection_name, field_name,
                                 params)  # field_name='embedding'

    def batch_insert(self, collection_name, entities, batch_size=100000):

        # 分区
        n = len(entities[0]['values'])
        num_part = n // batch_size + 1

        ids = []
        values_list = [_['values'] for _ in entities]
        for i in range(num_part):
            for e, values in zip(entities, values_list):
                e['values'] = values[i * batch_size:(i + 1) * batch_size]
            ids += self.client.insert(collection_name, entities)
            self.client.flush()
        return ids

    def search(self):  # todo: 获取相同的信息
        pass

    def drop_collection(self, collection_name):
        if self.client.has_collection(collection_name):
            self.client.drop_collection(collection_name)

    def drop_partition(self, collection_name, partition_tag):
        if self.client.has_partition(collection_name, partition_tag):
            self.client.drop_partition(collection_name,
                                       partition_tag,
                                       timeout=30)