コード例 #1
0
class Ingester:
    def __init__(self,
                 host,
                 port,
                 collection,
                 collection_param,
                 partition=None,
                 drop=False,
                 batch_size=100,
                 dtype=np.float32):
        self.collection = collection
        self.partition = partition
        self.client = Milvus(host, port)
        self.dtype = dtype
        self.batch_size = batch_size
        if drop and self.client.has_collection(collection):
            self.client.drop_collection(collection)
        if not self.client.has_collection(collection):
            self.client.create_collection(collection, collection_param)
        if partition and not self.client.has_partition(collection, partition):
            self.client.create_partition(collection, partition)

    def ingest(self, entities, ids):
        if self.partition:
            return self.client.insert(self.collection,
                                      entities,
                                      ids=ids,
                                      partition_tag=self.partition)
        else:
            return self.client.insert(self.collection, entities, ids=ids)
コード例 #2
0
    def _add_partition(self, collection: str, partition: str, milvus: mv.Milvus) -> None:
        """
        创建场景类

        @param {str} collection - 问题分类
        @param {str} partition - 场景类
        @param {mv.Milvus} milvus - Milvus连接对象
        """
        _status, _exists = milvus.has_partition(collection, partition)
        self.confirm_milvus_status(_status, 'has_partition')
        if not _exists:
            # 创建场景
            self.confirm_milvus_status(
                milvus.create_partition(collection, partition), 'create_partition'
            )
コード例 #3
0
class ANN(object):
    def __init__(self, host='10.119.33.90', port='19530', show_info=False):
        self.client = Milvus(host, port)

        if show_info:
            logger.info({
                "ClientVersion": self.client.client_version(),
                "ServerVersion": self.client.server_version()
            })

    def create_collection(self,
                          collection_name,
                          collection_param,
                          partition_tag=None,
                          overwrite=True):
        """

        :param collection_name:
        :param collection_param:
            collection_param = {
                "fields": [
                    #  Milvus doesn't support string type now, but we are considering supporting it soon.
                    #  {"name": "title", "type": DataType.STRING},
                    {"name": "category_", "type": DataType.INT32},
                    {"name": "vector", "type": DataType.FLOAT_VECTOR, "params": {"dim": 768}},
                ],
                "segment_row_limit": 4096,
                "auto_id": False
            }

        :param overwrite:
        :return:
        """
        if self.client.has_collection(collection_name) and overwrite:
            self.client.drop_collection(collection_name)
            self.client.flush()
            time.sleep(5)

            self.client.create_collection(collection_name, collection_param)
        elif self.client.has_collection(collection_name):
            print(f"{collection_name} already exist !!!")
        else:
            self.client.create_collection(collection_name, collection_param)

        if partition_tag is not None:
            self.client.create_partition(collection_name,
                                         partition_tag=partition_tag)

    def create_index(self,
                     collection_name,
                     field_name,
                     index_type='IVF_FLAT',
                     metric_type='IP',
                     index_params=None):
        """
        MetricType:
            INVALID = 0
            L2 = 1
            IP = 2
            # Only supported for byte vectors
            HAMMING = 3
            JACCARD = 4
            TANIMOTO = 5
            #
            SUBSTRUCTURE = 6
            SUPERSTRUCTURE = 7
        IndexType:
            INVALID = 0
            FLAT = 1
            IVFLAT = 2
            IVF_SQ8 = 3
            RNSG = 4
            IVF_SQ8H = 5
            IVF_PQ = 6
            HNSW = 11
            ANNOY = 12

            # alternative name
            IVF_FLAT = IVFLAT
            IVF_SQ8_H = IVF_SQ8H

        class DataType(IntEnum):
            NULL = 0
            INT8 = 1
            INT16 = 2
            INT32 = 3
            INT64 = 4

            STRING = 20

            BOOL = 30

            FLOAT = 40
            DOUBLE = 41

            VECTOR = 100
            UNKNOWN = 9999

        class RangeType(IntEnum):
            LT = 0   # less than
            LTE = 1  # less than or equal
            EQ = 2   # equal
            GT = 3   # greater than
            GTE = 4  # greater than or equal
            NE = 5   # not equal
        :return:
        """
        if index_params is None:
            index_params = {'nlist': 1024}

        params = {
            'index_type': index_type,
            # 'index_file_size': 1024,
            'params': index_params,
            'metric_type': metric_type,
        }
        self.client.create_index(collection_name, field_name,
                                 params)  # field_name='embedding'

    def batch_insert(self, collection_name, entities, batch_size=100000):

        # 分区
        n = len(entities[0]['values'])
        num_part = n // batch_size + 1

        ids = []
        values_list = [_['values'] for _ in entities]
        for i in range(num_part):
            for e, values in zip(entities, values_list):
                e['values'] = values[i * batch_size:(i + 1) * batch_size]
            ids += self.client.insert(collection_name, entities)
            self.client.flush()
        return ids

    def search(self):  # todo: 获取相同的信息
        pass

    def drop_collection(self, collection_name):
        if self.client.has_collection(collection_name):
            self.client.drop_collection(collection_name)

    def drop_partition(self, collection_name, partition_tag):
        if self.client.has_partition(collection_name, partition_tag):
            self.client.drop_partition(collection_name,
                                       partition_tag,
                                       timeout=30)