def setup(self) -> None: self._real_time = grpc_testing.strict_real_time() self._real_time_channel = grpc_testing.channel( milvus_pb2.DESCRIPTOR.services_by_name.values(), self._real_time) self._servicer = milvus_pb2.DESCRIPTOR.services_by_name[ 'MilvusService'] self._milvus = Milvus(channel=self._real_time_channel)
def __init__(self, collection_name=None, host=None, port=None, timeout=300): self._collection_name = collection_name self._collection_info = None start_time = time.time() if not host: host = config.SERVER_HOST_DEFAULT if not port: port = config.SERVER_PORT_DEFAULT # retry connect remote server i = 0 while time.time() < start_time + timeout: try: self._milvus = Milvus( host=host, port=port, try_connect=False, pre_ping=False) break except Exception as e: logger.error(str(e)) logger.error("Milvus connect failed: %d times" % i) i = i + 1 time.sleep(30) if time.time() > start_time + timeout: raise Exception("Server connect timeout")
def get_milvus(host, port, uri=None, handler=None, **kwargs): if handler is None: handler = "GRPC" try_connect = kwargs.get("try_connect", True) if uri is not None: milvus = Milvus(uri=uri, handler=handler, try_connect=try_connect) else: milvus = Milvus(host=host, port=port, handler=handler, try_connect=try_connect) return milvus
def test_connect_repeatedly(self, args): """ target: test connect repeatedly method: connect again expected: status.code is 0, and status.message shows have connected already """ uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) milvus = Milvus(uri=uri_value, handler=args["handler"]) milvus = Milvus(uri=uri_value, handler=args["handler"])
class TestCreateCollection: @pytest.fixture(scope="function") def collection_name(self): return f"test_collection_{random.randint(100000, 999999)}" def setup(self) -> None: self._real_time = grpc_testing.strict_real_time() self._real_time_channel = grpc_testing.channel( milvus_pb2.DESCRIPTOR.services_by_name.values(), self._real_time) self._servicer = milvus_pb2.DESCRIPTOR.services_by_name[ 'MilvusService'] self._milvus = Milvus(channel=self._real_time_channel, try_connect=False, pre_ping=False) def teardown(self) -> None: pass def test_create_collection(self, collection_name): id_field = { "name": "my_id", "type": DataType.INT64, "auto_id": True, "is_primary": True, } vector_field = { "name": "embedding", "type": DataType.FLOAT_VECTOR, "metric_type": "L2", "params": { "dim": "4" }, } fields = {"fields": [id_field, vector_field]} future = self._milvus.create_collection( collection_name=collection_name, fields=fields, _async=True) invocation_metadata, request, rpc = self._real_time_channel.take_unary_unary( self._servicer.methods_by_name['CreateCollection']) rpc.send_initial_metadata(()) rpc.terminate( common_pb2.Status(error_code=common_pb2.Success, reason="success"), (), grpc.StatusCode.OK, '') request_schema = schema_pb2.CollectionSchema() request_schema.ParseFromString(request.schema) assert request.collection_name == collection_name assert Fields.equal(request_schema.fields, fields["fields"]) return_value = future.result() assert return_value.error_code == common_pb2.Success assert return_value.reason == "success"
import random import sys import math import time import numpy as np from pymilvus import Milvus, DataType # This example shows how to use milvus to calculate vectors distance _HOST = '127.0.0.1' _PORT = '19530' # Create milvus client instance milvus = Milvus(_HOST, _PORT) _PRECISION = 1e-3 def gen_float_vectors(num, dim): vec_list = [[random.random() for _ in range(dim)] for _ in range(num)] return vec_list def gen_binary_vectors(num, dim): zero_fill = 0 if dim % 8 > 0: zero_fill = 8 - dim % 8 binary_vectors = [] raw_vectors = [] for i in range(num):
We will be using `films.csv` to demenstrate how can we build index and search by index on Milvus. We assuming you have read `example.py` and have a basic conception about how to communicate with Milvus using pymilvus This example is runable for Milvus(0.11.x) and pymilvus(0.3.x). """ import random import csv from pprint import pprint from pymilvus import Milvus, DataType _HOST = '127.0.0.1' _PORT = '19530' client = Milvus(_HOST, _PORT) collection_name = 'demo_index' if collection_name in client.list_collections(): client.drop_collection(collection_name) collection_param = { "fields": [ {"name": "release_year", "type": DataType.INT64, "is_primary": True}, {"name": "embedding", "type": DataType.FLOAT_VECTOR, "params": {"dim": 8}}, ], } client.create_collection(collection_name, collection_param)
values = ids elif data_type in [DataType.FLOAT, DataType.DOUBLE]: values = [(i + 0.0) for i in ids] elif data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: values = vectors return values def generate_entities(info, vectors, ids=None): entities = [] for field in info["fields"]: if field["name"] == "_id": continue field_type = field["type"] entities.append({ "name": field["name"], "type": field_type, "values": generate_values(field_type, vectors, ids) }) return entities m = Milvus(host="127.0.0.1") info = m.describe_collection(name) print(info) ids = [random.randint(1, 10000000)] X = [[random.random() for _ in range(dim)] for _ in range(1)] entities = generate_entities(info, X, ids) print(entities) m.insert(name, entities, ids=ids)
class MilvusClient(object): def __init__(self, collection_name=None, host=None, port=None, timeout=300): self._collection_name = collection_name self._collection_info = None start_time = time.time() if not host: host = config.SERVER_HOST_DEFAULT if not port: port = config.SERVER_PORT_DEFAULT # retry connect remote server i = 0 while time.time() < start_time + timeout: try: self._milvus = Milvus( host=host, port=port, try_connect=False, pre_ping=False) break except Exception as e: logger.error(str(e)) logger.error("Milvus connect failed: %d times" % i) i = i + 1 time.sleep(30) if time.time() > start_time + timeout: raise Exception("Server connect timeout") # self._metric_type = None def __str__(self): return 'Milvus collection %s' % self._collection_name def set_collection(self, collection_name): self._collection_name = collection_name # TODO: server not support # def check_status(self, status): # if not status.OK(): # logger.error(status.message) # logger.error(self._milvus.server_status()) # logger.error(self.count()) # raise Exception("Status not ok") def check_result_ids(self, result): for index, item in enumerate(result): if item[0].distance >= epsilon: logger.error(index) logger.error(item[0].distance) raise Exception("Distance wrong") @property def collection_name(self): return self._collection_name # only support the given field name def create_collection(self, dimension, data_type=DataType.FLOAT_VECTOR, auto_id=False, collection_name=None, other_fields=None): self._dimension = dimension if not collection_name: collection_name = self._collection_name vec_field_name = utils.get_default_field_name(data_type) fields = [ {"name": vec_field_name, "type": data_type, "params": {"dim": dimension}}, {"name": "id", "type": DataType.INT64, "is_primary": True} ] if other_fields: other_fields = other_fields.split(",") for other_field_name in other_fields: if other_field_name.startswith("int"): field_type = DataType.INT64 elif other_field_name.startswith("float"): field_type = DataType.FLOAT elif other_field_name.startswith("double"): field_type = DataType.DOUBLE else: raise Exception("Field name not supported") fields.append({"name": other_field_name, "type": field_type}) create_param = { "fields": fields, "auto_id": auto_id} try: self._milvus.create_collection(collection_name, create_param) logger.info("Create collection: <%s> successfully" % collection_name) except Exception as e: logger.error(str(e)) raise def create_partition(self, tag, collection_name=None): if not collection_name: collection_name = self._collection_name self._milvus.create_partition(collection_name, tag) @time_wrapper def insert(self, entities, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name try: insert_res = self._milvus.insert(tmp_collection_name, entities) return insert_res.primary_keys except Exception as e: logger.error(str(e)) def get_dimension(self): info = self.get_info() for field in info["fields"]: if field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: return field["params"]["dim"] def get_rand_ids(self, length): segment_ids = [] while True: stats = self.get_stats() segments = stats["partitions"][0]["segments"] # random choice one segment segment = random.choice(segments) try: segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["id"]) except Exception as e: logger.error(str(e)) if not len(segment_ids): continue elif len(segment_ids) > length: return random.sample(segment_ids, length) else: logger.debug("Reset length: %d" % len(segment_ids)) return segment_ids # def get_rand_ids_each_segment(self, length): # res = [] # status, stats = self._milvus.get_collection_stats(self._collection_name) # self.check_status(status) # segments = stats["partitions"][0]["segments"] # segments_num = len(segments) # # random choice from each segment # for segment in segments: # status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) # self.check_status(status) # res.extend(segment_ids[:length]) # return segments_num, res # def get_rand_entities(self, length): # ids = self.get_rand_ids(length) # status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids) # self.check_status(status) # return ids, get_res @time_wrapper def get_entities(self, get_ids): get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids) return get_res @time_wrapper def delete(self, ids, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name self._milvus.delete_entity_by_id(tmp_collection_name, ids) def delete_rand(self): delete_id_length = random.randint(1, 100) count_before = self.count() logger.debug("%s: length to delete: %d" % (self._collection_name, delete_id_length)) delete_ids = self.get_rand_ids(delete_id_length) self.delete(delete_ids) self.flush() logger.info("%s: count after delete: %d" % (self._collection_name, self.count())) get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids) for item in get_res: assert not item # if count_before - len(delete_ids) < self.count(): # logger.error(delete_ids) # raise Exception("Error occured") @time_wrapper def flush(self,_async=False, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name self._milvus.flush([tmp_collection_name], _async=_async) @time_wrapper def compact(self, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name status = self._milvus.compact(tmp_collection_name) self.check_status(status) # only support "in" in expr @time_wrapper def get(self, ids, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name # res = self._milvus.get(tmp_collection_name, ids, output_fields=None, partition_names=None) ids_expr = "id in %s" % (str(ids)) res = self._milvus.query(tmp_collection_name, ids_expr, output_fields=None, partition_names=None) return res @time_wrapper def create_index(self, field_name, index_type, metric_type, _async=False, index_param=None): index_type = INDEX_MAP[index_type] metric_type = utils.metric_type_trans(metric_type) logger.info("Building index start, collection_name: %s, index_type: %s, metric_type: %s" % ( self._collection_name, index_type, metric_type)) if index_param: logger.info(index_param) index_params = { "index_type": index_type, "metric_type": metric_type, "params": index_param } self._milvus.create_index(self._collection_name, field_name, index_params, _async=_async) # TODO: need to check def describe_index(self, field_name, collection_name=None): # stats = self.get_stats() tmp_collection_name = self._collection_name if collection_name is None else collection_name info = self._milvus.describe_index(tmp_collection_name, field_name) logger.info(info) index_info = {"index_type": "flat", "metric_type": None, "index_param": None} if info: index_info = {"index_type": info["index_type"], "metric_type": info["metric_type"], "index_param": info["params"]} # transfer index type name for k, v in INDEX_MAP.items(): if index_info['index_type'] == v: index_info['index_type'] = k return index_info def drop_index(self, field_name): logger.info("Drop index: %s" % self._collection_name) return self._milvus.drop_index(self._collection_name, field_name) @time_wrapper def query(self, vector_query, filter_query=None, collection_name=None, timeout=300): tmp_collection_name = self._collection_name if collection_name is None else collection_name must_params = [vector_query] if filter_query: must_params.extend(filter_query) query = { "bool": {"must": must_params} } result = self._milvus.search(tmp_collection_name, query, timeout=timeout) return result @time_wrapper def warm_query(self, index_field_name, search_param, metric_type, times=2): query_vectors = [[random.random() for _ in range(self._dimension)] for _ in range(DEFAULT_WARM_QUERY_NQ)] # index_info = self.describe_index(index_field_name) vector_query = {"vector": {index_field_name: { "topk": DEFAULT_WARM_QUERY_TOPK, "query": query_vectors, "metric_type": metric_type, "params": search_param} }} must_params = [vector_query] query = { "bool": {"must": must_params} } logger.debug("Start warm up query") for i in range(times): self._milvus.search(self._collection_name, query) logger.debug("End warm up query") @time_wrapper def load_and_query(self, vector_query, filter_query=None, collection_name=None, timeout=120): tmp_collection_name = self._collection_name if collection_name is None else collection_name must_params = [vector_query] if filter_query: must_params.extend(filter_query) query = { "bool": {"must": must_params} } self.load_collection(tmp_collection_name) result = self._milvus.search(tmp_collection_name, query, timeout=timeout) return result def get_ids(self, result): # idss = result._entities.ids ids = [] # len_idss = len(idss) # len_r = len(result) # top_k = len_idss // len_r # for offset in range(0, len_idss, top_k): # ids.append(idss[offset: min(offset + top_k, len_idss)]) for res in result: ids.append(res.ids) return ids def query_rand(self, nq_max=100): # for ivf search dimension = 128 top_k = random.randint(1, 100) nq = random.randint(1, nq_max) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] metric_type = random.choice(["l2", "ip"]) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) vec_field_name = utils.get_default_field_name() vector_query = {"vector": {vec_field_name: { "topk": top_k, "query": query_vectors, "metric_type": utils.metric_type_trans(metric_type), "params": search_param} }} self.query(vector_query) def load_query_rand(self, nq_max=100): # for ivf search dimension = 128 top_k = random.randint(1, 100) nq = random.randint(1, nq_max) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] metric_type = random.choice(["l2", "ip"]) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) vec_field_name = utils.get_default_field_name() vector_query = {"vector": {vec_field_name: { "topk": top_k, "query": query_vectors, "metric_type": utils.metric_type_trans(metric_type), "params": search_param} }} self.load_and_query(vector_query) # TODO: need to check def count(self, collection_name=None): if collection_name is None: collection_name = self._collection_name row_count = self._milvus.get_collection_stats(collection_name)["row_count"] logger.debug("Row count: %d in collection: <%s>" % (row_count, collection_name)) return row_count def drop(self, timeout=120, collection_name=None): timeout = int(timeout) if collection_name is None: collection_name = self._collection_name logger.info("Start delete collection: %s" % collection_name) self._milvus.drop_collection(collection_name) i = 0 while i < timeout: try: row_count = self.count(collection_name=collection_name) if row_count: time.sleep(1) i = i + 1 continue else: break except Exception as e: logger.warning("Collection count failed: {}".format(str(e))) break if i >= timeout: logger.error("Delete collection timeout") def get_stats(self): return self._milvus.get_collection_stats(self._collection_name) def get_info(self, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.describe_collection(collection_name) def show_collections(self): return self._milvus.list_collections() def exists_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name res = self._milvus.has_collection(collection_name) return res def clean_db(self): collection_names = self.show_collections() for name in collection_names: self.drop(collection_name=name) @time_wrapper def load_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.load_collection(collection_name, timeout=3000) @time_wrapper def release_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.release_collection(collection_name, timeout=3000) @time_wrapper def load_partitions(self, tag_names, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.load_partitions(collection_name, tag_names, timeout=3000) @time_wrapper def release_partitions(self, tag_names, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.release_partitions(collection_name, tag_names, timeout=3000)
from pymilvus import Milvus, DataType import random from pprint import pprint if __name__ == "__main__": c = Milvus("localhost", "19530") collection_name = f"test_{random.randint(10000, 99999)}" c.create_collection(collection_name, {"fields": [ { "name": "f1", "type": DataType.FLOAT_VECTOR, "metric_type": "L2", "params": {"dim": 4}, "indexes": [{"metric_type": "L2"}] }, { "name": "age", "type": DataType.FLOAT, }, { "name": "id", "type": DataType.INT64, "auto_id": True, "is_primary": True, } ], }, orm=True) assert c.has_collection(collection_name)
We will be using `films.csv` to demenstrate how can we build index and search by index on Milvus. We assuming you have read `example.py` and have a basic conception about how to communicate with Milvus using pymilvus This example is runable for Milvus(0.11.x) and pymilvus(0.3.x). """ import random import csv from pprint import pprint from pymilvus import Milvus, DataType _HOST = '127.0.0.1' _PORT = '19530' client = Milvus(_HOST, _PORT) collection_name = 'demo_index' if collection_name in client.list_collections(): client.drop_collection(collection_name) collection_param = { "fields": [ { "name": "id", "type": DataType.INT64, "is_primary": True }, { "name": "release_year", "type": DataType.INT64
def main(): # Specify server addr when create milvus client instance # milvus client instance maintain a connection pool, param # `pool_size` specify the max connection num. milvus = Milvus(_HOST, _PORT) # Create collection demo_collection if it dosen't exist. collection_name = 'example_collection' ok = milvus.has_collection(collection_name) field_name = 'example_field' id_name = "id" if not ok: fields = {"fields": [{ "name": field_name, "type": DataType.FLOAT_VECTOR, "metric_type": "L2", "params": {"dim": _DIM}, "indexes": [{"metric_type": "L2"}] }, { "name": id_name, "type": DataType.INT64, "auto_id": True, "is_primary": True, } ]} milvus.create_collection(collection_name=collection_name, fields=fields) else: milvus.drop_collection(collection_name=collection_name) # Show collections in Milvus server collections = milvus.list_collections() print(collections) # Describe demo_collection stats = milvus.get_collection_stats(collection_name) print(stats) # 10000 vectors with 128 dimension # element per dimension is float32 type # vectors should be a 2-D array vectors = [[random.random() for _ in range(_DIM)] for _ in range(10)] print(vectors) # You can also use numpy to generate random vectors: # vectors = np.random.rand(10000, _DIM).astype(np.float32) # Insert vectors into demo_collection, return status and vectors id list entities = [{"name": field_name, "type": DataType.FLOAT_VECTOR, "values": vectors}] res_ids = milvus.insert(collection_name=collection_name, entities=entities) print("ids:", res_ids.primary_keys) # Flush collection inserted data to disk. milvus.flush([collection_name]) # present collection statistics info stats = milvus.get_collection_stats(collection_name) print(stats) # create index of vectors, search more rapidly index_param = { "metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 1024} } # Create ivflat index in demo_collection # You can search vectors without creating index. however, Creating index help to # search faster print("Creating index: {}".format(index_param)) status = milvus.create_index(collection_name, field_name, index_param) # execute vector similarity search print("Searching ... ") dsl = {"bool": {"must": [{"vector": { field_name: { "metric_type": "L2", "query": vectors, "topk": 10, "params": {"nprobe": 16} } }}]}} search_params = {"metric_type": "L2", "params": {"nprobe": 10}} milvus.load_collection(collection_name) results = milvus.search_with_expression(collection_name, vectors, field_name, param=search_params, limit=10, output_fields=[id_name]) print("search results: ", results[0][0].entity) # indicate search result # also use by: # `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]` if results[0][0].distance == 0.0 or results[0][0].id == ids[0]: print('Query result is correct') else: print('Query result isn\'t correct') milvus.drop_index(collection_name, field_name) milvus.release_collection(collection_name) # Delete demo_collection status = milvus.drop_collection(collection_name)