'dimension': 256, 'index_file_size': 1024, 'metric_type': MetricType.L2 } #milvus.create_collection(param) #删除集合 milvus.drop_collection(collection_name="test01") print("collection:", milvus.list_collections()) milvus.create_collection(param) #print(milvus.list_partitions("test01")) #创建分区 milvus.create_partition('test01', 'tag01') #print(milvus.list_partitions("test01")) #删除分区 #milvus.drop_partition('test01','tag01') print("partition:", milvus.list_partitions("test01")) import time import random import numpy as np #vectors = [[random.random() for _ in range(256)] for _ in range(3)] #print(np.shape(np.array(vectors))) #vector_ids = [1,2,3] vectors = [] vector_ids = [] id_query = {} with open("embed.txt", "r") as f: for line in f: line_lst = line.strip().split("\t") vectors.append(list(map(float, (line_lst[2].split())))) vector_ids.append(int(line_lst[1]))
class Indexer: ''' 索引器。 ''' def __init__(self, name, host='127.0.0.1', port='19531'): ''' 初始化。 ''' self.client = Milvus(host=host, port=port) self.collection = name def init(self, lenient=False): ''' 创建集合。 ''' if lenient: status, result = self.client.has_collection( collection_name=self.collection) if status.code != 0: raise ExertMilvusException(status) if result: return status = self.client.create_collection({ 'collection_name': self.collection, 'dimension': 512, 'index_file_size': 1024, 'metric_type': MetricType.L2 }) if status.code != 0 and not (lenient and status.code == 9): raise ExertMilvusException(status) # 创建索引。 status = self.client.create_index(collection_name=self.collection, index_type=IndexType.IVF_FLAT, params={'nlist': 16384}) if status.code != 0: raise ExertMilvusException(status) return status def drop(self): ''' 删除集合。 ''' status = self.client.drop_collection(collection_name=self.collection) if status.code != 0: raise ExertMilvusException(status) def flush(self): ''' 写入到硬盘。 ''' status = self.client.flush([self.collection]) if status.code != 0: raise ExertMilvusException(status) def compact(self): ''' 压缩集合。 ''' status = self.client.compact(collection_name=self.collection) if status.code != 0: raise ExertMilvusException(status) def close(self): ''' 关闭链接。 ''' self.client.close() def new_tag(self, tag): ''' 建分块标签。 ''' status = self.client.create_partition(collection_name=self.collection, partition_tag=tag) if status.code != 0: raise ExertMilvusException(status) def list_tag(self): ''' 列举分块标签。 ''' status, result = self.client.list_partitions( collection_name=self.collection) if status.code != 0: raise ExertMilvusException(status) return result def drop_tag(self, tag): ''' 删除分块标签。 ''' status = self.client.drop_partition(collection_name=self.collection, partition_tag=tag) if status.code != 0: raise ExertMilvusException(status) def index(self, vectors, tag=None, ids=None): ''' 添加索引 ''' params = {} if tag != None: params['tag'] = tag if ids != None: params['ids'] = ids status, result = self.client.insert(collection_name=self.collection, records=vectors, **params) if status.code != 0: raise ExertMilvusException(status) return result def listing(self, ids): ''' 列举信息。 ''' status, result = self.client.get_entity_by_id( collection_name=self.collection, ids=ids) if status.code != 0: raise ExertMilvusException(status) return result def counting(self): ''' 计算索引数。 ''' status, result = self.client.count_entities( collection_name=self.collection) if status.code != 0: raise ExertMilvusException(status) return result def unindex(self, ids): ''' 去掉索引。 ''' status = self.client.delete_entity_by_id( collection_name=self.collection, id_array=ids) if status.code != 0: raise ExertMilvusException(status) def search(self, vectors, top_count=100, tags=None): ''' 搜索。 ''' params = {'params': {'nprobe': 16}} if tags != None: params['partition_tags'] = tags status, results = self.client.search(collection_name=self.collection, query_records=vectors, top_k=top_count, **params) if status.code != 0: raise ExertMilvusException(status) return results
# Basic create collection: # After create collection `demo_films`, we create a partition tagged "American", it means the films we # will be inserted are from American. # ------ client.create_collection(collection_name, collection_param) client.create_partition(collection_name, "American") # ------ # Basic create collection: # You can check the collection info and partitions we've created by `get_collection_info` and # `list_partitions` # ------ print("--------get collection info--------") collection = client.get_collection_info(collection_name) pprint(collection) partitions = client.list_partitions(collection_name) print("\n----------list partitions----------") pprint(partitions) # ------ # Basic insert entities: # We have three films of The_Lord_of_the_Rings series here with their id, duration release_year # and fake embeddings to be inserted. They are listed below to give you a overview of the structure. # ------ The_Lord_of_the_Rings = [ { "title": "The_Fellowship_of_the_Ring", "id": 1, "duration": 208, "release_year": 2001, "embedding": [random.random() for _ in range(8)]
class MilvusClient(object): def __init__(self, collection_name=None, host=None, port=None, timeout=60): """ Milvus client wrapper for python-sdk. Default timeout set 60s """ self._collection_name = collection_name try: start_time = time.time() if not host: host = SERVER_HOST_DEFAULT if not port: port = SERVER_PORT_DEFAULT logger.debug(host) logger.debug(port) # retry connect for remote server i = 0 while time.time() < start_time + timeout: try: self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False) if self._milvus.server_status(): logger.debug("Try connect times: %d, %s" % (i, round(time.time() - start_time, 2))) break except Exception as e: logger.debug("Milvus connect failed: %d times" % i) i = i + 1 if time.time() > start_time + timeout: raise Exception("Server connect timeout") except Exception as e: raise e self._metric_type = None if self._collection_name and self.exists_collection(): self._metric_type = metric_type_to_str(self.describe()[1].metric_type) self._dimension = self.describe()[1].dimension def __str__(self): return 'Milvus collection %s' % self._collection_name def set_collection(self, name): self._collection_name = name def check_status(self, status): if not status.OK(): logger.error(self._collection_name) logger.error(status.message) logger.error(self._milvus.server_status()) logger.error(self.count()) raise Exception("Status not ok") def check_result_ids(self, result): for index, item in enumerate(result): if item[0].distance >= epsilon: logger.error(index) logger.error(item[0].distance) raise Exception("Distance wrong") def create_collection(self, collection_name, dimension, index_file_size, metric_type): if not self._collection_name: self._collection_name = collection_name if metric_type not in METRIC_MAP.keys(): raise Exception("Not supported metric_type: %s" % metric_type) metric_type = METRIC_MAP[metric_type] create_param = {'collection_name': collection_name, 'dimension': dimension, 'index_file_size': index_file_size, "metric_type": metric_type} status = self._milvus.create_collection(create_param) self.check_status(status) def create_partition(self, tag_name): status = self._milvus.create_partition(self._collection_name, tag_name) self.check_status(status) def drop_partition(self, tag_name): status = self._milvus.drop_partition(self._collection_name, tag_name) self.check_status(status) def list_partitions(self): status, tags = self._milvus.list_partitions(self._collection_name) self.check_status(status) return tags @time_wrapper def insert(self, X, ids=None, collection_name=None): if collection_name is None: collection_name = self._collection_name status, result = self._milvus.insert(collection_name, X, ids) self.check_status(status) return status, result def insert_rand(self): insert_xb = random.randint(1, 100) X = [[random.random() for _ in range(self._dimension)] for _ in range(insert_xb)] X = utils.normalize(self._metric_type, X) count_before = self.count() status, _ = self.insert(X) self.check_status(status) self.flush() if count_before + insert_xb != self.count(): raise Exception("Assert failed after inserting") def get_rand_ids(self, length): while True: status, stats = self._milvus.get_collection_stats(self._collection_name) self.check_status(status) segments = stats["partitions"][0]["segments"] # random choice one segment segment = random.choice(segments) status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) if not status.OK(): logger.error(status.message) continue if len(segment_ids): break if length >= len(segment_ids): logger.debug("Reset length: %d" % len(segment_ids)) return segment_ids return random.sample(segment_ids, length) def get_rand_ids_each_segment(self, length): res = [] status, stats = self._milvus.get_collection_stats(self._collection_name) self.check_status(status) segments = stats["partitions"][0]["segments"] segments_num = len(segments) # random choice from each segment for segment in segments: status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) self.check_status(status) res.extend(segment_ids[:length]) return segments_num, res def get_rand_entities(self, length): ids = self.get_rand_ids(length) status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids) self.check_status(status) return ids, get_res @time_wrapper def get_entities(self, get_ids): status, get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids) self.check_status(status) return get_res @time_wrapper def delete(self, ids, collection_name=None): if collection_name is None: collection_name = self._collection_name status = self._milvus.delete_entity_by_id(collection_name, ids) self.check_status(status) def delete_rand(self): delete_id_length = random.randint(1, 100) count_before = self.count() logger.info("%s: length to delete: %d" % (self._collection_name, delete_id_length)) delete_ids = self.get_rand_ids(delete_id_length) self.delete(delete_ids) self.flush() logger.info("%s: count after delete: %d" % (self._collection_name, self.count())) status, get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids) self.check_status(status) for item in get_res: if item: raise Exception("Assert failed after delete") if count_before - len(delete_ids) != self.count(): raise Exception("Assert failed after delete") @time_wrapper def flush(self, collection_name=None): if collection_name is None: collection_name = self._collection_name status = self._milvus.flush([collection_name]) self.check_status(status) @time_wrapper def compact(self, collection_name=None): if collection_name is None: collection_name = self._collection_name status = self._milvus.compact(collection_name) self.check_status(status) @time_wrapper def create_index(self, index_type, index_param=None): index_type = INDEX_MAP[index_type] logger.info("Building index start, collection_name: %s, index_type: %s" % (self._collection_name, index_type)) if index_param: logger.info(index_param) status = self._milvus.create_index(self._collection_name, index_type, index_param) self.check_status(status) def describe_index(self): status, result = self._milvus.get_index_info(self._collection_name) self.check_status(status) index_type = None for k, v in INDEX_MAP.items(): if result._index_type == v: index_type = k break return {"index_type": index_type, "index_param": result._params} def drop_index(self): logger.info("Drop index: %s" % self._collection_name) return self._milvus.drop_index(self._collection_name) def query(self, X, top_k, search_param=None, collection_name=None): if collection_name is None: collection_name = self._collection_name status, result = self._milvus.search(collection_name, top_k, query_records=X, params=search_param) self.check_status(status) return result def query_rand(self): top_k = random.randint(1, 100) nq = random.randint(1, 100) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} _, X = self.get_rand_entities(nq) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) status, _ = self._milvus.search(self._collection_name, top_k, query_records=X, params=search_param) self.check_status(status) # for i, item in enumerate(search_res): # if item[0].id != ids[i]: # logger.warning("The index of search result: %d" % i) # raise Exception("Query failed") # @time_wrapper # def query_ids(self, top_k, ids, search_param=None): # status, result = self._milvus.search_by_id(self._collection_name, ids, top_k, params=search_param) # self.check_result_ids(result) # return result def count(self, name=None): if name is None: name = self._collection_name logger.debug(self._milvus.count_entities(name)) row_count = self._milvus.count_entities(name)[1] if not row_count: row_count = 0 logger.debug("Row count: %d in collection: <%s>" % (row_count, name)) return row_count def drop(self, timeout=120, name=None): timeout = int(timeout) if name is None: name = self._collection_name logger.info("Start delete collection: %s" % name) status = self._milvus.drop_collection(name) self.check_status(status) i = 0 while i < timeout: if self.count(name=name): time.sleep(1) i = i + 1 continue else: break if i >= timeout: logger.error("Delete collection timeout") def describe(self): # logger.info(self._milvus.get_collection_info(self._collection_name)) return self._milvus.get_collection_info(self._collection_name) def show_collections(self): return self._milvus.list_collections() def exists_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name _, res = self._milvus.has_collection(collection_name) # self.check_status(status) return res def clean_db(self): collection_names = self.show_collections()[1] for name in collection_names: logger.debug(name) self.drop(name=name) @time_wrapper def preload_collection(self): status = self._milvus.load_collection(self._collection_name, timeout=3000) self.check_status(status) return status def get_server_version(self): _, res = self._milvus.server_version() return res def get_server_mode(self): return self.cmd("mode") def get_server_commit(self): return self.cmd("build_commit_id") def get_server_config(self): return json.loads(self.cmd("get_config *")) def get_mem_info(self): result = json.loads(self.cmd("get_system_info")) result_human = { # unit: Gb "memory_used": round(int(result["memory_used"]) / (1024*1024*1024), 2) } return result_human def cmd(self, command): status, res = self._milvus._cmd(command) logger.info("Server command: %s, result: %s" % (command, res)) self.check_status(status) return res
def main(): # Connect to Milvus server # You may need to change _HOST and _PORT accordingly param = {'host': _HOST, 'port': _PORT} # You can create a instance specified server addr and # invoke rpc method directly client = Milvus(**param) # Create collection demo_collection if it dosen't exist. collection_name = 'demo_partition_collection' partition_tag = "random" # create collection param = { 'collection_name': collection_name, 'dimension': _DIM, 'index_file_size': _INDEX_FILE_SIZE, # optional 'metric_type': MetricType.L2 # optional } client.create_collection(param) # Show collections in Milvus server _, collections = client.list_collections() # Describe collection _, collection = client.get_collection_info(collection_name) print(collection) # create partition client.create_partition(collection_name, partition_tag=partition_tag) # display partitions _, partitions = client.list_partitions(collection_name) # 10000 vectors with 16 dimension # element per dimension is float32 type # vectors should be a 2-D array vectors = [[random.random() for _ in range(_DIM)] for _ in range(10000)] # You can also use numpy to generate random vectors: # `vectors = np.random.rand(10000, 16).astype(np.float32).tolist()` # Insert vectors into partition of collection, return status and vectors id list status, ids = client.insert(collection_name=collection_name, records=vectors, partition_tag=partition_tag) # Wait for 6 seconds, until Milvus server persist vector data. time.sleep(6) # Get demo_collection row count status, num = client.count_entities(collection_name) # create index of vectors, search more rapidly index_param = { 'nlist': 2048 } # Create ivflat index in demo_collection # You can search vectors without creating index. however, Creating index help to # search faster status = client.create_index(collection_name, IndexType.IVF_FLAT, index_param) # describe index, get information of index status, index = client.get_index_info(collection_name) print(index) # Use the top 10 vectors for similarity search query_vectors = vectors[0:10] # execute vector similarity search, search range in partition `partition1` search_param = { "nprobe": 10 } param = { 'collection_name': collection_name, 'query_records': query_vectors, 'top_k': 1, 'partition_tags': ["random"], 'params': search_param } status, results = client.search(**param) if status.OK(): # indicate search result # also use by: # `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]` if results[0][0].distance == 0.0 or results[0][0].id == ids[0]: print('Query result is correct') else: print('Query result isn\'t correct') # print results print(results) # Drop partition. You can also invoke `drop_collection()`, so that all of partitions belongs to # designated collections will be deleted. status = client.drop_partition(collection_name, partition_tag) # Delete collection. All of partitions of this collection will be dropped. status = client.drop_collection(collection_name)