def create_collection(collection_name): client = Milvus(host, str(port)) status, ok = client.has_collection(collection_name) if not ok: param = { 'collection_name': collection_name, 'dimension': 3, } client.create_collection(param) client.close()
def test_connect_repeatedly(self, args): ''' target: test connect repeatedly method: connect again expected: status.code is 0, and status.message shows have connected already ''' milvus = Milvus() uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) milvus.connect(uri=uri_value) milvus.connect(uri=uri_value) assert milvus.connected()
def _test_search_concurrent_multiprocessing(self, args): ''' target: test concurrent search with multiprocessess method: search with 10 processes, each process uses dependent connection expected: status ok and the returned vectors should be query_records ''' nb = 100 top_k = 10 process_num = 4 processes = [] table = gen_unique_str("test_search_concurrent_multiprocessing") uri = "tcp://%s:%s" % (args["ip"], args["port"]) param = {'table_name': table, 'dimension': dim, 'index_type': IndexType.FLAT, 'store_raw_vector': False} # create table milvus = Milvus() milvus.connect(uri=uri) milvus.create_table(param) vectors, ids = self.init_data(milvus, table, nb=nb) query_vecs = vectors[nb//2:nb] def search(milvus): status, result = milvus.search_vectors(table, top_k, query_vecs) assert len(result) == len(query_vecs) for i in range(len(query_vecs)): assert result[i][0].id in ids assert result[i][0].distance == 0.0 for i in range(process_num): milvus = Milvus() milvus.connect(uri=uri) p = Process(target=search, args=(milvus, )) processes.append(p) p.start() time.sleep(0.2) for p in processes: p.join()
def milvus_test(usr_features, mov_features, ids): _HOST = '127.0.0.1' _PORT = '19530' # default value milvus = Milvus(_HOST, _PORT) table_name = 'paddle_demo1' status, ok = milvus.has_collection(table_name) if not ok: param = { 'collection_name': table_name, 'dimension': 200, 'index_file_size': 1024, # optional 'metric_type': MetricType.IP # optional } milvus.create_collection(param) insert_vectors = normaliz_data([usr_features.tolist()]) status, ids = milvus.insert(collection_name=table_name, records=insert_vectors, ids=ids) time.sleep(1) status, result = milvus.count_entities(table_name) print("rows in table paddle_demo1:", result) # status, table = milvus.count_entities(table_name) search_vectors = normaliz_data([mov_features.tolist()]) param = { 'collection_name': table_name, 'query_records': search_vectors, 'top_k': 1, 'params': { 'nprobe': 16 } } status, results = milvus.search(**param) print("Searched ids:", results[0][0].id) print("Score:", float(results[0][0].distance) * 5) status = milvus.drop_collection(table_name)
def test_connect_uri_null(self, args): ''' target: test connect with null uri method: uri set null expected: connected is True ''' milvus = Milvus() uri_value = "" if self.local_ip(args): milvus.connect(uri=uri_value, timeout=1) assert milvus.connected() else: with pytest.raises(Exception) as e: milvus.connect(uri=uri_value, timeout=1) assert not milvus.connected()
def _create_collection(_collection_param): milvus = Milvus() milvus.connect(**server_config) status, ok = milvus.has_collection(_collection_name) if ok: print("Table {} found, now going to delete it".format( _collection_name)) status = milvus.drop_collection(_collection_name) if not status.OK(): raise Exception("Delete collection error") print( "delete collection {} successfully!".format(_collection_name)) time.sleep(5) status, ok = milvus.has_collection(_collection_name) if ok: raise Exception("Delete collection error") status = milvus.create_collection(param) if not status.OK(): print("Create collection {} failed".format(_collection_name)) milvus.disconnect()
def main(): milvus = Milvus(host=SERVER_ADDR, port=SERVER_PORT) create_milvus_collection(milvus) partition_tag = get_partition_tag() count = 0 while count < (VEC_NUM // BASE_LEN): vectors = load_bvecs_data(FILE_PATH, BASE_LEN, count) vectors_ids = [ id for id in range(count * BASE_LEN, (count + 1) * BASE_LEN) ] create_partition(partition_tag[count], milvus) add_vectors(vectors, vectors_ids, partition_tag[count], milvus) count = count + 1
def create(self, name, **kwargs): """Create a new topo object and add in if not exist. Here the topo object is a Pymilvus client instance. """ uri = kwargs.get('uri', None) if not uri: raise RuntimeError('\"uri\" is required to create connection pool') milvus_args = copy.deepcopy(kwargs) milvus_args["max_retry"] = settings.MAX_RETRY pool = Milvus(name=name, **milvus_args) status = self.add(pool) if status != topology.StatusType.OK: pool = None return status, pool
def gcon(request, ghandler): ip = request.config.getoption("--ip") port = request.config.getoption("--port") milvus = Milvus(host=ip, port=port, handler=ghandler) def fin(): try: pass except Exception as e: print(e) pass request.addfinalizer(fin) return milvus
def _add(): milvus = Milvus() status = milvus.connect(**server_config) vectors = _generate_vectors(128, 10000) print('\n\tPID: {}, insert {} vectors'.format(os.getpid(), 10000)) status, _ = milvus.add_vectors(_table_name, vectors) milvus.disconnect()
def multi_thread_opr(table_name, utid): print("[{}] | T{} | Running .....".format(datetime.datetime.now(), utid)) client0 = Milvus(handler="HTTP") table_param = {'table_name': table_name, 'dimension': 64} vectors = [[random.random() for _ in range(64)] for _ in range(10000)] client0.connect() client0.create_table(table_param) for i in range(20): print("[{}] | T{} | O{} | Start insert data .....".format( datetime.datetime.now(), utid, i)) client0.insert(table_name, vectors) print("[{}] | T{} | O{} | Stop insert data .....".format( datetime.datetime.now(), utid, i)) client0.disconnect() print("[{}] | T{} | Stopping .....".format(datetime.datetime.now(), utid))
def main(): # connect_milvus_server() milvus = Milvus(host=SERVER_ADDR, port=SERVER_PORT) create_milvus_collection(milvus) build_collection(milvus) count = 0 while count < (VEC_NUM // BASE_LEN): vectors = load_bvecs_data(FILE_PATH, BASE_LEN, count) vectors_ids = [ id for id in range(count * BASE_LEN, (count + 1) * BASE_LEN) ] sex = [random.randint(0, 2) for _ in range(10000)] get_time = [random.randint(2017, 2020) for _ in range(10000)] is_glasses = [random.randint(10, 13) for _ in range(10000)] hybrid_entities = [{ "name": "sex", "values": sex, "type": DataType.INT32 }, { "name": "is_glasses", "values": is_glasses, "type": DataType.INT32 }, { "name": "get_time", "values": get_time, "type": DataType.INT32 }, { "name": "Vec", "values": vectors, "type": DataType.FLOAT_VECTOR }] time_start = time.time() result = milvus.insert('mixed06', hybrid_entities, ids=vectors_ids) time_end = time.time() print("insert milvue time: ", time_end - time_start) count = count + 1
def __init__(self, collection_name=None, ip=None, port=None, timeout=60): self._collection_name = collection_name try: i = 1 start_time = time.time() if not ip: self._milvus = Milvus(host=SERVER_HOST_DEFAULT, port=SERVER_PORT_DEFAULT) else: # retry connect for remote server while time.time() < start_time + timeout: try: self._milvus = Milvus(host=ip, port=port) if self._milvus.server_status(): logger.debug( "Try connect times: %d, %s" % (i, round(time.time() - start_time, 2))) break except Exception as e: logger.debug("Milvus connect failed") i = i + 1 except Exception as e: raise e
def __init__(self, collection_name=None, host=None, port=None, timeout=60): """ Milvus client wrapper for python-sdk. Default timeout set 60s """ self._collection_name = collection_name try: start_time = time.time() if not host: host = SERVER_HOST_DEFAULT if not port: port = SERVER_PORT_DEFAULT logger.debug(host) logger.debug(port) # retry connect for remote server i = 0 while time.time() < start_time + timeout: try: self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False) if self._milvus.server_status(): logger.debug("Try connect times: %d, %s" % (i, round(time.time() - start_time, 2))) break except Exception as e: logger.debug("Milvus connect failed: %d times" % i) i = i + 1 if time.time() > start_time + timeout: raise Exception("Server connect timeout") except Exception as e: raise e self._metric_type = None if self._collection_name and self.exists_collection(): self._metric_type = metric_type_to_str(self.describe()[1].metric_type) self._dimension = self.describe()[1].dimension
def _add(): milvus = Milvus() status = milvus.connect() if not status.OK: print(f'PID: {os.getpid()}, connect failed') status, _ = milvus.add_vectors(_table_name, vectors) milvus.disconnect()
def connect(request, handler): ip = request.config.getoption("--ip") handler = request.config.getoption("--handler") port_default = default_http_port if handler == "HTTP" else default_grpc_port port = request.config.getoption("--port", default=port_default) client = Milvus(host=ip, port=port, handler=handler) def fin(): try: client.close() except: pass request.addfinalizer(fin) return client
def create(): _HOST = 'localhost' _PORT = '19530' _collection_name = 'chs_stars_faces_512' _DIM = 512 # dimension of vector _INDEX_FILE_SIZE = 256 # max file size of stored index milvus = Milvus(_HOST, _PORT) param = { 'collection_name': _collection_name, 'dimension': _DIM, 'index_file_size': _INDEX_FILE_SIZE, # optional 'metric_type': MetricType.IP # optional } milvus.create_collection(param) index_param = { 'nlist': 2048 # 推荐 4 * sqrt(n) } status = milvus.create_index(_collection_name, IndexType.IVF_SQ8, index_param) # with open("chs_stars_features_pca.pickle", "rb") as f: # pca = pickle.load(f) # # with open("../chs_stars_features_pca.csv", "w") as fw, open("../chs_stars_features.csv", "r") as fr: # reader = csv.reader(fr) # writer = csv.writer(fw) # for index, line in enumerate(tqdm(reader)): # star, fname, features = line # features = np.array(json.loads(features)) # features = np.resize(features, (1, 512)) # features = normalize(features) # features = pca.transform(features).squeeze() # status, ids = milvus.insert(collection_name=_collection_name, records=[features.tolist()], ids=[index]) # if not status.OK(): # print(status) # continue # writer.writerow([index, star, fname, features]) with open("../chs_stars_labels.csv", "w") as fw, open("../chs_stars_features.csv", "r") as fr: reader = csv.reader(fr) writer = csv.writer(fw) for index, line in enumerate(tqdm(reader)): star, fname, features = line # features = np.array(json.loads(features)) # features = np.resize(features, (1, 512)) #features = normalize(features) features = json.loads(features) status, ids = milvus.insert(collection_name=_collection_name, records=[features], ids=[index]) if not status.OK(): print(status) continue writer.writerow([index, star, fname])
def main(): milvus = Milvus() milvus.connect(host=_HOST, port=_PORT) # # table_name = 'test_search_in_file' # dimension = 256 # vectors = Prepare.records([[random.random()for _ in range(dimension)] for _ in range(20)]) # param = { # 'table_name': table_name, # 'file_ids': ['1'], # 'query_records': vectors, # 'top_k': 5, # # 'query_ranges': [] # Not fully tested yet # } # status, result = milvus.search_vectors_in_files(**param) # if status.OK(): # pprint(result) # # _, result = milvus.get_table_row_count(table_name) # print('# Count: {}'.format(result)) table_name = 'test_search' dimension = 256 # param = {'start_date': '2019-06-24', 'end_date': '2019-06-25'} ranges = [['2019-06-25', '2019-06-25']] vectors = Prepare.records([[random.random() for _ in range(dimension)] for _ in range(1)]) # ranges = [Prepare.range(**param)] LOGGER.info(ranges) param = { 'table_name': table_name, 'query_records': vectors, 'top_k': 5, 'query_ranges': ranges # Not fully tested yet } status, result = milvus.search_vectors(**param) if status.OK(): pprint(result) _, result = milvus.get_table_row_count(table_name) print('# Count: {}'.format(result)) milvus.disconnect()
def test_search_multi_table_IP(search, args): ''' target: test search multi tables of IP method: add vectors into 10 tables, and search expected: search status ok, the length of result ''' num = 10 top_k = 10 nprobe = 1 tables = [] idx = [] for i in range(num): table = gen_unique_str("test_add_multitable_%d" % i) uri = "tcp://%s:%s" % (args["ip"], args["port"]) param = { 'table_name': table, 'dimension': dim, 'index_file_size': 10, 'metric_type': MetricType.L2 } # create table milvus = Milvus() milvus.connect(uri=uri) milvus.create_table(param) status, ids = milvus.add_vectors(table, vectors) assert status.OK() assert len(ids) == len(vectors) tables.append(table) idx.append(ids[0]) idx.append(ids[10]) idx.append(ids[20]) time.sleep(6) query_vecs = [vectors[0], vectors[10], vectors[20]] # start query from random table for i in range(num): table = tables[i] status, result = milvus.search_vectors(table, top_k, nprobe, query_vecs) assert status.OK() assert len(result) == len(query_vecs) for j in range(len(query_vecs)): assert len(result[j]) == top_k for j in range(len(query_vecs)): assert check_result(result[j], idx[3 * i + j])
def connect(request): host = '192.168.1.238' port = 19530 try: milvus = Milvus(host=host, port=port) except Exception as e: logging.getLogger().error(str(e)) pytest.exit("Milvus server can not connected, exit pytest ...") def fin(): try: milvus.close() pass except Exception as e: logging.getLogger().info(str(e)) request.addfinalizer(fin) return milvus
def _add_milvus_question(self, question_vector, collection: str, partition: str, milvus: mv.Milvus) -> int: """ 添加标准问题 @param {object} question_vector - 问题向量 @param {str} collection - 问题分类 @param {str} partition - 场景 @param {mv.Milvus} milvus - Milvus服务连接对象 @returns {int} - 返回milvus_id """ _status, _milvus_ids = milvus.insert( collection, [question_vector, ], partition_tag=partition) self.confirm_milvus_status(_status, 'insert') self._log_debug('insert _milvus_ids: %s' % str(_milvus_ids)) return _milvus_ids[0]
def dis_connect(request): ip = request.config.getoption("--ip") port = request.config.getoption("--port") milvus = Milvus() milvus.connect(host=ip, port=port) milvus.disconnect() def fin(): try: milvus.disconnect() except: pass request.addfinalizer(fin) return milvus
class TestToServer: fake_milvus = Milvus() fake_milvus.connect(host='127.0.0.1', port='9090') @mock.patch.object(Ms, 'server_status') def test_ping(self, server_status): server_status.return_value = 'OK' ans = self.fake_milvus.server_status('fake_ping') assert ans == 'OK' ans = self.fake_milvus.server_status('version') assert ans == 'OK' @mock.patch.object(Ms, 'create_table') def test_crate_table(self, create_table): create_table.return_value = Status.SUCCESS ans = self.fake_milvus.create_table('fakeparam') assert ans == Status.SUCCESS @mock.patch.object(Ms, 'add_vectors') def test_add_vector(self, add_vectors): add_vectors.return_value = ['aaaa'] ans = self.fake_milvus.add_vectors('fake1', 'fake2') assert ans == ['aaaa'] @mock.patch.object(Ms, 'describe_table') def test_describe_table(self, describe_table): describe_table.return_value = 'fake_table_name' ans = self.fake_milvus.describe_table('fake_param') assert ans == 'fake_table_name' @mock.patch.object(Ms, 'show_tables') def test_show_tables(self, show_tables): show_tables.return_value = 'some_table' ans = self.fake_milvus.show_tables() assert ans == 'some_table' @mock.patch.object(Ms, 'get_table_row_count') def test_get_table_row_count(self, get_table_row_count): get_table_row_count.return_value = 666 ans = self.fake_milvus.get_table_row_count('fake_table') assert ans == 666
def test_disconnect_repeatedly(self, connect, args): ''' target: test disconnect repeatedly method: disconnect a connected client, disconnect again expected: raise an error after disconnected ''' if not connect.connected(): milvus = Milvus() uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) milvus.connect(uri=uri_value) res = milvus.disconnect() with pytest.raises(Exception) as e: res = milvus.disconnect() else: res = connect.disconnect() with pytest.raises(Exception) as e: res = connect.disconnect()
def _test_connect_disconnect_with_multiprocess(self, args): ''' target: test uri connect and disconnect repeatly with multiprocess method: set correct uri, test with multiprocessing connecting and disconnecting expected: all connection is connected after 10 times operation ''' uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) process_num = 4 processes = [] def connect(milvus): milvus.connect(uri=uri_value) milvus.disconnect() milvus.connect(uri=uri_value) assert milvus.connected() for i in range(process_num): milvus = Milvus() p = Process(target=connect, args=(milvus, )) processes.append(p) p.start() for p in processes: p.join()
def predict(start_date_str, end_date_str): print("加载模型") model = gensim.models.doc2vec.Doc2Vec.load("./doc2vec.model") print("建立milvus链接") client = Milvus(host=milvus_ip, port='19530') print("读取数据ing") start_date = datetime.strptime(start_date_str, '%Y-%m-%d').timestamp() * 1000 end_date = datetime.strptime(end_date_str, '%Y-%m-%d').timestamp() * 1000 res = Paper.query_by_time_interval(start_date, end_date) num = 0 start = time.time() id_list = [] user_id_list = [] vecs = [] for i in res: paper_id = i.id paper_user_id = i.user_id paper_str = i.title + " . " + i.description vec = get_vector(model, [paper_str]) # 将词向量写入到Milvus id_list.append(paper_id) user_id_list.append(paper_user_id) vecs.append(list(vec)) # 将词向量写入数据库 paper_vec = str(vec).replace('\n', '').replace('[', '').replace( ']', '').replace(" ", " ").replace(" ", ",")[1:] paper_vec = paper_vec.replace(",,", ",0,") Paper.update_SQL('doc_vector', paper_vec, paper_user_id) num += 1 if num % 200 == 0: print("完成了", num, '篇', '--用时:', time.time() - start) start = time.time() # hybrid_entities = [ # {"name": "id", "values": id_list, "type": DataType.INT32}, # {"name": "Vec", "values": vecs, "type": DataType.FLOAT_VECTOR} # ] client.insert('ideaman', records=vecs, ids=id_list) client.flush(collection_name_array=["ideaman"]) user_id_list.clear() id_list.clear() vecs.clear()
def _test_connect_with_multiprocess(self, args): ''' target: test uri connect with multiprocess method: set correct uri, test with multiprocessing connecting expected: all connection is connected ''' uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) process_num = 4 processes = [] def connect(milvus): milvus.connect(uri=uri_value) with pytest.raises(Exception) as e: milvus.connect(uri=uri_value) assert milvus.connected() for i in range(process_num): milvus = Milvus() p = Process(target=connect, args=(milvus, )) processes.append(p) p.start() for p in processes: p.join()
class MilvusConnection: def __init__(self, env, name="movies_L2", port="19530", param=None): if param is None: param = dict() param = { "collection_name": name, "dimension": 128, "index_file_size": 1024, "metric_type": MetricType.L2, **param, } self.name = name self.client = Milvus(host="localhost", port=port) self.statuses = {} if not self.client.has_collection(name)[1]: status_created_collection = self.client.create_collection(param) vectors = env.base.embeddings.detach().cpu().numpy().astype( "float32") target_ids = list(range(vectors.shape[0])) status_inserted, inserted_vector_ids = self.client.insert( collection_name=name, records=vectors, ids=target_ids) status_flushed = self.client.flush([name]) status_compacted = self.client.compact(collection_name=name) self.statuses["created_collection"] = status_created_collection self.statuses["inserted"] = status_inserted self.statuses["flushed"] = status_flushed self.statuses["compacted"] = status_compacted def search(self, search_vecs, topk=10, search_param=None): if search_param is None: search_param = dict() search_param = {"nprobe": 16, **search_param} status, results = self.client.search( collection_name=self.name, query_records=search_vecs, top_k=topk, params=search_param, ) self.statuses['last_search'] = status return torch.tensor(results.id_array) def get_log(self): return self.statuses
return status.OK() and ok if __name__ == "__main__": import numpy dim = 128 nq = 10000 table = "test" file_name = '/poc/yuncong/ann_1000m/query.npy' data = np.load(file_name) vectors = data[0:nq].tolist() # print(vectors) connect = Milvus() # connect.connect(host="192.168.1.27") # print(connect.show_tables()) # print(connect.get_table_row_count(table)) # sys.exit() connect.connect(host="127.0.0.1") connect.delete_table(table) # sys.exit() # time.sleep(2) print(connect.get_table_row_count(table)) param = { 'table_name': table, 'dimension': dim, 'metric_type': MetricType.L2, 'index_file_size': 10 }
class MilvusClient(object): def __init__(self, collection_name=None, host=None, port=None, timeout=180): self._collection_name = collection_name start_time = time.time() if not host: host = SERVER_HOST_DEFAULT if not port: port = SERVER_PORT_DEFAULT logger.debug(host) logger.debug(port) # retry connect remote server i = 0 while time.time() < start_time + timeout: try: self._milvus = Milvus(host=host, port=port, try_connect=False, pre_ping=False) break except Exception as e: logger.error(str(e)) logger.error("Milvus connect failed: %d times" % i) i = i + 1 time.sleep(i) if time.time() > start_time + timeout: raise Exception("Server connect timeout") # self._metric_type = None def __str__(self): return 'Milvus collection %s' % self._collection_name def check_status(self, status): if not status.OK(): logger.error(status.message) logger.error(self._milvus.server_status()) logger.error(self.count()) raise Exception("Status not ok") def check_result_ids(self, result): for index, item in enumerate(result): if item[0].distance >= epsilon: logger.error(index) logger.error(item[0].distance) raise Exception("Distance wrong") # only support the given field name def create_collection(self, dimension, data_type=DataType.FLOAT_VECTOR, auto_id=False, collection_name=None, other_fields=None): self._dimension = dimension if not collection_name: collection_name = self._collection_name vec_field_name = utils.get_default_field_name(data_type) fields = [{ "name": vec_field_name, "type": data_type, "params": { "dim": dimension } }] if other_fields: other_fields = other_fields.split(",") if "int" in other_fields: fields.append({ "name": utils.DEFAULT_INT_FIELD_NAME, "type": DataType.INT64 }) if "float" in other_fields: fields.append({ "name": utils.DEFAULT_FLOAT_FIELD_NAME, "type": DataType.FLOAT }) create_param = {"fields": fields, "auto_id": auto_id} try: self._milvus.create_collection(collection_name, create_param) logger.info("Create collection: <%s> successfully" % collection_name) except Exception as e: logger.error(str(e)) raise def create_partition(self, tag, collection_name=None): if not collection_name: collection_name = self._collection_name self._milvus.create_partition(collection_name, tag) def generate_values(self, data_type, vectors, ids): values = None if data_type in [DataType.INT32, DataType.INT64]: values = ids elif data_type in [DataType.FLOAT, DataType.DOUBLE]: values = [(i + 0.0) for i in ids] elif data_type in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]: values = vectors return values def generate_entities(self, vectors, ids=None, collection_name=None): entities = [] if collection_name is None: collection_name = self._collection_name info = self.get_info(collection_name) for field in info["fields"]: field_type = field["type"] entities.append({ "name": field["name"], "type": field_type, "values": self.generate_values(field_type, vectors, ids) }) return entities @time_wrapper def insert(self, entities, ids=None, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name try: insert_ids = self._milvus.insert(tmp_collection_name, entities, ids=ids) return insert_ids except Exception as e: logger.error(str(e)) def get_dimension(self): info = self.get_info() for field in info["fields"]: if field["type"] in [ DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR ]: return field["params"]["dim"] def get_rand_ids(self, length): segment_ids = [] while True: stats = self.get_stats() segments = stats["partitions"][0]["segments"] # random choice one segment segment = random.choice(segments) try: segment_ids = self._milvus.list_id_in_segment( self._collection_name, segment["id"]) except Exception as e: logger.error(str(e)) if not len(segment_ids): continue elif len(segment_ids) > length: return random.sample(segment_ids, length) else: logger.debug("Reset length: %d" % len(segment_ids)) return segment_ids # def get_rand_ids_each_segment(self, length): # res = [] # status, stats = self._milvus.get_collection_stats(self._collection_name) # self.check_status(status) # segments = stats["partitions"][0]["segments"] # segments_num = len(segments) # # random choice from each segment # for segment in segments: # status, segment_ids = self._milvus.list_id_in_segment(self._collection_name, segment["name"]) # self.check_status(status) # res.extend(segment_ids[:length]) # return segments_num, res # def get_rand_entities(self, length): # ids = self.get_rand_ids(length) # status, get_res = self._milvus.get_entity_by_id(self._collection_name, ids) # self.check_status(status) # return ids, get_res def get(self): get_ids = random.randint(1, 1000000) self._milvus.get_entity_by_id(self._collection_name, [get_ids]) @time_wrapper def get_entities(self, get_ids): get_res = self._milvus.get_entity_by_id(self._collection_name, get_ids) return get_res @time_wrapper def delete(self, ids, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name self._milvus.delete_entity_by_id(tmp_collection_name, ids) def delete_rand(self): delete_id_length = random.randint(1, 100) count_before = self.count() logger.debug("%s: length to delete: %d" % (self._collection_name, delete_id_length)) delete_ids = self.get_rand_ids(delete_id_length) self.delete(delete_ids) self.flush() logger.info("%s: count after delete: %d" % (self._collection_name, self.count())) get_res = self._milvus.get_entity_by_id(self._collection_name, delete_ids) for item in get_res: assert not item # if count_before - len(delete_ids) < self.count(): # logger.error(delete_ids) # raise Exception("Error occured") @time_wrapper def flush(self, _async=False, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name self._milvus.flush([tmp_collection_name], _async=_async) @time_wrapper def compact(self, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name status = self._milvus.compact(tmp_collection_name) self.check_status(status) @time_wrapper def create_index(self, field_name, index_type, metric_type, _async=False, index_param=None): index_type = INDEX_MAP[index_type] metric_type = utils.metric_type_trans(metric_type) logger.info( "Building index start, collection_name: %s, index_type: %s, metric_type: %s" % (self._collection_name, index_type, metric_type)) if index_param: logger.info(index_param) index_params = { "index_type": index_type, "metric_type": metric_type, "params": index_param } self._milvus.create_index(self._collection_name, field_name, index_params, _async=_async) # TODO: need to check def describe_index(self, field_name): # stats = self.get_stats() info = self._milvus.describe_index(self._collection_name, field_name) index_info = {"index_type": "flat", "index_param": None} for field in info["fields"]: for index in field['indexes']: if not index or "index_type" not in index: continue else: for k, v in INDEX_MAP.items(): if index['index_type'] == v: index_info['index_type'] = k index_info['index_param'] = index['params'] return index_info return index_info def drop_index(self, field_name): logger.info("Drop index: %s" % self._collection_name) return self._milvus.drop_index(self._collection_name, field_name) @time_wrapper def query(self, vector_query, filter_query=None, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name must_params = [vector_query] if filter_query: must_params.extend(filter_query) query = {"bool": {"must": must_params}} result = self._milvus.search(tmp_collection_name, query) return result @time_wrapper def load_and_query(self, vector_query, filter_query=None, collection_name=None): tmp_collection_name = self._collection_name if collection_name is None else collection_name must_params = [vector_query] if filter_query: must_params.extend(filter_query) query = {"bool": {"must": must_params}} self.load_collection(tmp_collection_name) result = self._milvus.search(tmp_collection_name, query) return result def get_ids(self, result): idss = result._entities.ids ids = [] len_idss = len(idss) len_r = len(result) top_k = len_idss // len_r for offset in range(0, len_idss, top_k): ids.append(idss[offset:min(offset + top_k, len_idss)]) return ids def query_rand(self, nq_max=100): # for ivf search dimension = 128 top_k = random.randint(1, 100) nq = random.randint(1, nq_max) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] metric_type = random.choice(["l2", "ip"]) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) vec_field_name = utils.get_default_field_name() vector_query = { "vector": { vec_field_name: { "topk": top_k, "query": query_vectors, "metric_type": utils.metric_type_trans(metric_type), "params": search_param } } } self.query(vector_query) def load_query_rand(self, nq_max=100): # for ivf search dimension = 128 top_k = random.randint(1, 100) nq = random.randint(1, nq_max) nprobe = random.randint(1, 100) search_param = {"nprobe": nprobe} query_vectors = [[random.random() for _ in range(dimension)] for _ in range(nq)] metric_type = random.choice(["l2", "ip"]) logger.info("%s, Search nq: %d, top_k: %d, nprobe: %d" % (self._collection_name, nq, top_k, nprobe)) vec_field_name = utils.get_default_field_name() vector_query = { "vector": { vec_field_name: { "topk": top_k, "query": query_vectors, "metric_type": utils.metric_type_trans(metric_type), "params": search_param } } } self.load_and_query(vector_query) # TODO: need to check def count(self, collection_name=None): if collection_name is None: collection_name = self._collection_name row_count = self._milvus.get_collection_stats( collection_name)["row_count"] logger.debug("Row count: %d in collection: <%s>" % (row_count, collection_name)) return row_count def drop(self, timeout=120, collection_name=None): timeout = int(timeout) if collection_name is None: collection_name = self._collection_name logger.info("Start delete collection: %s" % collection_name) self._milvus.drop_collection(collection_name) i = 0 while i < timeout: try: row_count = self.count(collection_name=collection_name) if row_count: time.sleep(1) i = i + 1 continue else: break except Exception as e: logger.debug(str(e)) break if i >= timeout: logger.error("Delete collection timeout") def get_stats(self): return self._milvus.get_collection_stats(self._collection_name) def get_info(self, collection_name=None): # pdb.set_trace() if collection_name is None: collection_name = self._collection_name return self._milvus.get_collection_info(collection_name) def show_collections(self): return self._milvus.list_collections() def exists_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name res = self._milvus.has_collection(collection_name) return res def clean_db(self): collection_names = self.show_collections() for name in collection_names: self.drop(collection_name=name) @time_wrapper def load_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.load_collection(collection_name, timeout=3000) @time_wrapper def release_collection(self, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.release_collection(collection_name, timeout=3000) @time_wrapper def load_partitions(self, tag_names, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.load_partitions(collection_name, tag_names, timeout=3000) @time_wrapper def release_partitions(self, tag_names, collection_name=None): if collection_name is None: collection_name = self._collection_name return self._milvus.release_partitions(collection_name, tag_names, timeout=3000)