def test_not_connect(self): client = Milvus() with pytest.raises(NotConnectError): client.create_collection({}) with pytest.raises(NotConnectError): client.has_collection("a") with pytest.raises(NotConnectError): client.describe_collection("a") with pytest.raises(NotConnectError): client.drop_collection("a") with pytest.raises(NotConnectError): client.create_index("a") with pytest.raises(NotConnectError): client.insert("a", [], None) with pytest.raises(NotConnectError): client.count_collection("a") with pytest.raises(NotConnectError): client.show_collections() with pytest.raises(NotConnectError): client.search("a", 1, 2, [], None) with pytest.raises(NotConnectError): client.search_in_files("a", [], [], 2, 1, None) with pytest.raises(NotConnectError): client._cmd("") with pytest.raises(NotConnectError): client.preload_collection("a") with pytest.raises(NotConnectError): client.describe_index("a") with pytest.raises(NotConnectError): client.drop_index("")
class MilvusClient(object): def __init__(self, collection_name=None, host=None, port=None, timeout=180): self._collection_name = collection_name start_time = time.time() if not host: host = SERVER_HOST_DEFAULT if not port: port = SERVER_PORT_DEFAULT logger.debug(host) logger.debug(port) # retry connect remote server i = 0 while time.time() < start_time + timeout: try: self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False) break except Exception as e: logger.error(str(e)) logger.error("Milvus connect failed: %d times" % i) i = i + 1 time.sleep(i) if time.time() > start_time + timeout: raise Exception("Server connect timeout") # self._metric_type = None def __str__(self): return 'Milvus collection %s' % self._collection_name def check_status(self, status): if not status.OK(): logger.error(status.message) logger.error(self._milvus.server_status()) logger.error(self.count()) raise Exception("Status not ok") def check_result_ids(self, result): for index, item in enumerate(result): if item[0].distance >= epsilon: logger.error(index) logger.error(item[0].distance) raise Exception("Distance wrong") # only support the given field name def create_collection(self, dimension, data_type=DataType.FLOAT_VECTOR, auto_id=False, collection_name=None, other_fields=None): self._dimension = dimension if not collection_name: collection_name = self._collection_name vec_field_name = utils.get_default_field_name(data_type) fields = [{ "name": vec_field_name, "type": data_type, "params": { "dim": dimension } }] if other_fields: other_fields = other_fields.split(",") if "int" in other_fields: fields.append({ "name": utils.DEFAULT_INT_FIELD_NAME, "type": DataType.INT64 }) if "float" in other_fields: fields.append({ "name": utils.DEFAULT_FLOAT_FIELD_NAME, "type": DataType.FLOAT }) create_param = {"fields": fields, "auto_id": auto_id} try: self._milvus.create_collection(collection_name, create_param) logger.info("Create collection: <%s> successfully" % collection_name) except Exception as e: logger.error(str(e)) raise def create_partition(self, tag, collection_name=None): if not collection_name: collection_name = self._collection_name self._milvus.create_partition(collection_name, tag) def generate_values(self, data_type, vectors, ids): values = None if data_type in [DataType.INT32, DataType.INT64]: values = ids elif data_type in [DataType.FLOAT, DataType.DOUBLE]: values = [(i + 0.0) for i in ids] elif data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: values = vectors return values def generate_entities(self, vectors, ids=None, collection_name=None): entities = [] if collection_name is None: collection_name = self._collection_name info = self.get_info(collection_name) for field in info["fields"]: field_type = field["type"] entities.append({ "name": field["name"], "type": field_type, "values": self.generate_values(field_type, vectors, ids) }) return entities @time_wrapper def insert(self, entities, ids=None, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name try: insert_ids = self._milvus.insert(tmp_collection_name, entities, ids=ids) return insert_ids except Exception as e: logger.error(str(e)) def get_dimension(self): info = self.get_info() for field in info["fields"]: if field["type"] in [ DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR ]: return field["params"]["dim"] def get_rand_ids(self, length): segment_ids = [] while True: stats = self.get_stats() segments = stats["partitions"][0]["segments"] # random choice one segment segment = random.choice(segments) try: segment_ids = self._milvus.list_id_in_segment( self._collection_name, segment["id"]) except Exception as e: logger.error(str(e)) if not len(segment_ids): continue elif len(segment_ids) > length: return random.sample(segment_ids, length) else: logger.debug("Reset length: %d" % len(segment_ids)) return segment_ids # def get_rand_ids_each_segment(self, length): # res = [] # status, stats = self._milvus.get_collection_stats(self._collection_name) # self.check_status(status) # segments = stats["partitions"][0]["segments"] # segments_num = len(segments) # # random choice from each segment # for segment in segments: # status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) # self.check_status(status) # res.extend(segment_ids[:length]) # return segments_num, res # def get_rand_entities(self, length): # ids = self.get_rand_ids(length) # status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids) # self.check_status(status) # return ids, get_res def get(self): get_ids = random.randint(1, 1000000) self._milvus.get_entity_by_id(self._collection_name, [get_ids]) @time_wrapper def get_entities(self, get_ids): get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids) return get_res @time_wrapper def delete(self, ids, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name self._milvus.delete_entity_by_id(tmp_collection_name, ids) def delete_rand(self): delete_id_length = random.randint(1, 100) count_before = self.count() logger.debug("%s: length to delete: %d" % (self._collection_name, delete_id_length)) delete_ids = self.get_rand_ids(delete_id_length) self.delete(delete_ids) self.flush() logger.info("%s: count after delete: %d" % (self._collection_name, self.count())) get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids) for item in get_res: assert not item # if count_before - len(delete_ids) < self.count(): # logger.error(delete_ids) # raise Exception("Error occured") @time_wrapper def flush(self, _async=False, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name self._milvus.flush([tmp_collection_name], _async=_async) @time_wrapper def compact(self, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name status = self._milvus.compact(tmp_collection_name) self.check_status(status) @time_wrapper def create_index(self, field_name, index_type, metric_type, _async=False, index_param=None): index_type = INDEX_MAP[index_type] metric_type = utils.metric_type_trans(metric_type) logger.info( "Building index start, collection_name: %s, index_type: %s, metric_type: %s" % (self._collection_name, index_type, metric_type)) if index_param: logger.info(index_param) index_params = { "index_type": index_type, "metric_type": metric_type, "params": index_param } self._milvus.create_index(self._collection_name, field_name, index_params, _async=_async) # TODO: need to check def describe_index(self, field_name): # stats = self.get_stats() info = self._milvus.describe_index(self._collection_name, field_name) index_info = {"index_type": "flat", "index_param": None} for field in info["fields"]: for index in field['indexes']: if not index or "index_type" not in index: continue else: for k, v in INDEX_MAP.items(): if index['index_type'] == v: index_info['index_type'] = k index_info['index_param'] = index['params'] return index_info return index_info def drop_index(self, field_name): logger.info("Drop index: %s" % self._collection_name) return self._milvus.drop_index(self._collection_name, field_name) @time_wrapper def query(self, vector_query, filter_query=None, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name must_params = [vector_query] if filter_query: must_params.extend(filter_query) query = {"bool": {"must": must_params}} result = self._milvus.search(tmp_collection_name, query) return result @time_wrapper def load_and_query(self, vector_query, filter_query=None, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name must_params = [vector_query] if filter_query: must_params.extend(filter_query) query = {"bool": {"must": must_params}} self.load_collection(tmp_collection_name) result = self._milvus.search(tmp_collection_name, query) return result def get_ids(self, result): idss = result._entities.ids ids = [] len_idss = len(idss) len_r = len(result) top_k = len_idss // len_r for offset in range(0, len_idss, top_k): ids.append(idss[offset:min(offset + top_k, len_idss)]) return ids def query_rand(self, nq_max=100): # for ivf search dimension = 128 top_k = random.randint(1, 100) nq = random.randint(1, nq_max) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] metric_type = random.choice(["l2", "ip"]) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) vec_field_name = utils.get_default_field_name() vector_query = { "vector": { vec_field_name: { "topk": top_k, "query": query_vectors, "metric_type": utils.metric_type_trans(metric_type), "params": search_param } } } self.query(vector_query) def load_query_rand(self, nq_max=100): # for ivf search dimension = 128 top_k = random.randint(1, 100) nq = random.randint(1, nq_max) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] metric_type = random.choice(["l2", "ip"]) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) vec_field_name = utils.get_default_field_name() vector_query = { "vector": { vec_field_name: { "topk": top_k, "query": query_vectors, "metric_type": utils.metric_type_trans(metric_type), "params": search_param } } } self.load_and_query(vector_query) # TODO: need to check def count(self, collection_name=None): if collection_name is None: collection_name = self._collection_name row_count = self._milvus.get_collection_stats( collection_name)["row_count"] logger.debug("Row count: %d in collection: <%s>" % (row_count, collection_name)) return row_count def drop(self, timeout=120, collection_name=None): timeout = int(timeout) if collection_name is None: collection_name = self._collection_name logger.info("Start delete collection: %s" % collection_name) self._milvus.drop_collection(collection_name) i = 0 while i < timeout: try: row_count = self.count(collection_name=collection_name) if row_count: time.sleep(1) i = i + 1 continue else: break except Exception as e: logger.debug(str(e)) break if i >= timeout: logger.error("Delete collection timeout") def get_stats(self): return self._milvus.get_collection_stats(self._collection_name) def get_info(self, collection_name=None): # pdb.set_trace() if collection_name is None: collection_name = self._collection_name return self._milvus.get_collection_info(collection_name) def show_collections(self): return self._milvus.list_collections() def exists_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name res = self._milvus.has_collection(collection_name) return res def clean_db(self): collection_names = self.show_collections() for name in collection_names: self.drop(collection_name=name) @time_wrapper def load_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.load_collection(collection_name, timeout=3000) @time_wrapper def release_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.release_collection(collection_name, timeout=3000) @time_wrapper def load_partitions(self, tag_names, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.load_partitions(collection_name, tag_names, timeout=3000) @time_wrapper def release_partitions(self, tag_names, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.release_partitions(collection_name, tag_names, timeout=3000)
# ------ # Basic hybrid search entities # ------ results = client.search(collection_name, query_hybrid, fields=["release_year", "embedding"]) for entities in results: for topk_film in entities: current_entity = topk_film.entity print("==") print("- id: {}".format(topk_film.id)) print("- title: {}".format(titles[topk_film.id])) print("- distance: {}".format(topk_film.distance)) print("- release_year: {}".format(current_entity.release_year)) print("- embedding: {}".format(current_entity.embedding)) # ------ # Basic delete index: # You can drop index for a field. # ------ client.drop_index(collection_name, "embedding") if collection_name in client.list_collections(): client.drop_collection(collection_name) # ------ # Summary: # Now we've went through some basic build index operations, hope it's helpful! # ------
class MilvusClient(object): def __init__(self, table_name=None, ip=None, port=None): self._milvus = Milvus() self._table_name = table_name try: if not ip: self._milvus.connect( host = SERVER_HOST_DEFAULT, port = SERVER_PORT_DEFAULT) else: self._milvus.connect( host = ip, port = port) except Exception as e: raise e def __str__(self): return 'Milvus table %s' % self._table_name def check_status(self, status): if not status.OK(): logger.error(status.message) raise Exception("Status not ok") def create_table(self, table_name, dimension, index_file_size, metric_type): if not self._table_name: self._table_name = table_name if metric_type == "l2": metric_type = MetricType.L2 elif metric_type == "ip": metric_type = MetricType.IP else: logger.error("Not supported metric_type: %s" % metric_type) create_param = {'table_name': table_name, 'dimension': dimension, 'index_file_size': index_file_size, "metric_type": metric_type} status = self._milvus.create_table(create_param) self.check_status(status) @time_wrapper def insert(self, X, ids=None): status, result = self._milvus.add_vectors(self._table_name, X, ids) self.check_status(status) return status, result @time_wrapper def create_index(self, index_type, nlist): if index_type == "flat": index_type = IndexType.FLAT elif index_type == "ivf_flat": index_type = IndexType.IVFLAT elif index_type == "ivf_sq8": index_type = IndexType.IVF_SQ8 elif index_type == "nsg": index_type = IndexType.NSG elif index_type == "ivf_sq8h": index_type = IndexType.IVF_SQ8H elif index_type == "ivf_pq": index_type = IndexType.IVF_PQ index_params = { "index_type": index_type, "nlist": nlist, } logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params))) status = self._milvus.create_index(self._table_name, index=index_params) self.check_status(status) def describe_index(self): return self._milvus.describe_index(self._table_name) def drop_index(self): logger.info("Drop index: %s" % self._table_name) return self._milvus.drop_index(self._table_name) @time_wrapper def query(self, X, top_k, nprobe): status, result = self._milvus.search_vectors(self._table_name, top_k, nprobe, X) self.check_status(status) return status, result def count(self): return self._milvus.get_table_row_count(self._table_name)[1] def delete(self, timeout=60): logger.info("Start delete table: %s" % self._table_name) self._milvus.delete_table(self._table_name) i = 0 while i < timeout: if self.count(): time.sleep(1) i = i + 1 continue else: break if i < timeout: logger.error("Delete table timeout") def describe(self): return self._milvus.describe_table(self._table_name) def show_tables(self): return self._milvus.show_tables() def exists_table(self): status, res = self._milvus.has_table(self._table_name) self.check_status(status) return res @time_wrapper def preload_table(self): return self._milvus.preload_table(self._table_name, timeout=3000)
class MilvusClient(object): def __init__(self, collection_name=None, ip=None, port=None, timeout=60): self._collection_name = collection_name try: i = 1 start_time = time.time() if not ip: self._milvus = Milvus(host=SERVER_HOST_DEFAULT, port=SERVER_PORT_DEFAULT) else: # retry connect for remote server while time.time() < start_time + timeout: try: self._milvus = Milvus(host=ip, port=port) if self._milvus.server_status(): logger.debug( "Try connect times: %d, %s" % (i, round(time.time() - start_time, 2))) break except Exception as e: logger.debug("Milvus connect failed") i = i + 1 except Exception as e: raise e def __str__(self): return 'Milvus collection %s' % self._collection_name def check_status(self, status): if not status.OK(): logger.error(status.message) raise Exception("Status not ok") def check_result_ids(self, result): for index, item in enumerate(result): if item[0].distance >= epsilon: logger.error(index) logger.error(item[0].distance) raise Exception("Distance wrong") def create_collection(self, collection_name, dimension, index_file_size, metric_type): if not self._collection_name: self._collection_name = collection_name if metric_type == "l2": metric_type = MetricType.L2 elif metric_type == "ip": metric_type = MetricType.IP elif metric_type == "jaccard": metric_type = MetricType.JACCARD elif metric_type == "hamming": metric_type = MetricType.HAMMING elif metric_type == "sub": metric_type = MetricType.SUBSTRUCTURE elif metric_type == "super": metric_type = MetricType.SUPERSTRUCTURE else: logger.error("Not supported metric_type: %s" % metric_type) create_param = { 'collection_name': collection_name, 'dimension': dimension, 'index_file_size': index_file_size, "metric_type": metric_type } status = self._milvus.create_collection(create_param) self.check_status(status) @time_wrapper def insert(self, X, ids=None): status, result = self._milvus.add_vectors(self._collection_name, X, ids) self.check_status(status) return status, result @time_wrapper def delete_vectors(self, ids): status = self._milvus.delete_by_id(self._collection_name, ids) self.check_status(status) @time_wrapper def flush(self): status = self._milvus.flush([self._collection_name]) self.check_status(status) @time_wrapper def compact(self): status = self._milvus.compact(self._collection_name) self.check_status(status) @time_wrapper def create_index(self, index_type, index_param=None): index_type = INDEX_MAP[index_type] logger.info( "Building index start, collection_name: %s, index_type: %s" % (self._collection_name, index_type)) if index_param: logger.info(index_param) status = self._milvus.create_index(self._collection_name, index_type, index_param) self.check_status(status) def describe_index(self): status, result = self._milvus.describe_index(self._collection_name) self.check_status(status) index_type = None for k, v in INDEX_MAP.items(): if result._index_type == v: index_type = k break return {"index_type": index_type, "index_param": result._params} def drop_index(self): logger.info("Drop index: %s" % self._collection_name) return self._milvus.drop_index(self._collection_name) @time_wrapper def query(self, X, top_k, search_param=None): status, result = self._milvus.search_vectors(self._collection_name, top_k, query_records=X, params=search_param) self.check_status(status) return result @time_wrapper def query_ids(self, top_k, ids, search_param=None): status, result = self._milvus.search_by_ids(self._collection_name, ids, top_k, params=search_param) self.check_result_ids(result) return result def count(self): return self._milvus.count_collection(self._collection_name)[1] def delete(self, timeout=120): timeout = int(timeout) logger.info("Start delete collection: %s" % self._collection_name) self._milvus.drop_collection(self._collection_name) i = 0 while i < timeout: if self.count(): time.sleep(1) i = i + 1 continue else: break if i >= timeout: logger.error("Delete collection timeout") def describe(self): return self._milvus.describe_collection(self._collection_name) def show_collections(self): return self._milvus.show_collections() def exists_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name status, res = self._milvus.has_collection(collection_name) # self.check_status(status) return res @time_wrapper def preload_collection(self): status = self._milvus.preload_collection(self._collection_name, timeout=3000) self.check_status(status) return status def get_server_version(self): status, res = self._milvus.server_version() return res def get_server_mode(self): return self.cmd("mode") def get_server_commit(self): return self.cmd("build_commit_id") def get_server_config(self): return json.loads(self.cmd("get_config *")) def get_mem_info(self): result = json.loads(self.cmd("get_system_info")) result_human = { # unit: Gb "memory_used": round(int(result["memory_used"]) / (1024 * 1024 * 1024), 2) } return result_human def cmd(self, command): status, res = self._milvus._cmd(command) logger.info("Server command: %s, result: %s" % (command, res)) self.check_status(status) return res
class MilvusClient(object): def __init__(self, collection_name=None, host=None, port=None, timeout=60): """ Milvus client wrapper for python-sdk. Default timeout set 60s """ self._collection_name = collection_name try: start_time = time.time() if not host: host = SERVER_HOST_DEFAULT if not port: port = SERVER_PORT_DEFAULT logger.debug(host) logger.debug(port) # retry connect for remote server i = 0 while time.time() < start_time + timeout: try: self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False) if self._milvus.server_status(): logger.debug("Try connect times: %d, %s" % (i, round(time.time() - start_time, 2))) break except Exception as e: logger.debug("Milvus connect failed: %d times" % i) i = i + 1 if time.time() > start_time + timeout: raise Exception("Server connect timeout") except Exception as e: raise e self._metric_type = None if self._collection_name and self.exists_collection(): self._metric_type = metric_type_to_str(self.describe()[1].metric_type) self._dimension = self.describe()[1].dimension def __str__(self): return 'Milvus collection %s' % self._collection_name def set_collection(self, name): self._collection_name = name def check_status(self, status): if not status.OK(): logger.error(self._collection_name) logger.error(status.message) logger.error(self._milvus.server_status()) logger.error(self.count()) raise Exception("Status not ok") def check_result_ids(self, result): for index, item in enumerate(result): if item[0].distance >= epsilon: logger.error(index) logger.error(item[0].distance) raise Exception("Distance wrong") def create_collection(self, collection_name, dimension, index_file_size, metric_type): if not self._collection_name: self._collection_name = collection_name if metric_type not in METRIC_MAP.keys(): raise Exception("Not supported metric_type: %s" % metric_type) metric_type = METRIC_MAP[metric_type] create_param = {'collection_name': collection_name, 'dimension': dimension, 'index_file_size': index_file_size, "metric_type": metric_type} status = self._milvus.create_collection(create_param) self.check_status(status) def create_partition(self, tag_name): status = self._milvus.create_partition(self._collection_name, tag_name) self.check_status(status) def drop_partition(self, tag_name): status = self._milvus.drop_partition(self._collection_name, tag_name) self.check_status(status) def list_partitions(self): status, tags = self._milvus.list_partitions(self._collection_name) self.check_status(status) return tags @time_wrapper def insert(self, X, ids=None, collection_name=None): if collection_name is None: collection_name = self._collection_name status, result = self._milvus.insert(collection_name, X, ids) self.check_status(status) return status, result def insert_rand(self): insert_xb = random.randint(1, 100) X = [[random.random() for _ in range(self._dimension)] for _ in range(insert_xb)] X = utils.normalize(self._metric_type, X) count_before = self.count() status, _ = self.insert(X) self.check_status(status) self.flush() if count_before + insert_xb != self.count(): raise Exception("Assert failed after inserting") def get_rand_ids(self, length): while True: status, stats = self._milvus.get_collection_stats(self._collection_name) self.check_status(status) segments = stats["partitions"][0]["segments"] # random choice one segment segment = random.choice(segments) status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) if not status.OK(): logger.error(status.message) continue if len(segment_ids): break if length >= len(segment_ids): logger.debug("Reset length: %d" % len(segment_ids)) return segment_ids return random.sample(segment_ids, length) def get_rand_ids_each_segment(self, length): res = [] status, stats = self._milvus.get_collection_stats(self._collection_name) self.check_status(status) segments = stats["partitions"][0]["segments"] segments_num = len(segments) # random choice from each segment for segment in segments: status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) self.check_status(status) res.extend(segment_ids[:length]) return segments_num, res def get_rand_entities(self, length): ids = self.get_rand_ids(length) status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids) self.check_status(status) return ids, get_res @time_wrapper def get_entities(self, get_ids): status, get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids) self.check_status(status) return get_res @time_wrapper def delete(self, ids, collection_name=None): if collection_name is None: collection_name = self._collection_name status = self._milvus.delete_entity_by_id(collection_name, ids) self.check_status(status) def delete_rand(self): delete_id_length = random.randint(1, 100) count_before = self.count() logger.info("%s: length to delete: %d" % (self._collection_name, delete_id_length)) delete_ids = self.get_rand_ids(delete_id_length) self.delete(delete_ids) self.flush() logger.info("%s: count after delete: %d" % (self._collection_name, self.count())) status, get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids) self.check_status(status) for item in get_res: if item: raise Exception("Assert failed after delete") if count_before - len(delete_ids) != self.count(): raise Exception("Assert failed after delete") @time_wrapper def flush(self, collection_name=None): if collection_name is None: collection_name = self._collection_name status = self._milvus.flush([collection_name]) self.check_status(status) @time_wrapper def compact(self, collection_name=None): if collection_name is None: collection_name = self._collection_name status = self._milvus.compact(collection_name) self.check_status(status) @time_wrapper def create_index(self, index_type, index_param=None): index_type = INDEX_MAP[index_type] logger.info("Building index start, collection_name: %s, index_type: %s" % (self._collection_name, index_type)) if index_param: logger.info(index_param) status = self._milvus.create_index(self._collection_name, index_type, index_param) self.check_status(status) def describe_index(self): status, result = self._milvus.get_index_info(self._collection_name) self.check_status(status) index_type = None for k, v in INDEX_MAP.items(): if result._index_type == v: index_type = k break return {"index_type": index_type, "index_param": result._params} def drop_index(self): logger.info("Drop index: %s" % self._collection_name) return self._milvus.drop_index(self._collection_name) def query(self, X, top_k, search_param=None, collection_name=None): if collection_name is None: collection_name = self._collection_name status, result = self._milvus.search(collection_name, top_k, query_records=X, params=search_param) self.check_status(status) return result def query_rand(self): top_k = random.randint(1, 100) nq = random.randint(1, 100) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} _, X = self.get_rand_entities(nq) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) status, _ = self._milvus.search(self._collection_name, top_k, query_records=X, params=search_param) self.check_status(status) # for i, item in enumerate(search_res): # if item[0].id != ids[i]: # logger.warning("The index of search result: %d" % i) # raise Exception("Query failed") # @time_wrapper # def query_ids(self, top_k, ids, search_param=None): # status, result = self._milvus.search_by_id(self._collection_name, ids, top_k, params=search_param) # self.check_result_ids(result) # return result def count(self, name=None): if name is None: name = self._collection_name logger.debug(self._milvus.count_entities(name)) row_count = self._milvus.count_entities(name)[1] if not row_count: row_count = 0 logger.debug("Row count: %d in collection: <%s>" % (row_count, name)) return row_count def drop(self, timeout=120, name=None): timeout = int(timeout) if name is None: name = self._collection_name logger.info("Start delete collection: %s" % name) status = self._milvus.drop_collection(name) self.check_status(status) i = 0 while i < timeout: if self.count(name=name): time.sleep(1) i = i + 1 continue else: break if i >= timeout: logger.error("Delete collection timeout") def describe(self): # logger.info(self._milvus.get_collection_info(self._collection_name)) return self._milvus.get_collection_info(self._collection_name) def show_collections(self): return self._milvus.list_collections() def exists_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name _, res = self._milvus.has_collection(collection_name) # self.check_status(status) return res def clean_db(self): collection_names = self.show_collections()[1] for name in collection_names: logger.debug(name) self.drop(name=name) @time_wrapper def preload_collection(self): status = self._milvus.load_collection(self._collection_name, timeout=3000) self.check_status(status) return status def get_server_version(self): _, res = self._milvus.server_version() return res def get_server_mode(self): return self.cmd("mode") def get_server_commit(self): return self.cmd("build_commit_id") def get_server_config(self): return json.loads(self.cmd("get_config *")) def get_mem_info(self): result = json.loads(self.cmd("get_system_info")) result_human = { # unit: Gb "memory_used": round(int(result["memory_used"]) / (1024*1024*1024), 2) } return result_human def cmd(self, command): status, res = self._milvus._cmd(command) logger.info("Server command: %s, result: %s" % (command, res)) self.check_status(status) return res
def main(): # Specify server addr when create milvus client instance # milvus client instance maintain a connection pool, param # `pool_size` specify the max connection num. milvus = Milvus(_HOST, _PORT) # Create collection demo_collection if it dosen't exist. collection_name = 'example_collection' ok = milvus.has_collection(collection_name) field_name = 'example_field' if not ok: fields = { "fields": [{ "name": field_name, "type": DataType.FLOAT_VECTOR, "metric_type": "L2", "params": { "dim": _DIM }, "indexes": [{ "metric_type": "L2" }] }] } milvus.create_collection(collection_name=collection_name, fields=fields) else: milvus.drop_collection(collection_name=collection_name) # Show collections in Milvus server collections = milvus.list_collections() print(collections) # Describe demo_collection stats = milvus.get_collection_stats(collection_name) print(stats) # 10000 vectors with 128 dimension # element per dimension is float32 type # vectors should be a 2-D array vectors = [[random.random() for _ in range(_DIM)] for _ in range(10)] print(vectors) # You can also use numpy to generate random vectors: # vectors = np.random.rand(10000, _DIM).astype(np.float32) # Insert vectors into demo_collection, return status and vectors id list entities = [{ "name": field_name, "type": DataType.FLOAT_VECTOR, "values": vectors }] res_ids = milvus.insert(collection_name=collection_name, entities=entities) print("ids:", res_ids) # Flush collection inserted data to disk. milvus.flush([collection_name]) # present collection statistics info stats = milvus.get_collection_stats(collection_name) print(stats) # create index of vectors, search more rapidly index_param = { "metric_type": "L2", "index_type": "IVF_FLAT", "params": { "nlist": 1024 } } # Create ivflat index in demo_collection # You can search vectors without creating index. however, Creating index help to # search faster print("Creating index: {}".format(index_param)) status = milvus.create_index(collection_name, field_name, index_param) # execute vector similarity search print("Searching ... ") dsl = { "bool": { "must": [{ "vector": { field_name: { "metric_type": "L2", "query": vectors, "topk": 10, "params": { "nprobe": 16 } } } }] } } milvus.load_collection(collection_name) results = milvus.search(collection_name, dsl) # indicate search result # also use by: # `results.distance_array[0][0] == 0.0 or results.id_array[0][0] == ids[0]` if results[0][0].distance == 0.0 or results[0][0].id == ids[0]: print('Query result is correct') else: print('Query result isn\'t correct') milvus.drop_index(collection_name, field_name) milvus.release_collection(collection_name) # Delete demo_collection status = milvus.drop_collection(collection_name)
class MilvusClient(object): def __init__(self, table_name=None, ip=None, port=None, timeout=60): self._milvus = Milvus() self._table_name = table_name try: i = 1 start_time = time.time() if not ip: self._milvus.connect(host=SERVER_HOST_DEFAULT, port=SERVER_PORT_DEFAULT) else: # retry connect for remote server while time.time() < start_time + timeout: try: self._milvus.connect(host=ip, port=port) if self._milvus.connected() is True: logger.debug( "Try connect times: %d, %s" % (i, round(time.time() - start_time, 2))) break except Exception as e: logger.debug("Milvus connect failed") i = i + 1 except Exception as e: raise e def __str__(self): return 'Milvus table %s' % self._table_name def check_status(self, status): if not status.OK(): logger.error(status.message) # raise Exception("Status not ok") def create_table(self, table_name, dimension, index_file_size, metric_type): if not self._table_name: self._table_name = table_name if metric_type == "l2": metric_type = MetricType.L2 elif metric_type == "ip": metric_type = MetricType.IP elif metric_type == "jaccard": metric_type = MetricType.JACCARD elif metric_type == "hamming": metric_type = MetricType.HAMMING else: logger.error("Not supported metric_type: %s" % metric_type) create_param = { 'table_name': table_name, 'dimension': dimension, 'index_file_size': index_file_size, "metric_type": metric_type } status = self._milvus.create_table(create_param) self.check_status(status) @time_wrapper def insert(self, X, ids=None): status, result = self._milvus.add_vectors(self._table_name, X, ids) self.check_status(status) return status, result @time_wrapper def create_index(self, index_type, nlist): index_params = { "index_type": INDEX_MAP[index_type], "nlist": nlist, } logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params))) status = self._milvus.create_index(self._table_name, index=index_params) self.check_status(status) def describe_index(self): status, result = self._milvus.describe_index(self._table_name) index_type = None for k, v in INDEX_MAP.items(): if result._index_type == v: index_type = k break nlist = result._nlist res = {"index_type": index_type, "nlist": nlist} return res def drop_index(self): logger.info("Drop index: %s" % self._table_name) return self._milvus.drop_index(self._table_name) @time_wrapper def query(self, X, top_k, nprobe): status, result = self._milvus.search_vectors(self._table_name, top_k, nprobe, X) self.check_status(status) return result def count(self): return self._milvus.get_table_row_count(self._table_name)[1] def delete(self, timeout=60): logger.info("Start delete table: %s" % self._table_name) self._milvus.delete_table(self._table_name) i = 0 while i < timeout: if self.count(): time.sleep(1) i = i + 1 continue else: break if i >= timeout: logger.error("Delete table timeout") def describe(self): return self._milvus.describe_table(self._table_name) def show_tables(self): return self._milvus.show_tables() def exists_table(self, table_name=None): if table_name is None: table_name = self._table_name status, res = self._milvus.has_table(table_name) self.check_status(status) return res @time_wrapper def preload_table(self): return self._milvus.preload_table(self._table_name, timeout=3000) def get_server_version(self): status, res = self._milvus.server_version() return res def get_server_mode(self): return self.cmd("mode") def get_server_commit(self): return self.cmd("build_commit_id") def get_server_config(self): return json.loads(self.cmd("get_config *")) def get_mem_info(self): result = json.loads(self.cmd("get_system_info")) result_human = { # unit: Gb "memory_used": round(int(result["memory_used"]) / (1024 * 1024 * 1024), 2) } return result_human def cmd(self, command): status, res = self._milvus._cmd(command) logger.info("Server command: %s, result: %s" % (command, res)) self.check_status(status) return res
class MilvusClient(object): def __init__(self, table_name=None, host=None, port=None): self._milvus = Milvus() self._table_name = table_name try: if not host: self._milvus.connect(host=SERVER_HOST_DEFAULT, port=SERVER_PORT_DEFAULT) else: self._milvus.connect(host=host, port=port) except Exception as e: raise e def __str__(self): return 'Milvus table %s' % self._table_name def check_status(self, status): if not status.OK(): logger.error(status.message) raise Exception("Status not ok") def create_table(self, table_name, dimension, index_file_size, metric_type): if not self._table_name: self._table_name = table_name if metric_type == "l2": metric_type = MetricType.L2 elif metric_type == "ip": metric_type = MetricType.IP else: logger.error("Not supported metric_type: %s" % metric_type) self._metric_type = metric_type create_param = { 'table_name': table_name, 'dimension': dimension, 'index_file_size': index_file_size, "metric_type": metric_type } status = self._milvus.create_table(create_param) self.check_status(status) @time_wrapper def insert(self, X, ids): if self._metric_type == MetricType.IP: logger.info("Set normalize for metric_type: Inner Product") X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') X = X.astype(numpy.float32) status, result = self._milvus.add_vectors(self._table_name, X.tolist(), ids=ids) self.check_status(status) return status, result @time_wrapper def create_index(self, index_type, nlist): if index_type == "flat": index_type = IndexType.FLAT elif index_type == "ivf_flat": index_type = IndexType.IVFLAT elif index_type == "ivf_sq8": index_type = IndexType.IVF_SQ8 elif index_type == "ivf_sq8h": index_type = IndexType.IVF_SQ8H elif index_type == "nsg": index_type = IndexType.NSG elif index_type == "ivf_pq": index_type = IndexType.IVF_PQ index_params = { "index_type": index_type, "nlist": nlist, } logger.info("Building index start, table_name: %s, index_params: %s" % (self._table_name, json.dumps(index_params))) status = self._milvus.create_index(self._table_name, index=index_params, timeout=6 * 3600) self.check_status(status) def describe_index(self): return self._milvus.describe_index(self._table_name) def drop_index(self): logger.info("Drop index: %s" % self._table_name) return self._milvus.drop_index(self._table_name) @time_wrapper def query(self, X, top_k, nprobe): if self._metric_type == MetricType.IP: logger.info("Set normalize for metric_type: Inner Product") X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') X = X.astype(numpy.float32) status, results = self._milvus.search_vectors(self._table_name, top_k, nprobe, X.tolist()) self.check_status(status) ids = [] for result in results: tmp_ids = [] for item in result: tmp_ids.append(item.id) ids.append(tmp_ids) return ids def count(self): return self._milvus.get_table_row_count(self._table_name)[1] def delete(self, table_name): logger.info("Start delete table: %s" % table_name) return self._milvus.delete_table(table_name) def describe(self): return self._milvus.describe_table(self._table_name) def exists_table(self, table_name): return self._milvus.has_table(table_name) def get_server_version(self): status, res = self._milvus.server_version() self.check_status(status) return res @time_wrapper def preload_table(self): return self._milvus.preload_table(self._table_name)
class MyMilvus(): def __init__(self, name, host, port, collection_param): self.host = host self.port = port self.client = Milvus(host, port) self.collection_name = name self.collection_param = collection_param # self.collection_param = { # "fields": [ # {"name": "release_year", "type": DataType.INT32}, # {"name": "embedding", "type": DataType.FLOAT_VECTOR, "params": {"dim": 8}}, # ], # "segment_row_limit": 4096, # "auto_id": False # } def create_collection(self): if self.collection_name not in self.client.list_collections(): # self.client.drop_collection(self.collection_name) self.client.create_collection(self.collection_name, self.collection_param) # ------ # Basic create index: # Now that we have a collection in Milvus with `segment_row_limit` 4096, we can create index or # insert entities. # # We can call `create_index` BEFORE we insert any entities or AFTER. However Milvus won't actually # start build index task if the segment row count is smaller than `segment_row_limit`. So if we want # to make Milvus build index, we need to insert number of entities larger than `segment_row_limit`. # # We are going to use data in `films.csv` so you can checkout the structure. And we need to group # data with same fields together, so here is a example of how we obtain the data in files and transfer # them into what we need. # ------ ids = [] # ids titles = [] # titles release_years = [] # release year embeddings = [] # embeddings films = [] with open('films.csv', 'r') as csvfile: reader = csv.reader(csvfile) films = [film for film in reader] for film in films: ids.append(int(film[0])) titles.append(film[1]) release_years.append(int(film[2])) embeddings.append(list(map(float, film[3][1:][:-1].split(',')))) hybrid_entities = [ {"name": "release_year", "values": release_years, "type": DataType.INT32}, {"name": "embedding", "values": embeddings, "type": DataType.FLOAT_VECTOR}, ] # ------ # Basic insert: # After preparing the data, we are going to insert them into our collection. # The number of films inserted should be 8657. # ------ ids = self.client.insert(self.collection_name, hybrid_entities, ids) self.client.flush([self.collection_name]) after_flush_counts = self.client.count_entities(self.collection_name) print(" > There are {} films in collection `{}` after flush".format(after_flush_counts, self.collection_name)) # ------ # Basic create index: # Now that we have inserted all the films into Milvus, we are going to build index with these data. # # While building index, we have to indicate which `field` to build index for, the `index_type`, # `metric_type` and params for the specific index type. In our case, we want to build a `IVF_FLAT` # index, so the specific params are "nlist". See pymilvus documentation # (https://milvus-io.github.io/milvus-sdk-python/pythondoc/v0.3.0/index.html) for `index_type` we # support and the params accordingly. # # If there are already index for a collection and you call `create_index` with different params, the # older index will be replaced by new one. # ------ self.client.create_index(self.collection_name, "embedding", {"index_type": "IVF_FLAT", "metric_type": "L2", "params": {"nlist": 100}}) # ------ # Basic create index: # We can get the detail of the index by `get_collection_info`. # ------ info = self.client.get_collection_info(self.collection_name) pprint(info) # ------ # Basic hybrid search entities: # If we want to use index, the specific index params need to be provided, in our case, the "params" # should be "nprobe", if no "params" given, Milvus will complain about it and raise a exception. # ------ # query_embedding = [random.random() for _ in range(8)] # query_hybrid = { # "bool": { # "must": [ # { # "term": {"release_year": [2002, 1995]} # }, # { # "vector": { # "embedding": {"topk": 3, # "query": [query_embedding], # "metric_type": "L2", # "params": {"nprobe": 8}} # } # } # ] # } # } # ------ # Basic hybrid search entities # ------ # results = client.search(collection_name, query_hybrid, fields=["release_year", "embedding"]) # for entities in results: # for topk_film in entities: # current_entity = topk_film.entity # print("==") # print("- id: {}".format(topk_film.id)) # print("- title: {}".format(titles[topk_film.id])) # print("- distance: {}".format(topk_film.distance)) # # print("- release_year: {}".format(current_entity.release_year)) # print("- embedding: {}".format(current_entity.embedding)) # ------ # Basic delete index: # You can drop index for a field. # ------ self.client.drop_index(self.collection_name, "embedding") if self.collection_name in self.client.list_collections(): self.client.drop_collection(self.collection_name)