Пример #1
0
    def test_not_connect(self):
        client = Milvus()

        with pytest.raises(NotConnectError):
            client.create_collection({})

        with pytest.raises(NotConnectError):
            client.has_collection("a")

        with pytest.raises(NotConnectError):
            client.describe_collection("a")

        with pytest.raises(NotConnectError):
            client.drop_collection("a")

        with pytest.raises(NotConnectError):
            client.create_index("a")

        with pytest.raises(NotConnectError):
            client.insert("a", [], None)

        with pytest.raises(NotConnectError):
            client.count_collection("a")

        with pytest.raises(NotConnectError):
            client.show_collections()

        with pytest.raises(NotConnectError):
            client.search("a", 1, 2, [], None)

        with pytest.raises(NotConnectError):
            client.search_in_files("a", [], [], 2, 1, None)

        with pytest.raises(NotConnectError):
            client._cmd("")

        with pytest.raises(NotConnectError):
            client.preload_collection("a")

        with pytest.raises(NotConnectError):
            client.describe_index("a")

        with pytest.raises(NotConnectError):
            client.drop_index("")
Пример #2
0
class MilvusClient(object):
    def __init__(self,
                 collection_name=None,
                 host=None,
                 port=None,
                 timeout=180):
        self._collection_name = collection_name
        start_time = time.time()
        if not host:
            host = SERVER_HOST_DEFAULT
        if not port:
            port = SERVER_PORT_DEFAULT
        logger.debug(host)
        logger.debug(port)
        # retry connect remote server
        i = 0
        while time.time() < start_time + timeout:
            try:
                self._milvus = Milvus(host=host,
                                      port=port,
                                      try_connect=False,
                                      pre_ping=False)
                break
            except Exception as e:
                logger.error(str(e))
                logger.error("Milvus connect failed: %d times" % i)
                i = i + 1
                time.sleep(i)

        if time.time() > start_time + timeout:
            raise Exception("Server connect timeout")
        # self._metric_type = None

    def __str__(self):
        return 'Milvus collection %s' % self._collection_name

    def check_status(self, status):
        if not status.OK():
            logger.error(status.message)
            logger.error(self._milvus.server_status())
            logger.error(self.count())
            raise Exception("Status not ok")

    def check_result_ids(self, result):
        for index, item in enumerate(result):
            if item[0].distance >= epsilon:
                logger.error(index)
                logger.error(item[0].distance)
                raise Exception("Distance wrong")

    # only support the given field name
    def create_collection(self,
                          dimension,
                          data_type=DataType.FLOAT_VECTOR,
                          auto_id=False,
                          collection_name=None,
                          other_fields=None):
        self._dimension = dimension
        if not collection_name:
            collection_name = self._collection_name
        vec_field_name = utils.get_default_field_name(data_type)
        fields = [{
            "name": vec_field_name,
            "type": data_type,
            "params": {
                "dim": dimension
            }
        }]
        if other_fields:
            other_fields = other_fields.split(",")
            if "int" in other_fields:
                fields.append({
                    "name": utils.DEFAULT_INT_FIELD_NAME,
                    "type": DataType.INT64
                })
            if "float" in other_fields:
                fields.append({
                    "name": utils.DEFAULT_FLOAT_FIELD_NAME,
                    "type": DataType.FLOAT
                })
        create_param = {"fields": fields, "auto_id": auto_id}
        try:
            self._milvus.create_collection(collection_name, create_param)
            logger.info("Create collection: <%s> successfully" %
                        collection_name)
        except Exception as e:
            logger.error(str(e))
            raise

    def create_partition(self, tag, collection_name=None):
        if not collection_name:
            collection_name = self._collection_name
        self._milvus.create_partition(collection_name, tag)

    def generate_values(self, data_type, vectors, ids):
        values = None
        if data_type in [DataType.INT32, DataType.INT64]:
            values = ids
        elif data_type in [DataType.FLOAT, DataType.DOUBLE]:
            values = [(i + 0.0) for i in ids]
        elif data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]:
            values = vectors
        return values

    def generate_entities(self, vectors, ids=None, collection_name=None):
        entities = []
        if collection_name is None:
            collection_name = self._collection_name
        info = self.get_info(collection_name)
        for field in info["fields"]:
            field_type = field["type"]
            entities.append({
                "name":
                field["name"],
                "type":
                field_type,
                "values":
                self.generate_values(field_type, vectors, ids)
            })
        return entities

    @time_wrapper
    def insert(self, entities, ids=None, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        try:
            insert_ids = self._milvus.insert(tmp_collection_name,
                                             entities,
                                             ids=ids)
            return insert_ids
        except Exception as e:
            logger.error(str(e))

    def get_dimension(self):
        info = self.get_info()
        for field in info["fields"]:
            if field["type"] in [
                    DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR
            ]:
                return field["params"]["dim"]

    def get_rand_ids(self, length):
        segment_ids = []
        while True:
            stats = self.get_stats()
            segments = stats["partitions"][0]["segments"]
            # random choice one segment
            segment = random.choice(segments)
            try:
                segment_ids = self._milvus.list_id_in_segment(
                    self._collection_name, segment["id"])
            except Exception as e:
                logger.error(str(e))
            if not len(segment_ids):
                continue
            elif len(segment_ids) > length:
                return random.sample(segment_ids, length)
            else:
                logger.debug("Reset length: %d" % len(segment_ids))
                return segment_ids

    # def get_rand_ids_each_segment(self, length):
    #     res = []
    #     status, stats = self._milvus.get_collection_stats(self._collection_name)
    #     self.check_status(status)
    #     segments = stats["partitions"][0]["segments"]
    #     segments_num = len(segments)
    #     # random choice from each segment
    #     for segment in segments:
    #         status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
    #         self.check_status(status)
    #         res.extend(segment_ids[:length])
    #     return segments_num, res

    # def get_rand_entities(self, length):
    #     ids = self.get_rand_ids(length)
    #     status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids)
    #     self.check_status(status)
    #     return ids, get_res

    def get(self):
        get_ids = random.randint(1, 1000000)
        self._milvus.get_entity_by_id(self._collection_name, [get_ids])

    @time_wrapper
    def get_entities(self, get_ids):
        get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids)
        return get_res

    @time_wrapper
    def delete(self, ids, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        self._milvus.delete_entity_by_id(tmp_collection_name, ids)

    def delete_rand(self):
        delete_id_length = random.randint(1, 100)
        count_before = self.count()
        logger.debug("%s: length to delete: %d" %
                     (self._collection_name, delete_id_length))
        delete_ids = self.get_rand_ids(delete_id_length)
        self.delete(delete_ids)
        self.flush()
        logger.info("%s: count after delete: %d" %
                    (self._collection_name, self.count()))
        get_res = self._milvus.get_entity_by_id(self._collection_name,
                                                delete_ids)
        for item in get_res:
            assert not item
        # if count_before - len(delete_ids) < self.count():
        #     logger.error(delete_ids)
        #     raise Exception("Error occured")

    @time_wrapper
    def flush(self, _async=False, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        self._milvus.flush([tmp_collection_name], _async=_async)

    @time_wrapper
    def compact(self, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        status = self._milvus.compact(tmp_collection_name)
        self.check_status(status)

    @time_wrapper
    def create_index(self,
                     field_name,
                     index_type,
                     metric_type,
                     _async=False,
                     index_param=None):
        index_type = INDEX_MAP[index_type]
        metric_type = utils.metric_type_trans(metric_type)
        logger.info(
            "Building index start, collection_name: %s, index_type: %s, metric_type: %s"
            % (self._collection_name, index_type, metric_type))
        if index_param:
            logger.info(index_param)
        index_params = {
            "index_type": index_type,
            "metric_type": metric_type,
            "params": index_param
        }
        self._milvus.create_index(self._collection_name,
                                  field_name,
                                  index_params,
                                  _async=_async)

    # TODO: need to check
    def describe_index(self, field_name):
        # stats = self.get_stats()
        info = self._milvus.describe_index(self._collection_name, field_name)
        index_info = {"index_type": "flat", "index_param": None}
        for field in info["fields"]:
            for index in field['indexes']:
                if not index or "index_type" not in index:
                    continue
                else:
                    for k, v in INDEX_MAP.items():
                        if index['index_type'] == v:
                            index_info['index_type'] = k
                            index_info['index_param'] = index['params']
                            return index_info
        return index_info

    def drop_index(self, field_name):
        logger.info("Drop index: %s" % self._collection_name)
        return self._milvus.drop_index(self._collection_name, field_name)

    @time_wrapper
    def query(self, vector_query, filter_query=None, collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        must_params = [vector_query]
        if filter_query:
            must_params.extend(filter_query)
        query = {"bool": {"must": must_params}}
        result = self._milvus.search(tmp_collection_name, query)
        return result

    @time_wrapper
    def load_and_query(self,
                       vector_query,
                       filter_query=None,
                       collection_name=None):
        tmp_collection_name = self._collection_name if collection_name is None else collection_name
        must_params = [vector_query]
        if filter_query:
            must_params.extend(filter_query)
        query = {"bool": {"must": must_params}}
        self.load_collection(tmp_collection_name)
        result = self._milvus.search(tmp_collection_name, query)
        return result

    def get_ids(self, result):
        idss = result._entities.ids
        ids = []
        len_idss = len(idss)
        len_r = len(result)
        top_k = len_idss // len_r
        for offset in range(0, len_idss, top_k):
            ids.append(idss[offset:min(offset + top_k, len_idss)])
        return ids

    def query_rand(self, nq_max=100):
        # for ivf search
        dimension = 128
        top_k = random.randint(1, 100)
        nq = random.randint(1, nq_max)
        nprobe = random.randint(1, 100)
        search_param = {"nprobe": nprobe}
        query_vectors = [[random.random() for _ in range(dimension)]
                         for _ in range(nq)]
        metric_type = random.choice(["l2", "ip"])
        logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" %
                    (self._collection_name, nq, top_k, nprobe))
        vec_field_name = utils.get_default_field_name()
        vector_query = {
            "vector": {
                vec_field_name: {
                    "topk": top_k,
                    "query": query_vectors,
                    "metric_type": utils.metric_type_trans(metric_type),
                    "params": search_param
                }
            }
        }
        self.query(vector_query)

    def load_query_rand(self, nq_max=100):
        # for ivf search
        dimension = 128
        top_k = random.randint(1, 100)
        nq = random.randint(1, nq_max)
        nprobe = random.randint(1, 100)
        search_param = {"nprobe": nprobe}
        query_vectors = [[random.random() for _ in range(dimension)]
                         for _ in range(nq)]
        metric_type = random.choice(["l2", "ip"])
        logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" %
                    (self._collection_name, nq, top_k, nprobe))
        vec_field_name = utils.get_default_field_name()
        vector_query = {
            "vector": {
                vec_field_name: {
                    "topk": top_k,
                    "query": query_vectors,
                    "metric_type": utils.metric_type_trans(metric_type),
                    "params": search_param
                }
            }
        }
        self.load_and_query(vector_query)

    # TODO: need to check
    def count(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        row_count = self._milvus.get_collection_stats(
            collection_name)["row_count"]
        logger.debug("Row count: %d in collection: <%s>" %
                     (row_count, collection_name))
        return row_count

    def drop(self, timeout=120, collection_name=None):
        timeout = int(timeout)
        if collection_name is None:
            collection_name = self._collection_name
        logger.info("Start delete collection: %s" % collection_name)
        self._milvus.drop_collection(collection_name)
        i = 0
        while i < timeout:
            try:
                row_count = self.count(collection_name=collection_name)
                if row_count:
                    time.sleep(1)
                    i = i + 1
                    continue
                else:
                    break
            except Exception as e:
                logger.debug(str(e))
                break
        if i >= timeout:
            logger.error("Delete collection timeout")

    def get_stats(self):
        return self._milvus.get_collection_stats(self._collection_name)

    def get_info(self, collection_name=None):
        # pdb.set_trace()
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.get_collection_info(collection_name)

    def show_collections(self):
        return self._milvus.list_collections()

    def exists_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        res = self._milvus.has_collection(collection_name)
        return res

    def clean_db(self):
        collection_names = self.show_collections()
        for name in collection_names:
            self.drop(collection_name=name)

    @time_wrapper
    def load_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.load_collection(collection_name, timeout=3000)

    @time_wrapper
    def release_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.release_collection(collection_name, timeout=3000)

    @time_wrapper
    def load_partitions(self, tag_names, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.load_partitions(collection_name,
                                            tag_names,
                                            timeout=3000)

    @time_wrapper
    def release_partitions(self, tag_names, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        return self._milvus.release_partitions(collection_name,
                                               tag_names,
                                               timeout=3000)
Пример #3
0
# ------
# Basic hybrid search entities
# ------
results = client.search(collection_name,
                        query_hybrid,
                        fields=["release_year", "embedding"])
for entities in results:
    for topk_film in entities:
        current_entity = topk_film.entity
        print("==")
        print("- id: {}".format(topk_film.id))
        print("- title: {}".format(titles[topk_film.id]))
        print("- distance: {}".format(topk_film.distance))

        print("- release_year: {}".format(current_entity.release_year))
        print("- embedding: {}".format(current_entity.embedding))

# ------
# Basic delete index:
#     You can drop index for a field.
# ------
client.drop_index(collection_name, "embedding")

if collection_name in client.list_collections():
    client.drop_collection(collection_name)

# ------
# Summary:
#     Now we've went through some basic build index operations, hope it's helpful!
# ------
Пример #4
0
class MilvusClient(object):
    def __init__(self, table_name=None, ip=None, port=None):
        self._milvus = Milvus()
        self._table_name = table_name
        try:
            if not ip:
                self._milvus.connect(
                    host = SERVER_HOST_DEFAULT,
                    port = SERVER_PORT_DEFAULT)
            else:
                self._milvus.connect(
                    host = ip,
                    port = port)
        except Exception as e:
            raise e

    def __str__(self):
        return 'Milvus table %s' % self._table_name

    def check_status(self, status):
        if not status.OK():
            logger.error(status.message)
            raise Exception("Status not ok")

    def create_table(self, table_name, dimension, index_file_size, metric_type):
        if not self._table_name:
            self._table_name = table_name
        if metric_type == "l2":
            metric_type = MetricType.L2
        elif metric_type == "ip":
            metric_type = MetricType.IP
        else:
            logger.error("Not supported metric_type: %s" % metric_type)
        create_param = {'table_name': table_name,
                 'dimension': dimension,
                 'index_file_size': index_file_size, 
                 "metric_type": metric_type}
        status = self._milvus.create_table(create_param)
        self.check_status(status)

    @time_wrapper
    def insert(self, X, ids=None):
        status, result = self._milvus.add_vectors(self._table_name, X, ids)
        self.check_status(status)
        return status, result

    @time_wrapper
    def create_index(self, index_type, nlist):
        if index_type == "flat":
            index_type = IndexType.FLAT
        elif index_type == "ivf_flat":
            index_type = IndexType.IVFLAT
        elif index_type == "ivf_sq8":
            index_type = IndexType.IVF_SQ8
        elif index_type == "nsg":
            index_type = IndexType.NSG
        elif index_type == "ivf_sq8h":
            index_type = IndexType.IVF_SQ8H
        elif index_type == "ivf_pq":
            index_type = IndexType.IVF_PQ
        index_params = {
            "index_type": index_type,
            "nlist": nlist,
        }
        logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params)))
        status = self._milvus.create_index(self._table_name, index=index_params)
        self.check_status(status)

    def describe_index(self):
        return self._milvus.describe_index(self._table_name)

    def drop_index(self):
        logger.info("Drop index: %s" % self._table_name)
        return self._milvus.drop_index(self._table_name)

    @time_wrapper
    def query(self, X, top_k, nprobe):
        status, result = self._milvus.search_vectors(self._table_name, top_k, nprobe, X)
        self.check_status(status)
        return status, result

    def count(self):
        return self._milvus.get_table_row_count(self._table_name)[1]

    def delete(self, timeout=60):
        logger.info("Start delete table: %s" % self._table_name)
        self._milvus.delete_table(self._table_name)
        i = 0
        while i < timeout:
            if self.count():
                time.sleep(1)
                i = i + 1
                continue
            else:
                break
        if i < timeout:
            logger.error("Delete table timeout")

    def describe(self):
        return self._milvus.describe_table(self._table_name)

    def show_tables(self):
        return self._milvus.show_tables()

    def exists_table(self):
        status, res = self._milvus.has_table(self._table_name)
        self.check_status(status)
        return res

    @time_wrapper
    def preload_table(self):
        return self._milvus.preload_table(self._table_name, timeout=3000)
Пример #5
0
class MilvusClient(object):
    def __init__(self, collection_name=None, ip=None, port=None, timeout=60):
        self._collection_name = collection_name
        try:
            i = 1
            start_time = time.time()
            if not ip:
                self._milvus = Milvus(host=SERVER_HOST_DEFAULT,
                                      port=SERVER_PORT_DEFAULT)
            else:
                # retry connect for remote server
                while time.time() < start_time + timeout:
                    try:
                        self._milvus = Milvus(host=ip, port=port)
                        if self._milvus.server_status():
                            logger.debug(
                                "Try connect times: %d, %s" %
                                (i, round(time.time() - start_time, 2)))
                            break
                    except Exception as e:
                        logger.debug("Milvus connect failed")
                        i = i + 1

        except Exception as e:
            raise e

    def __str__(self):
        return 'Milvus collection %s' % self._collection_name

    def check_status(self, status):
        if not status.OK():
            logger.error(status.message)
            raise Exception("Status not ok")

    def check_result_ids(self, result):
        for index, item in enumerate(result):
            if item[0].distance >= epsilon:
                logger.error(index)
                logger.error(item[0].distance)
                raise Exception("Distance wrong")

    def create_collection(self, collection_name, dimension, index_file_size,
                          metric_type):
        if not self._collection_name:
            self._collection_name = collection_name
        if metric_type == "l2":
            metric_type = MetricType.L2
        elif metric_type == "ip":
            metric_type = MetricType.IP
        elif metric_type == "jaccard":
            metric_type = MetricType.JACCARD
        elif metric_type == "hamming":
            metric_type = MetricType.HAMMING
        elif metric_type == "sub":
            metric_type = MetricType.SUBSTRUCTURE
        elif metric_type == "super":
            metric_type = MetricType.SUPERSTRUCTURE
        else:
            logger.error("Not supported metric_type: %s" % metric_type)
        create_param = {
            'collection_name': collection_name,
            'dimension': dimension,
            'index_file_size': index_file_size,
            "metric_type": metric_type
        }
        status = self._milvus.create_collection(create_param)
        self.check_status(status)

    @time_wrapper
    def insert(self, X, ids=None):
        status, result = self._milvus.add_vectors(self._collection_name, X,
                                                  ids)
        self.check_status(status)
        return status, result

    @time_wrapper
    def delete_vectors(self, ids):
        status = self._milvus.delete_by_id(self._collection_name, ids)
        self.check_status(status)

    @time_wrapper
    def flush(self):
        status = self._milvus.flush([self._collection_name])
        self.check_status(status)

    @time_wrapper
    def compact(self):
        status = self._milvus.compact(self._collection_name)
        self.check_status(status)

    @time_wrapper
    def create_index(self, index_type, index_param=None):
        index_type = INDEX_MAP[index_type]
        logger.info(
            "Building index start, collection_name: %s, index_type: %s" %
            (self._collection_name, index_type))
        if index_param:
            logger.info(index_param)
        status = self._milvus.create_index(self._collection_name, index_type,
                                           index_param)
        self.check_status(status)

    def describe_index(self):
        status, result = self._milvus.describe_index(self._collection_name)
        self.check_status(status)
        index_type = None
        for k, v in INDEX_MAP.items():
            if result._index_type == v:
                index_type = k
                break
        return {"index_type": index_type, "index_param": result._params}

    def drop_index(self):
        logger.info("Drop index: %s" % self._collection_name)
        return self._milvus.drop_index(self._collection_name)

    @time_wrapper
    def query(self, X, top_k, search_param=None):
        status, result = self._milvus.search_vectors(self._collection_name,
                                                     top_k,
                                                     query_records=X,
                                                     params=search_param)
        self.check_status(status)
        return result

    @time_wrapper
    def query_ids(self, top_k, ids, search_param=None):
        status, result = self._milvus.search_by_ids(self._collection_name,
                                                    ids,
                                                    top_k,
                                                    params=search_param)
        self.check_result_ids(result)
        return result

    def count(self):
        return self._milvus.count_collection(self._collection_name)[1]

    def delete(self, timeout=120):
        timeout = int(timeout)
        logger.info("Start delete collection: %s" % self._collection_name)
        self._milvus.drop_collection(self._collection_name)
        i = 0
        while i < timeout:
            if self.count():
                time.sleep(1)
                i = i + 1
                continue
            else:
                break
        if i >= timeout:
            logger.error("Delete collection timeout")

    def describe(self):
        return self._milvus.describe_collection(self._collection_name)

    def show_collections(self):
        return self._milvus.show_collections()

    def exists_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status, res = self._milvus.has_collection(collection_name)
        # self.check_status(status)
        return res

    @time_wrapper
    def preload_collection(self):
        status = self._milvus.preload_collection(self._collection_name,
                                                 timeout=3000)
        self.check_status(status)
        return status

    def get_server_version(self):
        status, res = self._milvus.server_version()
        return res

    def get_server_mode(self):
        return self.cmd("mode")

    def get_server_commit(self):
        return self.cmd("build_commit_id")

    def get_server_config(self):
        return json.loads(self.cmd("get_config *"))

    def get_mem_info(self):
        result = json.loads(self.cmd("get_system_info"))
        result_human = {
            # unit: Gb
            "memory_used":
            round(int(result["memory_used"]) / (1024 * 1024 * 1024), 2)
        }
        return result_human

    def cmd(self, command):
        status, res = self._milvus._cmd(command)
        logger.info("Server command: %s, result: %s" % (command, res))
        self.check_status(status)
        return res
Пример #6
0
class MilvusClient(object):
    def __init__(self, collection_name=None, host=None, port=None, timeout=60):
        """
        Milvus client wrapper for python-sdk.

        Default timeout set 60s
        """
        self._collection_name = collection_name
        try:
            start_time = time.time()
            if not host:
                host = SERVER_HOST_DEFAULT
            if not port:
                port = SERVER_PORT_DEFAULT
            logger.debug(host)
            logger.debug(port)
            # retry connect for remote server
            i = 0
            while time.time() < start_time + timeout:
                try:
                    self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False)
                    if self._milvus.server_status():
                        logger.debug("Try connect times: %d, %s" % (i, round(time.time() - start_time, 2)))
                        break
                except Exception as e:
                    logger.debug("Milvus connect failed: %d times" % i)
                    i = i + 1

            if time.time() > start_time + timeout:
                raise Exception("Server connect timeout")

        except Exception as e:
            raise e
        self._metric_type = None
        if self._collection_name and self.exists_collection():
            self._metric_type = metric_type_to_str(self.describe()[1].metric_type)
            self._dimension = self.describe()[1].dimension

    def __str__(self):
        return 'Milvus collection %s' % self._collection_name

    def set_collection(self, name):
        self._collection_name = name

    def check_status(self, status):
        if not status.OK():
            logger.error(self._collection_name)
            logger.error(status.message)
            logger.error(self._milvus.server_status())
            logger.error(self.count())
            raise Exception("Status not ok")

    def check_result_ids(self, result):
        for index, item in enumerate(result):
            if item[0].distance >= epsilon:
                logger.error(index)
                logger.error(item[0].distance)
                raise Exception("Distance wrong")

    def create_collection(self, collection_name, dimension, index_file_size, metric_type):
        if not self._collection_name:
            self._collection_name = collection_name
        if metric_type not in METRIC_MAP.keys():
            raise Exception("Not supported metric_type: %s" % metric_type)
        metric_type = METRIC_MAP[metric_type]
        create_param = {'collection_name': collection_name,
                 'dimension': dimension,
                 'index_file_size': index_file_size, 
                 "metric_type": metric_type}
        status = self._milvus.create_collection(create_param)
        self.check_status(status)

    def create_partition(self, tag_name):
        status = self._milvus.create_partition(self._collection_name, tag_name)
        self.check_status(status)

    def drop_partition(self, tag_name):
        status = self._milvus.drop_partition(self._collection_name, tag_name)
        self.check_status(status)

    def list_partitions(self):
        status, tags = self._milvus.list_partitions(self._collection_name)
        self.check_status(status)
        return tags

    @time_wrapper
    def insert(self, X, ids=None, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status, result = self._milvus.insert(collection_name, X, ids)
        self.check_status(status)
        return status, result

    def insert_rand(self):
        insert_xb = random.randint(1, 100)
        X = [[random.random() for _ in range(self._dimension)] for _ in range(insert_xb)]
        X = utils.normalize(self._metric_type, X)
        count_before = self.count()
        status, _ = self.insert(X)
        self.check_status(status)
        self.flush()
        if count_before + insert_xb != self.count():
            raise Exception("Assert failed after inserting")

    def get_rand_ids(self, length):
        while True:
            status, stats = self._milvus.get_collection_stats(self._collection_name)
            self.check_status(status)
            segments = stats["partitions"][0]["segments"]
            # random choice one segment
            segment = random.choice(segments)
            status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
            if not status.OK():
                logger.error(status.message)
                continue
            if len(segment_ids):
                break
        if length >= len(segment_ids):
            logger.debug("Reset length: %d" % len(segment_ids))
            return segment_ids
        return random.sample(segment_ids, length)

    def get_rand_ids_each_segment(self, length):
        res = []
        status, stats = self._milvus.get_collection_stats(self._collection_name)
        self.check_status(status)
        segments = stats["partitions"][0]["segments"]
        segments_num = len(segments)
        # random choice from each segment
        for segment in segments:
            status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"])
            self.check_status(status)
            res.extend(segment_ids[:length])
        return segments_num, res

    def get_rand_entities(self, length):
        ids = self.get_rand_ids(length)
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids)
        self.check_status(status)
        return ids, get_res

    @time_wrapper
    def get_entities(self, get_ids):
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids)
        self.check_status(status)
        return get_res

    @time_wrapper
    def delete(self, ids, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status = self._milvus.delete_entity_by_id(collection_name, ids)
        self.check_status(status)

    def delete_rand(self):
        delete_id_length = random.randint(1, 100)
        count_before = self.count()
        logger.info("%s: length to delete: %d" % (self._collection_name, delete_id_length))
        delete_ids = self.get_rand_ids(delete_id_length)
        self.delete(delete_ids)
        self.flush()
        logger.info("%s: count after delete: %d" % (self._collection_name, self.count()))
        status, get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids)
        self.check_status(status)
        for item in get_res:
            if item:
                raise Exception("Assert failed after delete")
        if count_before - len(delete_ids) != self.count():
            raise Exception("Assert failed after delete")

    @time_wrapper
    def flush(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status = self._milvus.flush([collection_name])
        self.check_status(status)

    @time_wrapper
    def compact(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status = self._milvus.compact(collection_name)
        self.check_status(status)

    @time_wrapper
    def create_index(self, index_type, index_param=None):
        index_type = INDEX_MAP[index_type]
        logger.info("Building index start, collection_name: %s, index_type: %s" % (self._collection_name, index_type))
        if index_param:
            logger.info(index_param)
        status = self._milvus.create_index(self._collection_name, index_type, index_param)
        self.check_status(status)

    def describe_index(self):
        status, result = self._milvus.get_index_info(self._collection_name)
        self.check_status(status)
        index_type = None
        for k, v in INDEX_MAP.items():
            if result._index_type == v:
                index_type = k
                break
        return {"index_type": index_type, "index_param": result._params}

    def drop_index(self):
        logger.info("Drop index: %s" % self._collection_name)
        return self._milvus.drop_index(self._collection_name)

    def query(self, X, top_k, search_param=None, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        status, result = self._milvus.search(collection_name, top_k, query_records=X, params=search_param)
        self.check_status(status)
        return result

    def query_rand(self):
        top_k = random.randint(1, 100)
        nq = random.randint(1, 100)
        nprobe = random.randint(1, 100)
        search_param = {"nprobe": nprobe}
        _, X = self.get_rand_entities(nq)
        logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe))
        status, _ = self._milvus.search(self._collection_name, top_k, query_records=X, params=search_param)
        self.check_status(status)
        # for i, item in enumerate(search_res):
        #     if item[0].id != ids[i]:
        #         logger.warning("The index of search result: %d" % i)
        #         raise Exception("Query failed")

    # @time_wrapper
    # def query_ids(self, top_k, ids, search_param=None):
    #     status, result = self._milvus.search_by_id(self._collection_name, ids, top_k, params=search_param)
    #     self.check_result_ids(result)
    #     return result

    def count(self, name=None):
        if name is None:
            name = self._collection_name
        logger.debug(self._milvus.count_entities(name))
        row_count = self._milvus.count_entities(name)[1]
        if not row_count:
            row_count = 0
        logger.debug("Row count: %d in collection: <%s>" % (row_count, name))
        return row_count

    def drop(self, timeout=120, name=None):
        timeout = int(timeout)
        if name is None:
            name = self._collection_name
        logger.info("Start delete collection: %s" % name)
        status = self._milvus.drop_collection(name)
        self.check_status(status)
        i = 0
        while i < timeout:
            if self.count(name=name):
                time.sleep(1)
                i = i + 1
                continue
            else:
                break
        if i >= timeout:
            logger.error("Delete collection timeout")

    def describe(self):
        # logger.info(self._milvus.get_collection_info(self._collection_name))
        return self._milvus.get_collection_info(self._collection_name)

    def show_collections(self):
        return self._milvus.list_collections()

    def exists_collection(self, collection_name=None):
        if collection_name is None:
            collection_name = self._collection_name
        _, res = self._milvus.has_collection(collection_name)
        # self.check_status(status)
        return res

    def clean_db(self):
        collection_names = self.show_collections()[1]
        for name in collection_names:
            logger.debug(name)
            self.drop(name=name)

    @time_wrapper
    def preload_collection(self):
        status = self._milvus.load_collection(self._collection_name, timeout=3000)
        self.check_status(status)
        return status

    def get_server_version(self):
        _, res = self._milvus.server_version()
        return res

    def get_server_mode(self):
        return self.cmd("mode")

    def get_server_commit(self):
        return self.cmd("build_commit_id")

    def get_server_config(self):
        return json.loads(self.cmd("get_config *"))

    def get_mem_info(self):
        result = json.loads(self.cmd("get_system_info"))
        result_human = {
            # unit: Gb
            "memory_used": round(int(result["memory_used"]) / (1024*1024*1024), 2)
        }
        return result_human

    def cmd(self, command):
        status, res = self._milvus._cmd(command)
        logger.info("Server command: %s, result: %s" % (command, res))
        self.check_status(status)
        return res
Пример #7
0
def main():
    # Specify server addr when create milvus client instance
    # milvus client instance maintain a connection pool, param
    # `pool_size` specify the max connection num.
    milvus = Milvus(_HOST, _PORT)

    # Create collection demo_collection if it dosen't exist.
    collection_name = 'example_collection'

    ok = milvus.has_collection(collection_name)
    field_name = 'example_field'
    if not ok:
        fields = {
            "fields": [{
                "name": field_name,
                "type": DataType.FLOAT_VECTOR,
                "metric_type": "L2",
                "params": {
                    "dim": _DIM
                },
                "indexes": [{
                    "metric_type": "L2"
                }]
            }]
        }

        milvus.create_collection(collection_name=collection_name,
                                 fields=fields)
    else:
        milvus.drop_collection(collection_name=collection_name)

    # Show collections in Milvus server
    collections = milvus.list_collections()
    print(collections)

    # Describe demo_collection
    stats = milvus.get_collection_stats(collection_name)
    print(stats)

    # 10000 vectors with 128 dimension
    # element per dimension is float32 type
    # vectors should be a 2-D array
    vectors = [[random.random() for _ in range(_DIM)] for _ in range(10)]
    print(vectors)
    # You can also use numpy to generate random vectors:
    #   vectors = np.random.rand(10000, _DIM).astype(np.float32)

    # Insert vectors into demo_collection, return status and vectors id list
    entities = [{
        "name": field_name,
        "type": DataType.FLOAT_VECTOR,
        "values": vectors
    }]

    res_ids = milvus.insert(collection_name=collection_name, entities=entities)
    print("ids:", res_ids)

    # Flush collection  inserted data to disk.
    milvus.flush([collection_name])

    # present collection statistics info
    stats = milvus.get_collection_stats(collection_name)
    print(stats)

    # create index of vectors, search more rapidly
    index_param = {
        "metric_type": "L2",
        "index_type": "IVF_FLAT",
        "params": {
            "nlist": 1024
        }
    }

    # Create ivflat index in demo_collection
    # You can search vectors without creating index. however, Creating index help to
    # search faster
    print("Creating index: {}".format(index_param))
    status = milvus.create_index(collection_name, field_name, index_param)

    # execute vector similarity search

    print("Searching ... ")

    dsl = {
        "bool": {
            "must": [{
                "vector": {
                    field_name: {
                        "metric_type": "L2",
                        "query": vectors,
                        "topk": 10,
                        "params": {
                            "nprobe": 16
                        }
                    }
                }
            }]
        }
    }

    milvus.load_collection(collection_name)
    results = milvus.search(collection_name, dsl)
    # indicate search result
    # also use by:
    #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
    if results[0][0].distance == 0.0 or results[0][0].id == ids[0]:
        print('Query result is correct')
    else:
        print('Query result isn\'t correct')

    milvus.drop_index(collection_name, field_name)
    milvus.release_collection(collection_name)

    # Delete demo_collection
    status = milvus.drop_collection(collection_name)
Пример #8
0
class MilvusClient(object):
    def __init__(self, table_name=None, ip=None, port=None, timeout=60):
        self._milvus = Milvus()
        self._table_name = table_name
        try:
            i = 1
            start_time = time.time()
            if not ip:
                self._milvus.connect(host=SERVER_HOST_DEFAULT,
                                     port=SERVER_PORT_DEFAULT)
            else:
                # retry connect for remote server
                while time.time() < start_time + timeout:
                    try:
                        self._milvus.connect(host=ip, port=port)
                        if self._milvus.connected() is True:
                            logger.debug(
                                "Try connect times: %d, %s" %
                                (i, round(time.time() - start_time, 2)))
                            break
                    except Exception as e:
                        logger.debug("Milvus connect failed")
                        i = i + 1

        except Exception as e:
            raise e

    def __str__(self):
        return 'Milvus table %s' % self._table_name

    def check_status(self, status):
        if not status.OK():
            logger.error(status.message)
            # raise Exception("Status not ok")

    def create_table(self, table_name, dimension, index_file_size,
                     metric_type):
        if not self._table_name:
            self._table_name = table_name
        if metric_type == "l2":
            metric_type = MetricType.L2
        elif metric_type == "ip":
            metric_type = MetricType.IP
        elif metric_type == "jaccard":
            metric_type = MetricType.JACCARD
        elif metric_type == "hamming":
            metric_type = MetricType.HAMMING
        else:
            logger.error("Not supported metric_type: %s" % metric_type)
        create_param = {
            'table_name': table_name,
            'dimension': dimension,
            'index_file_size': index_file_size,
            "metric_type": metric_type
        }
        status = self._milvus.create_table(create_param)
        self.check_status(status)

    @time_wrapper
    def insert(self, X, ids=None):
        status, result = self._milvus.add_vectors(self._table_name, X, ids)
        self.check_status(status)
        return status, result

    @time_wrapper
    def create_index(self, index_type, nlist):
        index_params = {
            "index_type": INDEX_MAP[index_type],
            "nlist": nlist,
        }
        logger.info("Building index start, table_name: %s, index_params: %s" %
                    (self._table_name, json.dumps(index_params)))
        status = self._milvus.create_index(self._table_name,
                                           index=index_params)
        self.check_status(status)

    def describe_index(self):
        status, result = self._milvus.describe_index(self._table_name)
        index_type = None
        for k, v in INDEX_MAP.items():
            if result._index_type == v:
                index_type = k
                break
        nlist = result._nlist
        res = {"index_type": index_type, "nlist": nlist}
        return res

    def drop_index(self):
        logger.info("Drop index: %s" % self._table_name)
        return self._milvus.drop_index(self._table_name)

    @time_wrapper
    def query(self, X, top_k, nprobe):
        status, result = self._milvus.search_vectors(self._table_name, top_k,
                                                     nprobe, X)
        self.check_status(status)
        return result

    def count(self):
        return self._milvus.get_table_row_count(self._table_name)[1]

    def delete(self, timeout=60):
        logger.info("Start delete table: %s" % self._table_name)
        self._milvus.delete_table(self._table_name)
        i = 0
        while i < timeout:
            if self.count():
                time.sleep(1)
                i = i + 1
                continue
            else:
                break
        if i >= timeout:
            logger.error("Delete table timeout")

    def describe(self):
        return self._milvus.describe_table(self._table_name)

    def show_tables(self):
        return self._milvus.show_tables()

    def exists_table(self, table_name=None):
        if table_name is None:
            table_name = self._table_name
        status, res = self._milvus.has_table(table_name)
        self.check_status(status)
        return res

    @time_wrapper
    def preload_table(self):
        return self._milvus.preload_table(self._table_name, timeout=3000)

    def get_server_version(self):
        status, res = self._milvus.server_version()
        return res

    def get_server_mode(self):
        return self.cmd("mode")

    def get_server_commit(self):
        return self.cmd("build_commit_id")

    def get_server_config(self):
        return json.loads(self.cmd("get_config *"))

    def get_mem_info(self):
        result = json.loads(self.cmd("get_system_info"))
        result_human = {
            # unit: Gb
            "memory_used":
            round(int(result["memory_used"]) / (1024 * 1024 * 1024), 2)
        }
        return result_human

    def cmd(self, command):
        status, res = self._milvus._cmd(command)
        logger.info("Server command: %s, result: %s" % (command, res))
        self.check_status(status)
        return res
Пример #9
0
class MilvusClient(object):
    def __init__(self, table_name=None, host=None, port=None):
        self._milvus = Milvus()
        self._table_name = table_name
        try:
            if not host:
                self._milvus.connect(host=SERVER_HOST_DEFAULT,
                                     port=SERVER_PORT_DEFAULT)
            else:
                self._milvus.connect(host=host, port=port)
        except Exception as e:
            raise e

    def __str__(self):
        return 'Milvus table %s' % self._table_name

    def check_status(self, status):
        if not status.OK():
            logger.error(status.message)
            raise Exception("Status not ok")

    def create_table(self, table_name, dimension, index_file_size,
                     metric_type):
        if not self._table_name:
            self._table_name = table_name
        if metric_type == "l2":
            metric_type = MetricType.L2
        elif metric_type == "ip":
            metric_type = MetricType.IP
        else:
            logger.error("Not supported metric_type: %s" % metric_type)
        self._metric_type = metric_type
        create_param = {
            'table_name': table_name,
            'dimension': dimension,
            'index_file_size': index_file_size,
            "metric_type": metric_type
        }
        status = self._milvus.create_table(create_param)
        self.check_status(status)

    @time_wrapper
    def insert(self, X, ids):
        if self._metric_type == MetricType.IP:
            logger.info("Set normalize for metric_type: Inner Product")
            X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
        X = X.astype(numpy.float32)
        status, result = self._milvus.add_vectors(self._table_name,
                                                  X.tolist(),
                                                  ids=ids)
        self.check_status(status)
        return status, result

    @time_wrapper
    def create_index(self, index_type, nlist):
        if index_type == "flat":
            index_type = IndexType.FLAT
        elif index_type == "ivf_flat":
            index_type = IndexType.IVFLAT
        elif index_type == "ivf_sq8":
            index_type = IndexType.IVF_SQ8
        elif index_type == "ivf_sq8h":
            index_type = IndexType.IVF_SQ8H
        elif index_type == "nsg":
            index_type = IndexType.NSG
        elif index_type == "ivf_pq":
            index_type = IndexType.IVF_PQ
        index_params = {
            "index_type": index_type,
            "nlist": nlist,
        }
        logger.info("Building index start, table_name: %s, index_params: %s" %
                    (self._table_name, json.dumps(index_params)))
        status = self._milvus.create_index(self._table_name,
                                           index=index_params,
                                           timeout=6 * 3600)
        self.check_status(status)

    def describe_index(self):
        return self._milvus.describe_index(self._table_name)

    def drop_index(self):
        logger.info("Drop index: %s" % self._table_name)
        return self._milvus.drop_index(self._table_name)

    @time_wrapper
    def query(self, X, top_k, nprobe):
        if self._metric_type == MetricType.IP:
            logger.info("Set normalize for metric_type: Inner Product")
            X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
        X = X.astype(numpy.float32)
        status, results = self._milvus.search_vectors(self._table_name, top_k,
                                                      nprobe, X.tolist())
        self.check_status(status)
        ids = []
        for result in results:
            tmp_ids = []
            for item in result:
                tmp_ids.append(item.id)
            ids.append(tmp_ids)
        return ids

    def count(self):
        return self._milvus.get_table_row_count(self._table_name)[1]

    def delete(self, table_name):
        logger.info("Start delete table: %s" % table_name)
        return self._milvus.delete_table(table_name)

    def describe(self):
        return self._milvus.describe_table(self._table_name)

    def exists_table(self, table_name):
        return self._milvus.has_table(table_name)

    def get_server_version(self):
        status, res = self._milvus.server_version()
        self.check_status(status)
        return res

    @time_wrapper
    def preload_table(self):
        return self._milvus.preload_table(self._table_name)
Пример #10
0
class MyMilvus():

    def __init__(self, name, host, port, collection_param):
        self.host = host
        self.port = port
        self.client = Milvus(host, port)
        self.collection_name = name
        self.collection_param = collection_param

        # self.collection_param = {
        #     "fields": [
        #         {"name": "release_year", "type": DataType.INT32},
        #         {"name": "embedding", "type": DataType.FLOAT_VECTOR, "params": {"dim": 8}},
        #     ],
        #     "segment_row_limit": 4096,
        #     "auto_id": False
        # }

    def create_collection(self):
        if self.collection_name not in self.client.list_collections():
            # self.client.drop_collection(self.collection_name)
            self.client.create_collection(self.collection_name, self.collection_param)

        # ------
        # Basic create index:
        #     Now that we have a collection in Milvus with `segment_row_limit` 4096, we can create index or
        #     insert entities.
        #
        #     We can call `create_index` BEFORE we insert any entities or AFTER. However Milvus won't actually
        #     start build index task if the segment row count is smaller than `segment_row_limit`. So if we want
        #     to make Milvus build index, we need to insert number of entities larger than `segment_row_limit`.
        #
        #     We are going to use data in `films.csv` so you can checkout the structure. And we need to group
        #     data with same fields together, so here is a example of how we obtain the data in files and transfer
        #     them into what we need.
        # ------

        ids = []  # ids
        titles = []  # titles
        release_years = []  # release year
        embeddings = []  # embeddings
        films = []
        with open('films.csv', 'r') as csvfile:
            reader = csv.reader(csvfile)
            films = [film for film in reader]
        for film in films:
            ids.append(int(film[0]))
            titles.append(film[1])
            release_years.append(int(film[2]))
            embeddings.append(list(map(float, film[3][1:][:-1].split(','))))

        hybrid_entities = [
            {"name": "release_year", "values": release_years, "type": DataType.INT32},
            {"name": "embedding", "values": embeddings, "type": DataType.FLOAT_VECTOR},
        ]

        # ------
        # Basic insert:
        #     After preparing the data, we are going to insert them into our collection.
        #     The number of films inserted should be 8657.
        # ------
        ids = self.client.insert(self.collection_name, hybrid_entities, ids)
        self.client.flush([self.collection_name])
        after_flush_counts = self.client.count_entities(self.collection_name)
        print(" > There are {} films in collection `{}` after flush".format(after_flush_counts, self.collection_name))

        # ------
        # Basic create index:
        #     Now that we have inserted all the films into Milvus, we are going to build index with these data.
        #
        #     While building index, we have to indicate which `field` to build index for, the `index_type`,
        #     `metric_type` and params for the specific index type. In our case, we want to build a `IVF_FLAT`
        #     index, so the specific params are "nlist". See pymilvus documentation
        #     (https://milvus-io.github.io/milvus-sdk-python/pythondoc/v0.3.0/index.html) for `index_type` we
        #     support and the params accordingly.
        #
        #     If there are already index for a collection and you call `create_index` with different params, the
        #     older index will be replaced by new one.
        # ------
        self.client.create_index(self.collection_name, "embedding",
                                 {"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 100}})

        # ------
        # Basic create index:
        #     We can get the detail of the index  by `get_collection_info`.
        # ------
        info = self.client.get_collection_info(self.collection_name)
        pprint(info)

        # ------
        # Basic hybrid search entities:
        #     If we want to use index, the specific index params need to be provided, in our case, the "params"
        #     should be "nprobe", if no "params" given, Milvus will complain about it and raise a exception.
        # ------
        # query_embedding = [random.random() for _ in range(8)]
        # query_hybrid = {
        #     "bool": {
        #         "must": [
        #             {
        #                 "term": {"release_year": [2002, 1995]}
        #             },
        #             {
        #                 "vector": {
        #                     "embedding": {"topk": 3,
        #                                   "query": [query_embedding],
        #                                   "metric_type": "L2",
        #                                   "params": {"nprobe": 8}}
        #                 }
        #             }
        #         ]
        #     }
        # }

        # ------
        # Basic hybrid search entities
        # ------
        # results = client.search(collection_name, query_hybrid, fields=["release_year", "embedding"])
        # for entities in results:
        #     for topk_film in entities:
        #         current_entity = topk_film.entity
        #         print("==")
        #         print("- id: {}".format(topk_film.id))
        #         print("- title: {}".format(titles[topk_film.id]))
        #         print("- distance: {}".format(topk_film.distance))
        #
        #         print("- release_year: {}".format(current_entity.release_year))
        #         print("- embedding: {}".format(current_entity.embedding))

        # ------
        # Basic delete index:
        #     You can drop index for a field.
        # ------
        self.client.drop_index(self.collection_name, "embedding")

        if self.collection_name in self.client.list_collections():
            self.client.drop_collection(self.collection_name)