release_years.append(int(film[2])) embeddings.append(list(map(float, film[3][1:][:-1].split(',')))) hybrid_entities = [ {"name": "release_year", "values": release_years, "type": DataType.INT64}, {"name": "embedding", "values": embeddings, "type": DataType.FLOAT_VECTOR}, ] # ------ # Basic insert: # After preparing the data, we are going to insert them into our collection. # The number of films inserted should be 8657. # ------ ids = client.insert(collection_name, hybrid_entities, ids) client.flush([collection_name]) after_flush_counts = client.count_entities(collection_name) print(" > There are {} films in collection `{}` after flush".format(after_flush_counts, collection_name)) # ------ # Basic create index: # Now that we have insert all the films into Milvus, we are going to build index with these datas. # # While build index, we have to indicate which `field` to build index for, the `index_type`, # `metric_type` and params for the specific index type. In our case, we want to build a `IVF_FLAT` # index, so the specific params are "nlist". See pymilvus documentation # (https://milvus-io.github.io/milvus-sdk-python/pythondoc/v0.3.0/index.html) for `index_type` we # support and the params accordingly
values = ids elif data_type in [DataType.FLOAT, DataType.DOUBLE]: values = [(i + 0.0) for i in ids] elif data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: values = vectors return values def generate_entities(info, vectors, ids=None): entities = [] for field in info["fields"]: if field["name"] == "_id": continue field_type = field["type"] entities.append({ "name": field["name"], "type": field_type, "values": generate_values(field_type, vectors, ids) }) return entities m = Milvus(host="127.0.0.1") info = m.describe_collection(name) print(info) ids = [random.randint(1, 10000000)] X = [[random.random() for _ in range(dim)] for _ in range(1)] entities = generate_entities(info, X, ids) print(entities) m.insert(name, entities, ids=ids)
"name": "age", "type": DataType.FLOAT, }, { "name": "id", "type": DataType.INT64, "auto_id": True, "is_primary": True, } ], }, orm=True) assert c.has_collection(collection_name) ids = c.insert(collection_name, [ {"name": "f1", "type": DataType.FLOAT_VECTOR, "values": [[1.1, 2.2, 3.3, 4.4], [5.5, 6.6, 7.7, 8.8]]}, {"name": "age", "type": DataType.FLOAT, "values": [3.45, 8.9]} ], orm=True) c.flush([collection_name]) c.load_collection(collection_name) ############################################################# search_params = {"metric_type": "L2", "params": {"nprobe": 1}} results = c.search_with_expression(collection_name, [[1.1, 2.2, 3.3, 4.4]], "f1", param=search_params, limit=2, output_fields=["id"]) print("search results: ", results[0][0].entity, " + ", results[0][1].entity) #############################################################
class MilvusClient(object): def __init__(self, collection_name=None, host=None, port=None, timeout=300): self._collection_name = collection_name self._collection_info = None start_time = time.time() if not host: host = config.SERVER_HOST_DEFAULT if not port: port = config.SERVER_PORT_DEFAULT # retry connect remote server i = 0 while time.time() < start_time + timeout: try: self._milvus = Milvus( host=host, port=port, try_connect=False, pre_ping=False) break except Exception as e: logger.error(str(e)) logger.error("Milvus connect failed: %d times" % i) i = i + 1 time.sleep(30) if time.time() > start_time + timeout: raise Exception("Server connect timeout") # self._metric_type = None def __str__(self): return 'Milvus collection %s' % self._collection_name def set_collection(self, collection_name): self._collection_name = collection_name # TODO: server not support # def check_status(self, status): # if not status.OK(): # logger.error(status.message) # logger.error(self._milvus.server_status()) # logger.error(self.count()) # raise Exception("Status not ok") def check_result_ids(self, result): for index, item in enumerate(result): if item[0].distance >= epsilon: logger.error(index) logger.error(item[0].distance) raise Exception("Distance wrong") @property def collection_name(self): return self._collection_name # only support the given field name def create_collection(self, dimension, data_type=DataType.FLOAT_VECTOR, auto_id=False, collection_name=None, other_fields=None): self._dimension = dimension if not collection_name: collection_name = self._collection_name vec_field_name = utils.get_default_field_name(data_type) fields = [ {"name": vec_field_name, "type": data_type, "params": {"dim": dimension}}, {"name": "id", "type": DataType.INT64, "is_primary": True} ] if other_fields: other_fields = other_fields.split(",") for other_field_name in other_fields: if other_field_name.startswith("int"): field_type = DataType.INT64 elif other_field_name.startswith("float"): field_type = DataType.FLOAT elif other_field_name.startswith("double"): field_type = DataType.DOUBLE else: raise Exception("Field name not supported") fields.append({"name": other_field_name, "type": field_type}) create_param = { "fields": fields, "auto_id": auto_id} try: self._milvus.create_collection(collection_name, create_param) logger.info("Create collection: <%s> successfully" % collection_name) except Exception as e: logger.error(str(e)) raise def create_partition(self, tag, collection_name=None): if not collection_name: collection_name = self._collection_name self._milvus.create_partition(collection_name, tag) @time_wrapper def insert(self, entities, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name try: insert_res = self._milvus.insert(tmp_collection_name, entities) return insert_res.primary_keys except Exception as e: logger.error(str(e)) def get_dimension(self): info = self.get_info() for field in info["fields"]: if field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: return field["params"]["dim"] def get_rand_ids(self, length): segment_ids = [] while True: stats = self.get_stats() segments = stats["partitions"][0]["segments"] # random choice one segment segment = random.choice(segments) try: segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["id"]) except Exception as e: logger.error(str(e)) if not len(segment_ids): continue elif len(segment_ids) > length: return random.sample(segment_ids, length) else: logger.debug("Reset length: %d" % len(segment_ids)) return segment_ids # def get_rand_ids_each_segment(self, length): # res = [] # status, stats = self._milvus.get_collection_stats(self._collection_name) # self.check_status(status) # segments = stats["partitions"][0]["segments"] # segments_num = len(segments) # # random choice from each segment # for segment in segments: # status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) # self.check_status(status) # res.extend(segment_ids[:length]) # return segments_num, res # def get_rand_entities(self, length): # ids = self.get_rand_ids(length) # status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids) # self.check_status(status) # return ids, get_res @time_wrapper def get_entities(self, get_ids): get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids) return get_res @time_wrapper def delete(self, ids, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name self._milvus.delete_entity_by_id(tmp_collection_name, ids) def delete_rand(self): delete_id_length = random.randint(1, 100) count_before = self.count() logger.debug("%s: length to delete: %d" % (self._collection_name, delete_id_length)) delete_ids = self.get_rand_ids(delete_id_length) self.delete(delete_ids) self.flush() logger.info("%s: count after delete: %d" % (self._collection_name, self.count())) get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids) for item in get_res: assert not item # if count_before - len(delete_ids) < self.count(): # logger.error(delete_ids) # raise Exception("Error occured") @time_wrapper def flush(self,_async=False, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name self._milvus.flush([tmp_collection_name], _async=_async) @time_wrapper def compact(self, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name status = self._milvus.compact(tmp_collection_name) self.check_status(status) # only support "in" in expr @time_wrapper def get(self, ids, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name # res = self._milvus.get(tmp_collection_name, ids, output_fields=None, partition_names=None) ids_expr = "id in %s" % (str(ids)) res = self._milvus.query(tmp_collection_name, ids_expr, output_fields=None, partition_names=None) return res @time_wrapper def create_index(self, field_name, index_type, metric_type, _async=False, index_param=None): index_type = INDEX_MAP[index_type] metric_type = utils.metric_type_trans(metric_type) logger.info("Building index start, collection_name: %s, index_type: %s, metric_type: %s" % ( self._collection_name, index_type, metric_type)) if index_param: logger.info(index_param) index_params = { "index_type": index_type, "metric_type": metric_type, "params": index_param } self._milvus.create_index(self._collection_name, field_name, index_params, _async=_async) # TODO: need to check def describe_index(self, field_name, collection_name=None): # stats = self.get_stats() tmp_collection_name = self._collection_name if collection_name is None else collection_name info = self._milvus.describe_index(tmp_collection_name, field_name) logger.info(info) index_info = {"index_type": "flat", "metric_type": None, "index_param": None} if info: index_info = {"index_type": info["index_type"], "metric_type": info["metric_type"], "index_param": info["params"]} # transfer index type name for k, v in INDEX_MAP.items(): if index_info['index_type'] == v: index_info['index_type'] = k return index_info def drop_index(self, field_name): logger.info("Drop index: %s" % self._collection_name) return self._milvus.drop_index(self._collection_name, field_name) @time_wrapper def query(self, vector_query, filter_query=None, collection_name=None, timeout=300): tmp_collection_name = self._collection_name if collection_name is None else collection_name must_params = [vector_query] if filter_query: must_params.extend(filter_query) query = { "bool": {"must": must_params} } result = self._milvus.search(tmp_collection_name, query, timeout=timeout) return result @time_wrapper def warm_query(self, index_field_name, search_param, metric_type, times=2): query_vectors = [[random.random() for _ in range(self._dimension)] for _ in range(DEFAULT_WARM_QUERY_NQ)] # index_info = self.describe_index(index_field_name) vector_query = {"vector": {index_field_name: { "topk": DEFAULT_WARM_QUERY_TOPK, "query": query_vectors, "metric_type": metric_type, "params": search_param} }} must_params = [vector_query] query = { "bool": {"must": must_params} } logger.debug("Start warm up query") for i in range(times): self._milvus.search(self._collection_name, query) logger.debug("End warm up query") @time_wrapper def load_and_query(self, vector_query, filter_query=None, collection_name=None, timeout=120): tmp_collection_name = self._collection_name if collection_name is None else collection_name must_params = [vector_query] if filter_query: must_params.extend(filter_query) query = { "bool": {"must": must_params} } self.load_collection(tmp_collection_name) result = self._milvus.search(tmp_collection_name, query, timeout=timeout) return result def get_ids(self, result): # idss = result._entities.ids ids = [] # len_idss = len(idss) # len_r = len(result) # top_k = len_idss // len_r # for offset in range(0, len_idss, top_k): # ids.append(idss[offset: min(offset + top_k, len_idss)]) for res in result: ids.append(res.ids) return ids def query_rand(self, nq_max=100): # for ivf search dimension = 128 top_k = random.randint(1, 100) nq = random.randint(1, nq_max) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] metric_type = random.choice(["l2", "ip"]) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) vec_field_name = utils.get_default_field_name() vector_query = {"vector": {vec_field_name: { "topk": top_k, "query": query_vectors, "metric_type": utils.metric_type_trans(metric_type), "params": search_param} }} self.query(vector_query) def load_query_rand(self, nq_max=100): # for ivf search dimension = 128 top_k = random.randint(1, 100) nq = random.randint(1, nq_max) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] metric_type = random.choice(["l2", "ip"]) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) vec_field_name = utils.get_default_field_name() vector_query = {"vector": {vec_field_name: { "topk": top_k, "query": query_vectors, "metric_type": utils.metric_type_trans(metric_type), "params": search_param} }} self.load_and_query(vector_query) # TODO: need to check def count(self, collection_name=None): if collection_name is None: collection_name = self._collection_name row_count = self._milvus.get_collection_stats(collection_name)["row_count"] logger.debug("Row count: %d in collection: <%s>" % (row_count, collection_name)) return row_count def drop(self, timeout=120, collection_name=None): timeout = int(timeout) if collection_name is None: collection_name = self._collection_name logger.info("Start delete collection: %s" % collection_name) self._milvus.drop_collection(collection_name) i = 0 while i < timeout: try: row_count = self.count(collection_name=collection_name) if row_count: time.sleep(1) i = i + 1 continue else: break except Exception as e: logger.warning("Collection count failed: {}".format(str(e))) break if i >= timeout: logger.error("Delete collection timeout") def get_stats(self): return self._milvus.get_collection_stats(self._collection_name) def get_info(self, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.describe_collection(collection_name) def show_collections(self): return self._milvus.list_collections() def exists_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name res = self._milvus.has_collection(collection_name) return res def clean_db(self): collection_names = self.show_collections() for name in collection_names: self.drop(collection_name=name) @time_wrapper def load_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.load_collection(collection_name, timeout=3000) @time_wrapper def release_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.release_collection(collection_name, timeout=3000) @time_wrapper def load_partitions(self, tag_names, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.load_partitions(collection_name, tag_names, timeout=3000) @time_wrapper def release_partitions(self, tag_names, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.release_partitions(collection_name, tag_names, timeout=3000)
def main(): # Specify server addr when create milvus client instance # milvus client instance maintain a connection pool, param # `pool_size` specify the max connection num. milvus = Milvus(_HOST, _PORT) # Create collection demo_collection if it dosen't exist. collection_name = 'example_collection' ok = milvus.has_collection(collection_name) field_name = 'example_field' id_name = "id" if not ok: fields = {"fields": [{ "name": field_name, "type": DataType.FLOAT_VECTOR, "metric_type": "L2", "params": {"dim": _DIM}, "indexes": [{"metric_type": "L2"}] }, { "name": id_name, "type": DataType.INT64, "auto_id": True, "is_primary": True, } ]} milvus.create_collection(collection_name=collection_name, fields=fields) else: milvus.drop_collection(collection_name=collection_name) # Show collections in Milvus server collections = milvus.list_collections() print(collections) # Describe demo_collection stats = milvus.get_collection_stats(collection_name) print(stats) # 10000 vectors with 128 dimension # element per dimension is float32 type # vectors should be a 2-D array vectors = [[random.random() for _ in range(_DIM)] for _ in range(10)] print(vectors) # You can also use numpy to generate random vectors: # vectors = np.random.rand(10000, _DIM).astype(np.float32) # Insert vectors into demo_collection, return status and vectors id list entities = [{"name": field_name, "type": DataType.FLOAT_VECTOR, "values": vectors}] res_ids = milvus.insert(collection_name=collection_name, entities=entities) print("ids:", res_ids.primary_keys) # Flush collection inserted data to disk. milvus.flush([collection_name]) # present collection statistics info stats = milvus.get_collection_stats(collection_name) print(stats) # create index of vectors, search more rapidly index_param = { "metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 1024} } # Create ivflat index in demo_collection # You can search vectors without creating index. however, Creating index help to # search faster print("Creating index: {}".format(index_param)) status = milvus.create_index(collection_name, field_name, index_param) # execute vector similarity search print("Searching ... ") dsl = {"bool": {"must": [{"vector": { field_name: { "metric_type": "L2", "query": vectors, "topk": 10, "params": {"nprobe": 16} } }}]}} search_params = {"metric_type": "L2", "params": {"nprobe": 10}} milvus.load_collection(collection_name) results = milvus.search_with_expression(collection_name, vectors, field_name, param=search_params, limit=10, output_fields=[id_name]) print("search results: ", results[0][0].entity) # indicate search result # also use by: # `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]` if results[0][0].distance == 0.0 or results[0][0].id == ids[0]: print('Query result is correct') else: print('Query result isn\'t correct') milvus.drop_index(collection_name, field_name) milvus.release_collection(collection_name) # Delete demo_collection status = milvus.drop_collection(collection_name)