def test_search(self, started_app): collection_name = inspect.currentframe().f_code.co_name to_index_cnt = random.randint(10, 20) collection = TablesFactory(collection_id=collection_name, state=Tables.NORMAL) to_index_files = TableFilesFactory.create_batch( to_index_cnt, collection=collection, file_type=TableFiles.FILE_TYPE_TO_INDEX) topk = random.randint(5, 10) nq = random.randint(5, 10) param = { 'collection_name': collection_name, 'query_records': self.random_data(nq, collection.dimension), 'top_k': topk, 'params': { 'nprobe': 2049 } } result = [ milvus_pb2.TopKQueryResult(query_result_arrays=[ milvus_pb2.QueryResult(id=i, distance=random.random()) for i in range(topk) ]) for i in range(nq) ] mock_results = milvus_pb2.TopKQueryResultList(status=status_pb2.Status( error_code=status_pb2.SUCCESS, reason="Success"), topk_query_result=result) collection_schema = CollectionSchema( collection_name=collection_name, index_file_size=collection.index_file_size, metric_type=collection.metric_type, dimension=collection.dimension) status, _ = self.client.search_vectors(**param) assert status.code == Status.ILLEGAL_ARGUMENT param['params']['nprobe'] = 2048 RouterMixin.connection = mock.MagicMock(return_value=Milvus()) RouterMixin.query_conn.conn = mock.MagicMock(return_value=Milvus()) Milvus.describe_collection = mock.MagicMock( return_value=(BAD, collection_schema)) status, ret = self.client.search_vectors(**param) assert status.code == Status.COLLECTION_NOT_EXISTS Milvus.describe_collection = mock.MagicMock( return_value=(OK, collection_schema)) Milvus.search_vectors_in_files = mock.MagicMock( return_value=mock_results) status, ret = self.client.search_vectors(**param) assert status.OK() assert len(ret) == nq
def Search(self, request, context): metadata = {'resp_class': milvus_pb2.TopKQueryResult} collection_name = request.collection_name topk = request.topk if len(request.extra_params) == 0: raise exceptions.SearchParamError(message="Search parma loss", metadata=metadata) params = ujson.loads(str(request.extra_params[0].value)) logger.info('Search {}: topk={} params={}'.format( collection_name, topk, params)) # if nprobe > self.MAX_NPROBE or nprobe <= 0: # raise exceptions.InvalidArgumentError( # message='Invalid nprobe: {}'.format(nprobe), metadata=metadata) if topk > self.MAX_TOPK or topk <= 0: raise exceptions.InvalidTopKError( message='Invalid topk: {}'.format(topk), metadata=metadata) collection_meta = self.collection_meta.get(collection_name, None) if not collection_meta: status, info = self.router.connection( metadata=metadata).describe_collection(collection_name) if not status.OK(): raise exceptions.CollectionNotFoundError(collection_name, metadata=metadata) self.collection_meta[collection_name] = info collection_meta = info start = time.time() query_record_array = [] if int(collection_meta.metric_type) >= MetricType.HAMMING.value: for query_record in request.query_record_array: query_record_array.append(bytes(query_record.binary_data)) else: for query_record in request.query_record_array: query_record_array.append(list(query_record.float_data)) status, id_results, dis_results = self._do_query(context, collection_name, collection_meta, query_record_array, topk, params, partition_tags=getattr(request, "partition_tag_array", []), metadata=metadata) now = time.time() logger.info('SearchVector takes: {}'.format(now - start)) topk_result_list = milvus_pb2.TopKQueryResult( status=status_pb2.Status(error_code=status.error_code, reason=status.reason), row_num=len(request.query_record_array) if len(id_results) else 0, ids=id_results, distances=dis_results) return topk_result_list
def Search(self, request, context): table_name = request.table_name topk = request.topk nprobe = request.nprobe logger.info('Search {}: topk={} nprobe={}'.format( table_name, topk, nprobe)) metadata = {'resp_class': milvus_pb2.TopKQueryResult} if nprobe > self.MAX_NPROBE or nprobe <= 0: raise exceptions.InvalidArgumentError( message='Invalid nprobe: {}'.format(nprobe), metadata=metadata) if topk > self.MAX_TOPK or topk <= 0: raise exceptions.InvalidTopKError( message='Invalid topk: {}'.format(topk), metadata=metadata) table_meta = self.table_meta.get(table_name, None) if not table_meta: status, info = self.router.connection( metadata=metadata).describe_table(table_name) if not status.OK(): raise exceptions.TableNotFoundError(table_name, metadata=metadata) self.table_meta[table_name] = info table_meta = info start = time.time() query_record_array = [] for query_record in request.query_record_array: query_record_array.append(list(query_record.vector_data)) query_range_array = [] for query_range in request.query_range_array: query_range_array.append( Range(query_range.start_value, query_range.end_value)) status, id_results, dis_results = self._do_query(context, table_name, table_meta, query_record_array, topk, nprobe, query_range_array, partition_tags=getattr(request, "partition_tag_array", []), metadata=metadata) now = time.time() # logger.info('SearchVector takes: {}'.format(now - start)) topk_result_list = milvus_pb2.TopKQueryResult( status=status_pb2.Status(error_code=status.error_code, reason=status.reason), row_num=len(request.query_record_array) if len(id_results) else 0, ids=id_results, distances=dis_results) return topk_result_list
def gen_one_binary(topk): ids = [random.randrange(10000000, 99999999) for _ in range(topk)] distances = [random.random() for _ in range(topk)] return milvus_pb2.TopKQueryResult(struct.pack(str(topk) + 'l', *ids), struct.pack(str(topk) + 'd', *distances))