Exemplo n.º 1
0
class MilvusANN(object):
    def __init__(self, host='10.46.5.98', port='19530'):
        self.milvus = Milvus()

        print("Client Version:", self.milvus.client_version())

        status = self.milvus.connect(host, port)

        if status.OK():
            print("Server connected.")
        else:
            print("Server connect fail.")
            sys.exit(1)

        print("Server Version:", self.milvus.server_version()[-1])

    def desc(self, tabel_name=None):
        milvus = self.milvus
        milvus.show_collections()
        # milvus.drop_collection()
        if tabel_name:
            print(f"Describe: {milvus.describe_collection(tabel_name)[-1]}")
            print(
                f"Vector number in {tabel_name}: {milvus.count_collection(tabel_name)}"
            )

    def create_tabel_demo(self):
        # Create table demo_table if it dosen't exist.
        milvus = self.milvus
        table_name = 'demo_table'

        status, ok = milvus.has_collection(table_name)
        if not ok:
            param = {
                'collection_name': table_name,
                'dimension': 16,
                'index_file_size':
                1024,  # optional index_file_size:文件到达这个大小的时候,milvus开始为这个文件创建索引。
                'metric_type': MetricType.L2  # optional
            }

            milvus.create_collection(param)

        # Show tables in Milvus server
        _, collections = milvus.show_collections()

        # Describe demo_table
        _, table = milvus.describe_collection(table_name)
        print(table)

    def insert_vectors_demo(self, collection_name):
        milvus = self.milvus

        # 10000 vectors with 16 dimension
        # element per dimension is float32 type
        # vectors should be a 2-D array
        # vectors = [[random.random() for _ in range(16)] for _ in range(10000)]
        vectors = np.random.rand(10000, 16).astype(np.float32).tolist()
        # You can also use numpy to generate random vectors:
        #     `vectors = np.random.rand(10000, 16).astype(np.float32).tolist()`

        # Insert vectors into demo_table, return status and vectors id list
        status, self.ids = milvus.insert(collection_name,
                                         vectors)  # 时间戳 1581655102 786 118

        # Wait for 6 seconds, until Milvus server persist vector data.
        time.sleep(6)

        # Get demo_table row count
        status, result = milvus.count_collection(collection_name)

        # create index of vectors, search more rapidly
        index_param = {'nlist': 2048}

        # Create ivflat index in demo_table
        # You can search vectors without creating index. however, Creating index help to
        # search faster
        status = milvus.create_index(collection_name,
                                     index_type=IndexType.IVFLAT,
                                     params=index_param)

        # describe index, get information of index
        status, index = milvus.describe_index(collection_name)
        print(index)

        # Use the top 10 vectors for similarity search
        self._query_vectors = vectors[0:10]

    def search_vectors_demo(self, query_vectors, collection_name):
        milvus = self.milvus

        # execute vector similarity search
        status, results = milvus.search_vectors(collection_name,
                                                top_k=1,
                                                query_records=query_vectors,
                                                params={'nprobe': 16})
        if status.OK():
            # indicate search result
            # also use by:
            #   `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]`
            if results[0][0].distance == 0.0 or results[0][0].id == self.ids[0]:
                print('Query result is correct')
            else:
                print('Query result isn\'t correct')

        # print results
        print(results)

    def drop_table(self, collection_name):
        milvus = self.milvus

        # Delete demo_table
        status = milvus.drop_collection(collection_name)

        # Disconnect from Milvus
        status = milvus.disconnect()
Exemplo n.º 2
0
def main():
    milvus = Milvus()

    # Print client version
    print('# Client version: {}'.format(milvus.client_version()))

    # Connect milvus server
    # Please change HOST and PORT to the correct one
    param = {'host': _HOST, 'port': _PORT}
    cnn_status = milvus.connect(**param)
    print('# Connect Status: {}'.format(cnn_status))

    # Check if connected
    # is_connected = milvus.connected
    print('# Is connected: {}'.format(milvus.connected))

    # Print milvus server version
    print('# Server version: {}'.format(milvus.server_version()))

    # Describe table
    table_name = 'table01'
    res_status, table = milvus.describe_table(table_name)
    print('# Describe table status: {}'.format(res_status))
    print('# Describe table:{}'.format(table))

    # Create table
    # Check if `table01` exists, if not, create a table `table01`
    dimension = 256
    if not table:
        param = {
            'table_name': table_name,
            'dimension': dimension,
            'index_type': IndexType.IDMAP,
            'store_raw_vector': False
        }

        res_status = milvus.create_table(Prepare.table_schema(**param))
        print('# Create table status: {}'.format(res_status))

    # Show tables and their description
    status, tables = milvus.show_tables()
    pprint(tables)

    # Add vectors
    # Prepare vector with 256 dimension
    vectors = Prepare.records([[random.random() for _ in range(dimension)]
                               for _ in range(20)])

    # Insert vectors into table 'table01'
    status, ids = milvus.add_vectors(table_name=table_name, records=vectors)
    print('# Add vector status: {}'.format(status))
    pprint(ids)

    # Search vectors
    # When adding vectors for the first time, server will take at least 5s to
    # persist vector data, so you have to wait for 6s after adding vectors for
    # the first time.
    print('# Waiting for 6s...')
    time.sleep(6)

    q_records = Prepare.records([[random.random() for _ in range(dimension)]
                                 for _ in range(2)])

    param = {
        'table_name': table_name,
        'query_records': q_records,
        'top_k': 10,
    }
    status, results = milvus.search_vectors(**param)
    print('# Search vectors status: {}'.format(status))
    pprint(results)

    # Get table row count
    status, result = milvus.get_table_row_count(table_name)
    print('# Status: {}'.format(status))
    print('# Count: {}'.format(result))

    # Disconnect
    status = milvus.disconnect()
    print('# Disconnect Status: {}'.format(status))
