class MilvusConnection: def __init__(self, env, name="movies_L2", port="19530", param=None): if param is None: param = dict() param = { "collection_name": name, "dimension": 128, "index_file_size": 1024, "metric_type": MetricType.L2, **param, } self.name = name self.client = Milvus(host="localhost", port=port) self.statuses = {} if not self.client.has_collection(name)[1]: status_created_collection = self.client.create_collection(param) vectors = env.base.embeddings.detach().cpu().numpy().astype( "float32") target_ids = list(range(vectors.shape[0])) status_inserted, inserted_vector_ids = self.client.insert( collection_name=name, records=vectors, ids=target_ids) status_flushed = self.client.flush([name]) status_compacted = self.client.compact(collection_name=name) self.statuses["created_collection"] = status_created_collection self.statuses["inserted"] = status_inserted self.statuses["flushed"] = status_flushed self.statuses["compacted"] = status_compacted def search(self, search_vecs, topk=10, search_param=None): if search_param is None: search_param = dict() search_param = {"nprobe": 16, **search_param} status, results = self.client.search( collection_name=self.name, query_records=search_vecs, top_k=topk, params=search_param, ) self.statuses['last_search'] = status return torch.tensor(results.id_array) def get_log(self): return self.statuses
class MilvusClient(object): def __init__(self, collection_name=None, host=None, port=None, timeout=60): """ Milvus client wrapper for python-sdk. Default timeout set 60s """ self._collection_name = collection_name try: start_time = time.time() if not host: host = SERVER_HOST_DEFAULT if not port: port = SERVER_PORT_DEFAULT logger.debug(host) logger.debug(port) # retry connect for remote server i = 0 while time.time() < start_time + timeout: try: self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False) if self._milvus.server_status(): logger.debug("Try connect times: %d, %s" % (i, round(time.time() - start_time, 2))) break except Exception as e: logger.debug("Milvus connect failed: %d times" % i) i = i + 1 if time.time() > start_time + timeout: raise Exception("Server connect timeout") except Exception as e: raise e self._metric_type = None if self._collection_name and self.exists_collection(): self._metric_type = metric_type_to_str(self.describe()[1].metric_type) self._dimension = self.describe()[1].dimension def __str__(self): return 'Milvus collection %s' % self._collection_name def set_collection(self, name): self._collection_name = name def check_status(self, status): if not status.OK(): logger.error(self._collection_name) logger.error(status.message) logger.error(self._milvus.server_status()) logger.error(self.count()) raise Exception("Status not ok") def check_result_ids(self, result): for index, item in enumerate(result): if item[0].distance >= epsilon: logger.error(index) logger.error(item[0].distance) raise Exception("Distance wrong") def create_collection(self, collection_name, dimension, index_file_size, metric_type): if not self._collection_name: self._collection_name = collection_name if metric_type not in METRIC_MAP.keys(): raise Exception("Not supported metric_type: %s" % metric_type) metric_type = METRIC_MAP[metric_type] create_param = {'collection_name': collection_name, 'dimension': dimension, 'index_file_size': index_file_size, "metric_type": metric_type} status = self._milvus.create_collection(create_param) self.check_status(status) def create_partition(self, tag_name): status = self._milvus.create_partition(self._collection_name, tag_name) self.check_status(status) def drop_partition(self, tag_name): status = self._milvus.drop_partition(self._collection_name, tag_name) self.check_status(status) def list_partitions(self): status, tags = self._milvus.list_partitions(self._collection_name) self.check_status(status) return tags @time_wrapper def insert(self, X, ids=None, collection_name=None): if collection_name is None: collection_name = self._collection_name status, result = self._milvus.insert(collection_name, X, ids) self.check_status(status) return status, result def insert_rand(self): insert_xb = random.randint(1, 100) X = [[random.random() for _ in range(self._dimension)] for _ in range(insert_xb)] X = utils.normalize(self._metric_type, X) count_before = self.count() status, _ = self.insert(X) self.check_status(status) self.flush() if count_before + insert_xb != self.count(): raise Exception("Assert failed after inserting") def get_rand_ids(self, length): while True: status, stats = self._milvus.get_collection_stats(self._collection_name) self.check_status(status) segments = stats["partitions"][0]["segments"] # random choice one segment segment = random.choice(segments) status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) if not status.OK(): logger.error(status.message) continue if len(segment_ids): break if length >= len(segment_ids): logger.debug("Reset length: %d" % len(segment_ids)) return segment_ids return random.sample(segment_ids, length) def get_rand_ids_each_segment(self, length): res = [] status, stats = self._milvus.get_collection_stats(self._collection_name) self.check_status(status) segments = stats["partitions"][0]["segments"] segments_num = len(segments) # random choice from each segment for segment in segments: status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) self.check_status(status) res.extend(segment_ids[:length]) return segments_num, res def get_rand_entities(self, length): ids = self.get_rand_ids(length) status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids) self.check_status(status) return ids, get_res @time_wrapper def get_entities(self, get_ids): status, get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids) self.check_status(status) return get_res @time_wrapper def delete(self, ids, collection_name=None): if collection_name is None: collection_name = self._collection_name status = self._milvus.delete_entity_by_id(collection_name, ids) self.check_status(status) def delete_rand(self): delete_id_length = random.randint(1, 100) count_before = self.count() logger.info("%s: length to delete: %d" % (self._collection_name, delete_id_length)) delete_ids = self.get_rand_ids(delete_id_length) self.delete(delete_ids) self.flush() logger.info("%s: count after delete: %d" % (self._collection_name, self.count())) status, get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids) self.check_status(status) for item in get_res: if item: raise Exception("Assert failed after delete") if count_before - len(delete_ids) != self.count(): raise Exception("Assert failed after delete") @time_wrapper def flush(self, collection_name=None): if collection_name is None: collection_name = self._collection_name status = self._milvus.flush([collection_name]) self.check_status(status) @time_wrapper def compact(self, collection_name=None): if collection_name is None: collection_name = self._collection_name status = self._milvus.compact(collection_name) self.check_status(status) @time_wrapper def create_index(self, index_type, index_param=None): index_type = INDEX_MAP[index_type] logger.info("Building index start, collection_name: %s, index_type: %s" % (self._collection_name, index_type)) if index_param: logger.info(index_param) status = self._milvus.create_index(self._collection_name, index_type, index_param) self.check_status(status) def describe_index(self): status, result = self._milvus.get_index_info(self._collection_name) self.check_status(status) index_type = None for k, v in INDEX_MAP.items(): if result._index_type == v: index_type = k break return {"index_type": index_type, "index_param": result._params} def drop_index(self): logger.info("Drop index: %s" % self._collection_name) return self._milvus.drop_index(self._collection_name) def query(self, X, top_k, search_param=None, collection_name=None): if collection_name is None: collection_name = self._collection_name status, result = self._milvus.search(collection_name, top_k, query_records=X, params=search_param) self.check_status(status) return result def query_rand(self): top_k = random.randint(1, 100) nq = random.randint(1, 100) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} _, X = self.get_rand_entities(nq) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) status, _ = self._milvus.search(self._collection_name, top_k, query_records=X, params=search_param) self.check_status(status) # for i, item in enumerate(search_res): # if item[0].id != ids[i]: # logger.warning("The index of search result: %d" % i) # raise Exception("Query failed") # @time_wrapper # def query_ids(self, top_k, ids, search_param=None): # status, result = self._milvus.search_by_id(self._collection_name, ids, top_k, params=search_param) # self.check_result_ids(result) # return result def count(self, name=None): if name is None: name = self._collection_name logger.debug(self._milvus.count_entities(name)) row_count = self._milvus.count_entities(name)[1] if not row_count: row_count = 0 logger.debug("Row count: %d in collection: <%s>" % (row_count, name)) return row_count def drop(self, timeout=120, name=None): timeout = int(timeout) if name is None: name = self._collection_name logger.info("Start delete collection: %s" % name) status = self._milvus.drop_collection(name) self.check_status(status) i = 0 while i < timeout: if self.count(name=name): time.sleep(1) i = i + 1 continue else: break if i >= timeout: logger.error("Delete collection timeout") def describe(self): # logger.info(self._milvus.get_collection_info(self._collection_name)) return self._milvus.get_collection_info(self._collection_name) def show_collections(self): return self._milvus.list_collections() def exists_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name _, res = self._milvus.has_collection(collection_name) # self.check_status(status) return res def clean_db(self): collection_names = self.show_collections()[1] for name in collection_names: logger.debug(name) self.drop(name=name) @time_wrapper def preload_collection(self): status = self._milvus.load_collection(self._collection_name, timeout=3000) self.check_status(status) return status def get_server_version(self): _, res = self._milvus.server_version() return res def get_server_mode(self): return self.cmd("mode") def get_server_commit(self): return self.cmd("build_commit_id") def get_server_config(self): return json.loads(self.cmd("get_config *")) def get_mem_info(self): result = json.loads(self.cmd("get_system_info")) result_human = { # unit: Gb "memory_used": round(int(result["memory_used"]) / (1024*1024*1024), 2) } return result_human def cmd(self, command): status, res = self._milvus._cmd(command) logger.info("Server command: %s, result: %s" % (command, res)) self.check_status(status) return res
insert_future = client.insert(collection_name, Batmans, partition_tag="American", _async=True, _callback=batman_insert_cb) insert_future.done() # ------ # Basic insert entities: # After insert entities into collection, we need to flush collection to make sure its on disk, # so that we are able to retrive it. # ------ print("\n----------flush----------") flush_future = client.flush([collection_name], _async=True) flush_future.result() # ------ # Basic hybrid search entities: # Getting films by id is not enough, we are going to get films based on vector similarities. # Let's say we have a film with its `embedding` and we want to find `top3` films that are most similar # with it by L2 distance. # Other than vector similarities, we also want to obtain films that: # `released year` term in 2002 or 2003, # `duration` larger than 250 minutes. # # Milvus provides Query DSL(Domain Specific Language) to support structured data filtering in queries. # For now milvus suppots TermQuery and RangeQuery, they are structured as below. # For more information about the meaning and other options about "must" and "bool", # please refer to DSL chapter of our pymilvus documentation
class SearchEngine: def __init__(self, host, port): self.host = os.environ.get('MILVUS_HOST', host) self.port = os.environ.get('MILVUS_PORT', str(port)) self.engine = Milvus(host=self.host, port=self.port) self.collection_name = None ################################################# # HANDLE COLLECTION ################################################# def create_collection(self, collection_name, dimension): # collection 생성 param = { 'collection_name': collection_name, 'dimension': dimension, 'index_file_size': 1000, 'metric_type': MetricType.IP } self.engine.create_collection(param) print('[INFO] collection {}을 생성했습니다.'.format(collection_name)) def drop_collection(self, collection_name): # collection 삭제 self.engine.drop_collection(collection_name=collection_name) print('[INFO] collection {}을 삭제했습니다.'.format(collection_name)) def get_collection_stats(self, collection_name): # collection 정보 출력 print(self.engine.get_collection_info(collection_name)) print(self.engine.get_collection_stats(collection_name)) def set_collection(self, collection_name): # 쿼리 조작을 하기 위한 collection 지정 self.collection_name = collection_name print('[INFO] setting collection {}'.format(self.collection_name)) ################################################# # UTILS ################################################# def check_set_collection(self): # 쿼리 조작을 위한 collection 지정이 되어있는 지 체크 assert self.collection_name is not None, '[ERROR] collection을 setting해 주십시오!!' def check_exist_data_by_key(self, key): # collection에 정해진 key가 존재하는 지 확인 self.check_set_collection() _, vector = self.engine.get_entity_by_id( collection_name=self.collection_name, ids=key) vector = vector if vector else [vector] return True if vector[0] else False def convert_key_format(self, key): return [key] if isinstance(key, int) else key def convert_value_format(self, value): rank = len(value.shape) assert rank < 2, '[ERROR] value의 dim을 2 미만으로 입력해 주세요!!' return value.reshape(1, -1) if rank == 1 else value ################################################# # INSERT ################################################# def insert_data(self, key, value): # 데이터를 collection에 입력 key = self.convert_key_format(key) value = self.convert_value_format(value) if self.check_exist_data_by_key(key): print("[ERROR] 이미 collection에 데이터가 존재합니다.") return self.engine.insert(collection_name=self.collection_name, records=value, ids=key) self.engine.flush([self.collection_name]) print('[INFO] insert key {}'.format(key)) ################################################# # DELELTE ################################################# def delete_data(self, key): # 데이터를 collection에서 제거 key = self.convert_key_format(key) if not self.check_exist_data_by_key(key): print("[ERROR] collection에 데이터가 존재하지 않습니다.") return self.engine.delete_entity_by_id(self.collection_name, key) self.engine.flush([self.collection_name]) print('[INFO] delete key {}'.format(key)) ################################################# # UPDATE ################################################# def update_data(self, key, value): # 데이터를 업데이트 key = self.convert_key_format(key) value = self.convert_value_format(value) if not self.check_exist_data_by_key(key): print("[ERROR] collection에 데이터가 존재하지 않습니다.") return self.engine.delete_entity_by_id(self.collection_name, key) self.engine.flush([self.collection_name]) self.engine.insert(collection_name=self.collection_name, records=value, ids=key) self.engine.flush([self.collection_name]) print('[INFO] update key {}'.format(key)) ################################################# # SEARCH ################################################# def search_by_feature(self, feature, top_k): # feature를 이용해서 데이터를 검색 self.check_set_collection() feature = self.convert_value_format(feature) _, result = self.engine.search(collection_name=self.collection_name, query_records=feature, top_k=top_k) li_id = [ list(map(lambda x: x.id, result[0])) for i in range(len(result)) ] li_dist = [ list(map(lambda x: x.distance, result[0])) for i in range(len(result)) ] return li_id, li_dist def search_by_key(self, key, top_k): # key를 이용해서 데이터를 검색 self.check_set_collection() key = self.convert_key_format(key) if not self.check_exist_data_by_key(key): print("[ERROR] collection에 데이터가 존재하지 않습니다.") return _, vector = self.engine.get_entity_by_id( collection_name=self.collection_name, ids=key) _, result = self.engine.search(collection_name=self.collection_name, query_records=vector, top_k=top_k + 1) li_id = [ list(map(lambda x: x.id, result[0][1:])) for i in range(len(result)) ] li_dist = [ list(map(lambda x: x.distance, result[0][1:])) for i in range(len(result)) ] return li_id, li_dist
def main(): milvus = Milvus(uri=uri) param = { 'collection_name': collection_name, 'dimension': _DIM, 'index_file_size': 32, #'metric_type': MetricType.IP 'metric_type': MetricType.L2 } # show collections in Milvus server _, collections = milvus.list_collections() # 创建 collection milvus.create_collection(param) # 创建 collection partion milvus.create_partition(collection_name, partition_tag) print(f'collections in Milvus: {collections}') # Describe demo_collection _, collection = milvus.get_collection_info(collection_name) print(f'descript demo_collection: {collection}') # build fake vectors vectors = [[random.random() for _ in range(_DIM)] for _ in range(10)] vectors1 = [[random.random() for _ in range(_DIM)] for _ in range(10)] status, id = milvus.insert(collection_name=collection_name, records=vectors, ids=list(range(10)), partition_tag=partition_tag) print(f'status: {status} | id: {id}') if not status.OK(): print(f"insert failded: {status}") status1, id1 = milvus.insert(collection_name=collection_name, records=vectors1, ids=list(range(10, 20)), partition_tag=partition_tag) print(f'status1: {status1} | id1: {id1}') ids_deleted = list(range(10)) status_delete = milvus.delete_entity_by_id(collection_name=collection_name, id_array=ids_deleted) if status_delete.OK(): print(f'delete successful') # Flush collection insered data to disk milvus.flush([collection_name]) # Get demo_collection row count status, result = milvus.count_entities(collection_name) print(f"demo_collection row count: {result}") # Obtain raw vectors by providing vector ids status, result_vectors = milvus.get_entity_by_id(collection_name, list(range(10, 20))) # create index of vectors, search more repidly index_param = {'nlist': 2} # create ivflat index in demo_collection status = milvus.create_index(collection_name, IndexType.IVF_FLAT, index_param) if status.OK(): print(f"create index ivf_flat succeeed") # use the top 10 vectors for similarity search query_vectors = vectors1[0:2] # execute vector similariy search search_param = {"nprobe": 16} param = { 'collection_name': collection_name, 'query_records': query_vectors, 'top_k': 1, 'params': search_param } status, results = milvus.search(**param) if status.OK(): if results[0][0].distance == 0.0: print('query result is correct') else: print('not correct') print(results) else: print(f'search failed: {status}') # 清除已经存在的collection milvus.drop_collection(collection_name=collection_name) milvus.close()
def main(): # Specify server addr when create milvus client instance # milvus client instance maintain a connection pool, param # `pool_size` specify the max connection num. milvus = Milvus(_HOST, _PORT) # Create collection demo_collection if it dosen't exist. collection_name = 'example_collection' ok = milvus.has_collection(collection_name) field_name = 'example_field' if not ok: fields = { "fields": [{ "name": field_name, "type": DataType.FLOAT_VECTOR, "metric_type": "L2", "params": { "dim": _DIM }, "indexes": [{ "metric_type": "L2" }] }] } milvus.create_collection(collection_name=collection_name, fields=fields) else: milvus.drop_collection(collection_name=collection_name) # Show collections in Milvus server collections = milvus.list_collections() print(collections) # Describe demo_collection stats = milvus.get_collection_stats(collection_name) print(stats) # 10000 vectors with 128 dimension # element per dimension is float32 type # vectors should be a 2-D array vectors = [[random.random() for _ in range(_DIM)] for _ in range(10)] print(vectors) # You can also use numpy to generate random vectors: # vectors = np.random.rand(10000, _DIM).astype(np.float32) # Insert vectors into demo_collection, return status and vectors id list entities = [{ "name": field_name, "type": DataType.FLOAT_VECTOR, "values": vectors }] res_ids = milvus.insert(collection_name=collection_name, entities=entities) print("ids:", res_ids) # Flush collection inserted data to disk. milvus.flush([collection_name]) # present collection statistics info stats = milvus.get_collection_stats(collection_name) print(stats) # create index of vectors, search more rapidly index_param = { "metric_type": "L2", "index_type": "IVF_FLAT", "params": { "nlist": 1024 } } # Create ivflat index in demo_collection # You can search vectors without creating index. however, Creating index help to # search faster print("Creating index: {}".format(index_param)) status = milvus.create_index(collection_name, field_name, index_param) # execute vector similarity search print("Searching ... ") dsl = { "bool": { "must": [{ "vector": { field_name: { "metric_type": "L2", "query": vectors, "topk": 10, "params": { "nprobe": 16 } } } }] } } milvus.load_collection(collection_name) results = milvus.search(collection_name, dsl) # indicate search result # also use by: # `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]` if results[0][0].distance == 0.0 or results[0][0].id == ids[0]: print('Query result is correct') else: print('Query result isn\'t correct') milvus.drop_index(collection_name, field_name) milvus.release_collection(collection_name) # Delete demo_collection status = milvus.drop_collection(collection_name)
vectors_list = vectors_array.tolist() vectors_list.append([1 for _ in range(vec_dim)]) ids_list = [i for i in range(len(vectors_list))] #单次插入的数据量不能大于 256 MB,插入后存在缓存区,缓存区大小由参数index_file_size决定,默认1024M milvus.insert(collection_name='test01', records=vectors_list, partition_tag="tag01",ids=ids_list) #一些信息 print(milvus.list_collections()) print(milvus.get_collection_info('test01')) print(milvus.get_index_info('test01')) t2 = time.time() print("create cost:",t2-t1) # # 创建查询向量 query_vec_list = [[1 for _ in range(vec_dim)]] # 进行查询, 注意这里的参数nprobe和建立索引时的参数nlist 会因为索引类型不同而影响到查询性能和查询准确率 #IVF_FLAT下查询多少个簇,不能超过nlist search_param = {'nprobe': 10} #现在数据还在内存,需要数据落盘,保存到数据库中去,不然查不到数据 milvus.flush(collection_name_array=['test01']) results = milvus.search(collection_name='test01', query_records=query_vec_list, top_k=10,params=search_param) print(results) print("search cost:",time.time()-t2) # # sudo docker run -d --name milvus_cpu \ # -p 19530:19530 \ # -p 19121:19121 \ # -v /home/$USER/milvus/db:/var/lib/milvus/db \ # -v /home/$USER/milvus/conf:/var/lib/milvus/conf \ # -v /home/$USER/milvus/logs:/var/lib/milvus/logs \ # -v /home/$USER/milvus/wal:/var/lib/milvus/wal \ # milvusdb/milvus:0.10.2-cpu-d081520-8a2393
class Test: def __init__(self, nvec): self.cname = "benchmark" self.fname = "feature" self.dim = 128 self.client = Milvus("localhost", 19530) self.prefix = '/sift1b/binary_128d_' self.suffix = '.npy' self.vecs_per_file = 100000 self.maxfiles = 1000 self.insert_bulk_size = 5000 self.nvec = nvec self.insert_cost = 0 self.flush_cost = 0 self.create_index_cost = 0 self.search_cost = 0 assert self.nvec >= self.insert_bulk_size & self.nvec % self.insert_bulk_size == 0 def run(self, suite): report = dict() try: # step 1 create collection logging.info(f'step 1 create collection') self._create_collection() logging.info(f'step 1 complete') # step 2 fill data logging.info(f'step 2 insert') start = time.time() self._insert() self.insert_cost = time.time() - start report["insert-speed"] = { "value": format(self.nvec / self.insert_cost, ".4f"), "unit": "vec/sec" } logging.info(f'step 2 complete') # step 3 flush logging.info(f'step 3 flush') start = time.time() self._flush() self.flush_cost = time.time() - start report["flush-cost"] = { "value": format(self.flush_cost, ".4f"), "unit": "s" } logging.info(f'step 3 complete') # step 4 create index logging.info(f'step 4 create index') start = time.time() self._create_index() self.create_index_cost = time.time() - start report["create-index-cost"] = { "value": format(self.create_index_cost, ".4f"), "unit": "s" } logging.info(f'step 4 complete') # step 5 load logging.info(f'step 5 load') self._load_collection() logging.info(f'step 5 complete') # step 6 search logging.info(f'step 6 search') for nq in suite["nq"]: for topk in suite["topk"]: for nprobe in suite["nprobe"]: start = time.time() self._search(nq=nq, topk=topk, nprobe=nprobe) self.search_cost = time.time() - start report[f"search-q{nq}-k{topk}-p{nprobe}-cost"] = { "value": format(self.search_cost, ".4f"), "unit": "s" } logging.info(f'step 6 complete') except AssertionError as ae: logging.exception(ae) except Exception as e: logging.error(f'test failed: {e}') finally: return report def _create_collection(self): logging.debug(f'create_collection() start') if self.client.has_collection(self.cname): logging.debug(f'collection {self.cname} existed') self.client.drop_collection(self.cname) logging.info(f'drop collection {self.cname}') logging.debug(f'before create collection: {self.cname}') self.client.create_collection(self.cname, { "fields": [{ "name": self.fname, "type": DataType.FLOAT_VECTOR, "metric_type": "L2", "params": {"dim": self.dim}, "indexes": [{"metric_type": "L2"}] }] }) logging.info(f'created collection: {self.cname}') assert self.client.has_collection(self.cname) logging.debug(f'create_collection() finished') def _insert(self): logging.debug(f'insert() start') count = 0 for i in range(0, self.maxfiles): filename = self.prefix + str(i).zfill(5) + self.suffix logging.debug(f'filename: {filename}') array = np.load(filename) logging.debug(f'numpy array shape: {array.shape}') step = self.insert_bulk_size for p in range(0, self.vecs_per_file, step): entities = [ {"name": self.fname, "type": DataType.FLOAT_VECTOR, "values": array[p:p + step][:].tolist()}] logging.debug(f'before insert slice: {p}, {p + step}') self.client.insert(self.cname, entities) logging.info(f'after insert slice: {p}, {p + step}') count += step logging.debug(f'insert count: {count}') if count == self.nvec: logging.debug(f'inner break') break if count == self.nvec: logging.debug(f'outer break') break logging.debug(f'insert() finished') def _flush(self): logging.debug(f'flush() start') logging.debug(f'before flush: {self.cname}') self.client.flush([self.cname]) logging.info(f'after flush') stats = self.client.get_collection_stats(self.cname) logging.debug(stats) assert stats["row_count"] == self.nvec logging.debug(f'flush() finished') def _create_index(self): logging.debug(f'create_index() start') index_params = { "metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 1024} } self.client.create_index(self.cname, self.fname, index_params) logging.debug(f'create index {self.cname} : {self.fname} : {index_params}') logging.debug(f'create_index() finished') def _load_collection(self): logging.debug(f'load_collection() start') logging.debug(f'before load collection: {self.cname}') self.client.load_collection(self.cname) logging.debug(f'load_collection() finished') def _search(self, nq, topk, nprobe): logging.debug(f'search() start') result = self.client.search(self.cname, {"bool": {"must": [{"vector": { self.fname: { "metric_type": "L2", "query": _gen_vectors(nq, self.dim), "topk": topk, "params": {"nprobe": nprobe} } }}]}} ) logging.debug(f'{result}') logging.debug(f'search() finished')
class MilvusDocumentStore(SQLDocumentStore): """ Milvus (https://milvus.io/) is a highly reliable, scalable Document Store specialized on storing and processing vectors. Therefore, it is particularly suited for Haystack users that work with dense retrieval methods (like DPR). In contrast to FAISS, Milvus ... - runs as a separate service (e.g. a Docker container) and can scale easily in a distributed environment - allows dynamic data management (i.e. you can insert/delete vectors without recreating the whole index) - encapsulates multiple ANN libraries (FAISS, ANNOY ...) This class uses Milvus for all vector related storage, processing and querying. The meta-data (e.g. for filtering) and the document text are however stored in a separate SQL Database as Milvus does not allow these data types (yet). Usage: 1. Start a Milvus server (see https://milvus.io/docs/v0.10.5/install_milvus.md) 2. Init a MilvusDocumentStore in Haystack """ def __init__( self, sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, update_existing_documents: bool = False, return_embedding: bool = False, embedding_field: str = "embedding", **kwargs, ): """ :param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale deployment, Postgres is recommended. If using MySQL then same server can also be used for Milvus metadata. For more details see https://milvus.io/docs/v0.10.5/data_manage.md. :param milvus_url: Milvus server connection URL for storing and processing vectors. Protocol, host and port will automatically be inferred from the URL. See https://milvus.io/docs/v0.10.5/install_milvus.md for instructions to start a Milvus instance. :param connection_pool: Connection pool type to connect with Milvus server. Default: "SingletonThread". :param index: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name"). :param vector_dim: The embedding vector size. Default: 768. :param index_file_size: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB. When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment. Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one. As a rule of thumb, we would see a 30% ~ 50% increase in the search performance after changing the value of index_file_size from 1024 to 2048. Note that an overly large index_file_size value may cause failure to load a segment into the memory or graphics memory. (From https://milvus.io/docs/v0.10.5/performance_faq.md#How-can-I-get-the-best-performance-from-Milvus-through-setting-index_file_size) :param similarity: The similarity function used to compare document vectors. 'dot_product' is the default and recommended for DPR embeddings. 'cosine' is recommended for Sentence Transformers, but is not directly supported by Milvus. However, you can normalize your embeddings and use `dot_product` to get the same results. See https://milvus.io/docs/v0.10.5/metric.md?Inner-product-(IP)#floating. :param index_type: Type of approximate nearest neighbour (ANN) index used. The choice here determines your tradeoff between speed and accuracy. Some popular options: - FLAT (default): Exact method, slow - IVF_FLAT, inverted file based heuristic, fast - HSNW: Graph based, fast - ANNOY: Tree based, fast See: https://milvus.io/docs/v0.10.5/index.md :param index_param: Configuration parameters for the chose index_type needed at indexing time. For example: {"nlist": 16384} as the number of cluster units to create for index_type IVF_FLAT. See https://milvus.io/docs/v0.10.5/index.md :param search_param: Configuration parameters for the chose index_type needed at query time For example: {"nprobe": 10} as the number of cluster units to query for index_type IVF_FLAT. See https://milvus.io/docs/v0.10.5/index.md :param update_existing_documents: Whether to update any existing documents with the same ID when adding documents. When set as True, any document with an existing ID gets updated. If set to False, an error is raised if the document ID of the document being added already exists. :param return_embedding: To return document embedding. :param embedding_field: Name of field containing an embedding vector. """ self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool) self.vector_dim = vector_dim self.index_file_size = index_file_size if similarity == "dot_product": self.metric_type = MetricType.L2 else: raise ValueError( "The Milvus document store can currently only support dot_product similarity. " "Please set similarity=\"dot_product\"") self.index_type = index_type self.index_param = index_param or {"nlist": 16384} self.search_param = search_param or {"nprobe": 10} self.index = index self._create_collection_and_index_if_not_exist(self.index) self.return_embedding = return_embedding self.embedding_field = embedding_field super().__init__(url=sql_url, update_existing_documents=update_existing_documents, index=index) def __del__(self): return self.milvus_server.close() def _create_collection_and_index_if_not_exist( self, index: Optional[str] = None, index_param: Optional[Dict[str, Any]] = None): index = index or self.index index_param = index_param or self.index_param status, ok = self.milvus_server.has_collection(collection_name=index) if not ok: collection_param = { 'collection_name': index, 'dimension': self.vector_dim, 'index_file_size': self.index_file_size, 'metric_type': self.metric_type } status = self.milvus_server.create_collection(collection_param) if status.code != Status.SUCCESS: raise RuntimeError( f'Collection creation on Milvus server failed: {status}') status = self.milvus_server.create_index(index, self.index_type, index_param) if status.code != Status.SUCCESS: raise RuntimeError( f'Index creation on Milvus server failed: {status}') def _create_document_field_map(self) -> Dict: return { self.index: self.embedding_field, } def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, batch_size: int = 10_000): """ Add new documents to the DocumentStore. :param documents: List of `Dicts` or List of `Documents`. If they already contain the embeddings, we'll index them right away in Milvus. If not, you can later call update_embeddings() to create & index them. :param index: (SQL) index name for storing the docs and metadata :param batch_size: When working with large number of documents, batching can help reduce memory footprint. :return: """ index = index or self.index self._create_collection_and_index_if_not_exist(index) field_map = self._create_document_field_map() if len(documents) == 0: logger.warning( "Calling DocumentStore.write_documents() with empty list") return document_objects = [ Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents ] add_vectors = False if document_objects[0].embedding is None else True batched_documents = get_batches_from_generator(document_objects, batch_size) with tqdm(total=len(document_objects)) as progress_bar: for document_batch in batched_documents: vector_ids = [] if add_vectors: doc_ids = [] embeddings = [] for doc in document_batch: doc_ids.append(doc.id) if isinstance(doc.embedding, np.ndarray): embeddings.append(doc.embedding.tolist()) elif isinstance(doc.embedding, list): embeddings.append(doc.embedding) else: raise AttributeError( f'Format of supplied document embedding {type(doc.embedding)} is not ' f'supported. Please use list or numpy.ndarray') if self.update_existing_documents: existing_docs = super().get_documents_by_id( ids=doc_ids, index=index) self._delete_vector_ids_from_milvus( documents=existing_docs, index=index) status, vector_ids = self.milvus_server.insert( collection_name=index, records=embeddings) if status.code != Status.SUCCESS: raise RuntimeError( f'Vector embedding insertion failed: {status}') docs_to_write_in_sql = [] for idx, doc in enumerate(document_batch): meta = doc.meta if add_vectors: meta["vector_id"] = vector_ids[idx] docs_to_write_in_sql.append(doc) super().write_documents(docs_to_write_in_sql, index=index) progress_bar.update(batch_size) progress_bar.close() self.milvus_server.flush([index]) if self.update_existing_documents: self.milvus_server.compact(collection_name=index) def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None, batch_size: int = 10_000): """ Updates the embeddings in the the document store using the encoding model specified in the retriever. This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config). :param retriever: Retriever to use to get embeddings for text :param index: (SQL) index name for storing the docs and metadata :param batch_size: When working with large number of documents, batching can help reduce memory footprint. :return: None """ index = index or self.index self._create_collection_and_index_if_not_exist(index) document_count = self.get_document_count(index=index) if document_count == 0: logger.warning( "Calling DocumentStore.update_embeddings() on an empty index") return logger.info(f"Updating embeddings for {document_count} docs...") result = self.get_all_documents_generator(index=index, batch_size=batch_size, return_embedding=False) batched_documents = get_batches_from_generator(result, batch_size) with tqdm(total=document_count) as progress_bar: for document_batch in batched_documents: self._delete_vector_ids_from_milvus(documents=document_batch, index=index) embeddings = retriever.embed_passages( document_batch) # type: ignore embeddings_list = [ embedding.tolist() for embedding in embeddings ] assert len(document_batch) == len(embeddings_list) status, vector_ids = self.milvus_server.insert( collection_name=index, records=embeddings_list) if status.code != Status.SUCCESS: raise RuntimeError( f'Vector embedding insertion failed: {status}') vector_id_map = {} for vector_id, doc in zip(vector_ids, document_batch): vector_id_map[doc.id] = vector_id self.update_vector_ids(vector_id_map, index=index) progress_bar.update(batch_size) progress_bar.close() self.milvus_server.flush([index]) self.milvus_server.compact(collection_name=index) def query_by_embedding( self, query_emb: np.array, filters: Optional[dict] = None, top_k: int = 10, index: Optional[str] = None, return_embedding: Optional[bool] = None) -> List[Document]: """ Find the document that is most similar to the provided `query_emb` by using a vector similarity metric. :param query_emb: Embedding of the query (e.g. gathered from DPR) :param filters: Optional filters to narrow down the search space. Example: {"name": ["some", "more"], "category": ["only_one"]} :param top_k: How many documents to return :param index: (SQL) index name for storing the docs and metadata :param return_embedding: To return document embedding :return: """ if filters: raise Exception( "Query filters are not implemented for the MilvusDocumentStore." ) index = index or self.index status, ok = self.milvus_server.has_collection(collection_name=index) if status.code != Status.SUCCESS: raise RuntimeError(f'Milvus has collection check failed: {status}') if not ok: raise Exception( "No index exists. Use 'update_embeddings()` to create an index." ) if return_embedding is None: return_embedding = self.return_embedding index = index or self.index query_emb = query_emb.reshape(1, -1).astype(np.float32) status, search_result = self.milvus_server.search( collection_name=index, query_records=query_emb, top_k=top_k, params=self.search_param) if status.code != Status.SUCCESS: raise RuntimeError(f'Vector embedding search failed: {status}') vector_ids_for_query = [] scores_for_vector_ids: Dict[str, float] = {} for vector_id_list, distance_list in zip(search_result.id_array, search_result.distance_array): for vector_id, distance in zip(vector_id_list, distance_list): vector_ids_for_query.append(str(vector_id)) scores_for_vector_ids[str(vector_id)] = distance documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index) if return_embedding: self._populate_embeddings_to_docs(index=index, docs=documents) for doc in documents: doc.score = scores_for_vector_ids[doc.meta["vector_id"]] doc.probability = float(expit(np.asarray(doc.score / 100))) return documents def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None): """ Delete all documents (from SQL AND Milvus). :param index: (SQL) index name for storing the docs and metadata :param filters: Optional filters to narrow down the search space. Example: {"name": ["some", "more"], "category": ["only_one"]} :return: None """ index = index or self.index super().delete_all_documents(index=index, filters=filters) status, ok = self.milvus_server.has_collection(collection_name=index) if status.code != Status.SUCCESS: raise RuntimeError(f'Milvus has collection check failed: {status}') if ok: status = self.milvus_server.drop_collection(collection_name=index) if status.code != Status.SUCCESS: raise RuntimeError(f'Milvus drop collection failed: {status}') self.milvus_server.flush([index]) self.milvus_server.compact(collection_name=index) def get_all_documents_generator( self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None, return_embedding: Optional[bool] = None, batch_size: int = 10_000, ) -> Generator[Document, None, None]: """ Get all documents from the document store. Under-the-hood, documents are fetched in batches from the document store and yielded as individual documents. This method can be used to iteratively process a large number of documents without having to load all documents in memory. :param index: Name of the index to get the documents from. If None, the DocumentStore's default index (self.index) will be used. :param filters: Optional filters to narrow down the documents to return. Example: {"name": ["some", "more"], "category": ["only_one"]} :param return_embedding: Whether to return the document embeddings. :param batch_size: When working with large number of documents, batching can help reduce memory footprint. """ index = index or self.index documents = super().get_all_documents_generator(index=index, filters=filters, batch_size=batch_size) if return_embedding is None: return_embedding = self.return_embedding for doc in documents: if return_embedding: self._populate_embeddings_to_docs(index=index, docs=[doc]) yield doc
class MyMilvus(): def __init__(self, name, host, port, collection_param): self.host = host self.port = port self.client = Milvus(host, port) self.collection_name = name self.collection_param = collection_param # self.collection_param = { # "fields": [ # {"name": "release_year", "type": DataType.INT32}, # {"name": "embedding", "type": DataType.FLOAT_VECTOR, "params": {"dim": 8}}, # ], # "segment_row_limit": 4096, # "auto_id": False # } def create_collection(self): if self.collection_name not in self.client.list_collections(): # self.client.drop_collection(self.collection_name) self.client.create_collection(self.collection_name, self.collection_param) # ------ # Basic create index: # Now that we have a collection in Milvus with `segment_row_limit` 4096, we can create index or # insert entities. # # We can call `create_index` BEFORE we insert any entities or AFTER. However Milvus won't actually # start build index task if the segment row count is smaller than `segment_row_limit`. So if we want # to make Milvus build index, we need to insert number of entities larger than `segment_row_limit`. # # We are going to use data in `films.csv` so you can checkout the structure. And we need to group # data with same fields together, so here is a example of how we obtain the data in files and transfer # them into what we need. # ------ ids = [] # ids titles = [] # titles release_years = [] # release year embeddings = [] # embeddings films = [] with open('films.csv', 'r') as csvfile: reader = csv.reader(csvfile) films = [film for film in reader] for film in films: ids.append(int(film[0])) titles.append(film[1]) release_years.append(int(film[2])) embeddings.append(list(map(float, film[3][1:][:-1].split(',')))) hybrid_entities = [ {"name": "release_year", "values": release_years, "type": DataType.INT32}, {"name": "embedding", "values": embeddings, "type": DataType.FLOAT_VECTOR}, ] # ------ # Basic insert: # After preparing the data, we are going to insert them into our collection. # The number of films inserted should be 8657. # ------ ids = self.client.insert(self.collection_name, hybrid_entities, ids) self.client.flush([self.collection_name]) after_flush_counts = self.client.count_entities(self.collection_name) print(" > There are {} films in collection `{}` after flush".format(after_flush_counts, self.collection_name)) # ------ # Basic create index: # Now that we have inserted all the films into Milvus, we are going to build index with these data. # # While building index, we have to indicate which `field` to build index for, the `index_type`, # `metric_type` and params for the specific index type. In our case, we want to build a `IVF_FLAT` # index, so the specific params are "nlist". See pymilvus documentation # (https://milvus-io.github.io/milvus-sdk-python/pythondoc/v0.3.0/index.html) for `index_type` we # support and the params accordingly. # # If there are already index for a collection and you call `create_index` with different params, the # older index will be replaced by new one. # ------ self.client.create_index(self.collection_name, "embedding", {"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 100}}) # ------ # Basic create index: # We can get the detail of the index by `get_collection_info`. # ------ info = self.client.get_collection_info(self.collection_name) pprint(info) # ------ # Basic hybrid search entities: # If we want to use index, the specific index params need to be provided, in our case, the "params" # should be "nprobe", if no "params" given, Milvus will complain about it and raise a exception. # ------ # query_embedding = [random.random() for _ in range(8)] # query_hybrid = { # "bool": { # "must": [ # { # "term": {"release_year": [2002, 1995]} # }, # { # "vector": { # "embedding": {"topk": 3, # "query": [query_embedding], # "metric_type": "L2", # "params": {"nprobe": 8}} # } # } # ] # } # } # ------ # Basic hybrid search entities # ------ # results = client.search(collection_name, query_hybrid, fields=["release_year", "embedding"]) # for entities in results: # for topk_film in entities: # current_entity = topk_film.entity # print("==") # print("- id: {}".format(topk_film.id)) # print("- title: {}".format(titles[topk_film.id])) # print("- distance: {}".format(topk_film.distance)) # # print("- release_year: {}".format(current_entity.release_year)) # print("- embedding: {}".format(current_entity.embedding)) # ------ # Basic delete index: # You can drop index for a field. # ------ self.client.drop_index(self.collection_name, "embedding") if self.collection_name in self.client.list_collections(): self.client.drop_collection(self.collection_name)
class ANN(object): def __init__(self, host='10.119.33.90', port='19530', show_info=False): self.client = Milvus(host, port) if show_info: logger.info({ "ClientVersion": self.client.client_version(), "ServerVersion": self.client.server_version() }) def create_collection(self, collection_name, collection_param, partition_tag=None, overwrite=True): """ :param collection_name: :param collection_param: collection_param = { "fields": [ # Milvus doesn't support string type now, but we are considering supporting it soon. # {"name": "title", "type": DataType.STRING}, {"name": "category_", "type": DataType.INT32}, {"name": "vector", "type": DataType.FLOAT_VECTOR, "params": {"dim": 768}}, ], "segment_row_limit": 4096, "auto_id": False } :param overwrite: :return: """ if self.client.has_collection(collection_name) and overwrite: self.client.drop_collection(collection_name) self.client.flush() time.sleep(5) self.client.create_collection(collection_name, collection_param) elif self.client.has_collection(collection_name): print(f"{collection_name} already exist !!!") else: self.client.create_collection(collection_name, collection_param) if partition_tag is not None: self.client.create_partition(collection_name, partition_tag=partition_tag) def create_index(self, collection_name, field_name, index_type='IVF_FLAT', metric_type='IP', index_params=None): """ MetricType: INVALID = 0 L2 = 1 IP = 2 # Only supported for byte vectors HAMMING = 3 JACCARD = 4 TANIMOTO = 5 # SUBSTRUCTURE = 6 SUPERSTRUCTURE = 7 IndexType: INVALID = 0 FLAT = 1 IVFLAT = 2 IVF_SQ8 = 3 RNSG = 4 IVF_SQ8H = 5 IVF_PQ = 6 HNSW = 11 ANNOY = 12 # alternative name IVF_FLAT = IVFLAT IVF_SQ8_H = IVF_SQ8H class DataType(IntEnum): NULL = 0 INT8 = 1 INT16 = 2 INT32 = 3 INT64 = 4 STRING = 20 BOOL = 30 FLOAT = 40 DOUBLE = 41 VECTOR = 100 UNKNOWN = 9999 class RangeType(IntEnum): LT = 0 # less than LTE = 1 # less than or equal EQ = 2 # equal GT = 3 # greater than GTE = 4 # greater than or equal NE = 5 # not equal :return: """ if index_params is None: index_params = {'nlist': 1024} params = { 'index_type': index_type, # 'index_file_size': 1024, 'params': index_params, 'metric_type': metric_type, } self.client.create_index(collection_name, field_name, params) # field_name='embedding' def batch_insert(self, collection_name, entities, batch_size=100000): # 分区 n = len(entities[0]['values']) num_part = n // batch_size + 1 ids = [] values_list = [_['values'] for _ in entities] for i in range(num_part): for e, values in zip(entities, values_list): e['values'] = values[i * batch_size:(i + 1) * batch_size] ids += self.client.insert(collection_name, entities) self.client.flush() return ids def search(self): # todo: 获取相同的信息 pass def drop_collection(self, collection_name): if self.client.has_collection(collection_name): self.client.drop_collection(collection_name) def drop_partition(self, collection_name, partition_tag): if self.client.has_partition(collection_name, partition_tag): self.client.drop_partition(collection_name, partition_tag, timeout=30)