class MilvusANN(object): def __init__(self, host='10.46.5.98', port='19530'): self.milvus = Milvus() print("Client Version:", self.milvus.client_version()) status = self.milvus.connect(host, port) if status.OK(): print("Server connected.") else: print("Server connect fail.") sys.exit(1) print("Server Version:", self.milvus.server_version()[-1]) def desc(self, tabel_name=None): milvus = self.milvus milvus.show_collections() # milvus.drop_collection() if tabel_name: print(f"Describe: {milvus.describe_collection(tabel_name)[-1]}") print( f"Vector number in {tabel_name}: {milvus.count_collection(tabel_name)}" ) def create_tabel_demo(self): # Create table demo_table if it dosen't exist. milvus = self.milvus table_name = 'demo_table' status, ok = milvus.has_collection(table_name) if not ok: param = { 'collection_name': table_name, 'dimension': 16, 'index_file_size': 1024, # optional index_file_size:文件到达这个大小的时候,milvus开始为这个文件创建索引。 'metric_type': MetricType.L2 # optional } milvus.create_collection(param) # Show tables in Milvus server _, collections = milvus.show_collections() # Describe demo_table _, table = milvus.describe_collection(table_name) print(table) def insert_vectors_demo(self, collection_name): milvus = self.milvus # 10000 vectors with 16 dimension # element per dimension is float32 type # vectors should be a 2-D array # vectors = [[random.random() for _ in range(16)] for _ in range(10000)] vectors = np.random.rand(10000, 16).astype(np.float32).tolist() # You can also use numpy to generate random vectors: # `vectors = np.random.rand(10000, 16).astype(np.float32).tolist()` # Insert vectors into demo_table, return status and vectors id list status, self.ids = milvus.insert(collection_name, vectors) # 时间戳 1581655102 786 118 # Wait for 6 seconds, until Milvus server persist vector data. time.sleep(6) # Get demo_table row count status, result = milvus.count_collection(collection_name) # create index of vectors, search more rapidly index_param = {'nlist': 2048} # Create ivflat index in demo_table # You can search vectors without creating index. however, Creating index help to # search faster status = milvus.create_index(collection_name, index_type=IndexType.IVFLAT, params=index_param) # describe index, get information of index status, index = milvus.describe_index(collection_name) print(index) # Use the top 10 vectors for similarity search self._query_vectors = vectors[0:10] def search_vectors_demo(self, query_vectors, collection_name): milvus = self.milvus # execute vector similarity search status, results = milvus.search_vectors(collection_name, top_k=1, query_records=query_vectors, params={'nprobe': 16}) if status.OK(): # indicate search result # also use by: # `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]` if results[0][0].distance == 0.0 or results[0][0].id == self.ids[0]: print('Query result is correct') else: print('Query result isn\'t correct') # print results print(results) def drop_table(self, collection_name): milvus = self.milvus # Delete demo_table status = milvus.drop_collection(collection_name) # Disconnect from Milvus status = milvus.disconnect()
def main(): milvus = Milvus() # Print client version print('# Client version: {}'.format(milvus.client_version())) # Connect milvus server # Please change HOST and PORT to the correct one param = {'host': _HOST, 'port': _PORT} cnn_status = milvus.connect(**param) print('# Connect Status: {}'.format(cnn_status)) # Check if connected # is_connected = milvus.connected print('# Is connected: {}'.format(milvus.connected)) # Print milvus server version print('# Server version: {}'.format(milvus.server_version())) # Describe table table_name = 'table01' res_status, table = milvus.describe_table(table_name) print('# Describe table status: {}'.format(res_status)) print('# Describe table:{}'.format(table)) # Create table # Check if `table01` exists, if not, create a table `table01` dimension = 256 if not table: param = { 'table_name': table_name, 'dimension': dimension, 'index_type': IndexType.IDMAP, 'store_raw_vector': False } res_status = milvus.create_table(Prepare.table_schema(**param)) print('# Create table status: {}'.format(res_status)) # Show tables and their description status, tables = milvus.show_tables() pprint(tables) # Add vectors # Prepare vector with 256 dimension vectors = Prepare.records([[random.random() for _ in range(dimension)] for _ in range(20)]) # Insert vectors into table 'table01' status, ids = milvus.add_vectors(table_name=table_name, records=vectors) print('# Add vector status: {}'.format(status)) pprint(ids) # Search vectors # When adding vectors for the first time, server will take at least 5s to # persist vector data, so you have to wait for 6s after adding vectors for # the first time. print('# Waiting for 6s...') time.sleep(6) q_records = Prepare.records([[random.random() for _ in range(dimension)] for _ in range(2)]) param = { 'table_name': table_name, 'query_records': q_records, 'top_k': 10, } status, results = milvus.search_vectors(**param) print('# Search vectors status: {}'.format(status)) pprint(results) # Get table row count status, result = milvus.get_table_row_count(table_name) print('# Status: {}'.format(status)) print('# Count: {}'.format(result)) # Disconnect status = milvus.disconnect() print('# Disconnect Status: {}'.format(status))
class ANN(object): def __init__(self, host='10.119.33.90', port='19530', show_info=False): self.client = Milvus(host, port) if show_info: logger.info({ "ClientVersion": self.client.client_version(), "ServerVersion": self.client.server_version() }) def create_collection(self, collection_name, collection_param, partition_tag=None, overwrite=True): """ :param collection_name: :param collection_param: collection_param = { "fields": [ # Milvus doesn't support string type now, but we are considering supporting it soon. # {"name": "title", "type": DataType.STRING}, {"name": "category_", "type": DataType.INT32}, {"name": "vector", "type": DataType.FLOAT_VECTOR, "params": {"dim": 768}}, ], "segment_row_limit": 4096, "auto_id": False } :param overwrite: :return: """ if self.client.has_collection(collection_name) and overwrite: self.client.drop_collection(collection_name) self.client.flush() time.sleep(5) self.client.create_collection(collection_name, collection_param) elif self.client.has_collection(collection_name): print(f"{collection_name} already exist !!!") else: self.client.create_collection(collection_name, collection_param) if partition_tag is not None: self.client.create_partition(collection_name, partition_tag=partition_tag) def create_index(self, collection_name, field_name, index_type='IVF_FLAT', metric_type='IP', index_params=None): """ MetricType: INVALID = 0 L2 = 1 IP = 2 # Only supported for byte vectors HAMMING = 3 JACCARD = 4 TANIMOTO = 5 # SUBSTRUCTURE = 6 SUPERSTRUCTURE = 7 IndexType: INVALID = 0 FLAT = 1 IVFLAT = 2 IVF_SQ8 = 3 RNSG = 4 IVF_SQ8H = 5 IVF_PQ = 6 HNSW = 11 ANNOY = 12 # alternative name IVF_FLAT = IVFLAT IVF_SQ8_H = IVF_SQ8H class DataType(IntEnum): NULL = 0 INT8 = 1 INT16 = 2 INT32 = 3 INT64 = 4 STRING = 20 BOOL = 30 FLOAT = 40 DOUBLE = 41 VECTOR = 100 UNKNOWN = 9999 class RangeType(IntEnum): LT = 0 # less than LTE = 1 # less than or equal EQ = 2 # equal GT = 3 # greater than GTE = 4 # greater than or equal NE = 5 # not equal :return: """ if index_params is None: index_params = {'nlist': 1024} params = { 'index_type': index_type, # 'index_file_size': 1024, 'params': index_params, 'metric_type': metric_type, } self.client.create_index(collection_name, field_name, params) # field_name='embedding' def batch_insert(self, collection_name, entities, batch_size=100000): # 分区 n = len(entities[0]['values']) num_part = n // batch_size + 1 ids = [] values_list = [_['values'] for _ in entities] for i in range(num_part): for e, values in zip(entities, values_list): e['values'] = values[i * batch_size:(i + 1) * batch_size] ids += self.client.insert(collection_name, entities) self.client.flush() return ids def search(self): # todo: 获取相同的信息 pass def drop_collection(self, collection_name): if self.client.has_collection(collection_name): self.client.drop_collection(collection_name) def drop_partition(self, collection_name, partition_tag): if self.client.has_partition(collection_name, partition_tag): self.client.drop_partition(collection_name, partition_tag, timeout=30)