Exemplo n.º 3
0
class ANN(object):
    def __init__(self, host='10.119.33.90', port='19530', show_info=False):
        self.client = Milvus(host, port)

        if show_info:
            logger.info({
                "ClientVersion": self.client.client_version(),
                "ServerVersion": self.client.server_version()
            })

    def create_collection(self,
                          collection_name,
                          collection_param,
                          partition_tag=None,
                          overwrite=True):
        """

        :param collection_name:
        :param collection_param:
            collection_param = {
                "fields": [
                    #  Milvus doesn't support string type now, but we are considering supporting it soon.
                    #  {"name": "title", "type": DataType.STRING},
                    {"name": "category_", "type": DataType.INT32},
                    {"name": "vector", "type": DataType.FLOAT_VECTOR, "params": {"dim": 768}},
                ],
                "segment_row_limit": 4096,
                "auto_id": False
            }

        :param overwrite:
        :return:
        """
        if self.client.has_collection(collection_name) and overwrite:
            self.client.drop_collection(collection_name)
            self.client.flush()
            time.sleep(5)

            self.client.create_collection(collection_name, collection_param)
        elif self.client.has_collection(collection_name):
            print(f"{collection_name} already exist !!!")
        else:
            self.client.create_collection(collection_name, collection_param)

        if partition_tag is not None:
            self.client.create_partition(collection_name,
                                         partition_tag=partition_tag)

    def create_index(self,
                     collection_name,
                     field_name,
                     index_type='IVF_FLAT',
                     metric_type='IP',
                     index_params=None):
        """
        MetricType:
            INVALID = 0
            L2 = 1
            IP = 2
            # Only supported for byte vectors
            HAMMING = 3
            JACCARD = 4
            TANIMOTO = 5
            #
            SUBSTRUCTURE = 6
            SUPERSTRUCTURE = 7
        IndexType:
            INVALID = 0
            FLAT = 1
            IVFLAT = 2
            IVF_SQ8 = 3
            RNSG = 4
            IVF_SQ8H = 5
            IVF_PQ = 6
            HNSW = 11
            ANNOY = 12

            # alternative name
            IVF_FLAT = IVFLAT
            IVF_SQ8_H = IVF_SQ8H

        class DataType(IntEnum):
            NULL = 0
            INT8 = 1
            INT16 = 2
            INT32 = 3
            INT64 = 4

            STRING = 20

            BOOL = 30

            FLOAT = 40
            DOUBLE = 41

            VECTOR = 100
            UNKNOWN = 9999

        class RangeType(IntEnum):
            LT = 0   # less than
            LTE = 1  # less than or equal
            EQ = 2   # equal
            GT = 3   # greater than
            GTE = 4  # greater than or equal
            NE = 5   # not equal
        :return:
        """
        if index_params is None:
            index_params = {'nlist': 1024}

        params = {
            'index_type': index_type,
            # 'index_file_size': 1024,
            'params': index_params,
            'metric_type': metric_type,
        }
        self.client.create_index(collection_name, field_name,
                                 params)  # field_name='embedding'

    def batch_insert(self, collection_name, entities, batch_size=100000):

        # 分区
        n = len(entities[0]['values'])
        num_part = n // batch_size + 1

        ids = []
        values_list = [_['values'] for _ in entities]
        for i in range(num_part):
            for e, values in zip(entities, values_list):
                e['values'] = values[i * batch_size:(i + 1) * batch_size]
            ids += self.client.insert(collection_name, entities)
            self.client.flush()
        return ids

    def search(self):  # todo: 获取相同的信息
        pass

    def drop_collection(self, collection_name):
        if self.client.has_collection(collection_name):
            self.client.drop_collection(collection_name)

    def drop_partition(self, collection_name, partition_tag):
        if self.client.has_partition(collection_name, partition_tag):
            self.client.drop_partition(collection_name,
                                       partition_tag,
                                       timeout=30)