Esempio n. 1
0
    def test_search(self, started_app):
        collection_name = inspect.currentframe().f_code.co_name
        to_index_cnt = random.randint(10, 20)
        collection = TablesFactory(collection_id=collection_name,
                                   state=Tables.NORMAL)
        to_index_files = TableFilesFactory.create_batch(
            to_index_cnt,
            collection=collection,
            file_type=TableFiles.FILE_TYPE_TO_INDEX)
        topk = random.randint(5, 10)
        nq = random.randint(5, 10)
        param = {
            'collection_name': collection_name,
            'query_records': self.random_data(nq, collection.dimension),
            'top_k': topk,
            'params': {
                'nprobe': 2049
            }
        }

        result = [
            milvus_pb2.TopKQueryResult(query_result_arrays=[
                milvus_pb2.QueryResult(id=i, distance=random.random())
                for i in range(topk)
            ]) for i in range(nq)
        ]

        mock_results = milvus_pb2.TopKQueryResultList(status=status_pb2.Status(
            error_code=status_pb2.SUCCESS, reason="Success"),
                                                      topk_query_result=result)

        collection_schema = CollectionSchema(
            collection_name=collection_name,
            index_file_size=collection.index_file_size,
            metric_type=collection.metric_type,
            dimension=collection.dimension)

        status, _ = self.client.search_vectors(**param)
        assert status.code == Status.ILLEGAL_ARGUMENT

        param['params']['nprobe'] = 2048
        RouterMixin.connection = mock.MagicMock(return_value=Milvus())
        RouterMixin.query_conn.conn = mock.MagicMock(return_value=Milvus())
        Milvus.describe_collection = mock.MagicMock(
            return_value=(BAD, collection_schema))
        status, ret = self.client.search_vectors(**param)
        assert status.code == Status.COLLECTION_NOT_EXISTS

        Milvus.describe_collection = mock.MagicMock(
            return_value=(OK, collection_schema))
        Milvus.search_vectors_in_files = mock.MagicMock(
            return_value=mock_results)

        status, ret = self.client.search_vectors(**param)
        assert status.OK()
        assert len(ret) == nq
Esempio n. 2
0
    def Search(self, request, context):

        metadata = {'resp_class': milvus_pb2.TopKQueryResult}

        collection_name = request.collection_name

        topk = request.topk

        if len(request.extra_params) == 0:
            raise exceptions.SearchParamError(message="Search parma loss", metadata=metadata)
        params = ujson.loads(str(request.extra_params[0].value))

        logger.info('Search {}: topk={} params={}'.format(
            collection_name, topk, params))

        # if nprobe > self.MAX_NPROBE or nprobe <= 0:
        #     raise exceptions.InvalidArgumentError(
        #         message='Invalid nprobe: {}'.format(nprobe), metadata=metadata)

        if topk > self.MAX_TOPK or topk <= 0:
            raise exceptions.InvalidTopKError(
                message='Invalid topk: {}'.format(topk), metadata=metadata)

        collection_meta = self.collection_meta.get(collection_name, None)

        if not collection_meta:
            status, info = self.router.connection(
                metadata=metadata).describe_collection(collection_name)
            if not status.OK():
                raise exceptions.CollectionNotFoundError(collection_name,
                                                    metadata=metadata)

            self.collection_meta[collection_name] = info
            collection_meta = info

        start = time.time()

        query_record_array = []
        if int(collection_meta.metric_type) >= MetricType.HAMMING.value:
            for query_record in request.query_record_array:
                query_record_array.append(bytes(query_record.binary_data))
        else:
            for query_record in request.query_record_array:
                query_record_array.append(list(query_record.float_data))

        status, id_results, dis_results = self._do_query(context,
                                                         collection_name,
                                                         collection_meta,
                                                         query_record_array,
                                                         topk,
                                                         params,
                                                         partition_tags=getattr(request, "partition_tag_array", []),
                                                         metadata=metadata)

        now = time.time()
        logger.info('SearchVector takes: {}'.format(now - start))

        topk_result_list = milvus_pb2.TopKQueryResult(
            status=status_pb2.Status(error_code=status.error_code,
                                     reason=status.reason),
            row_num=len(request.query_record_array) if len(id_results) else 0,
            ids=id_results,
            distances=dis_results)
        return topk_result_list
Esempio n. 3
0
    def Search(self, request, context):

        table_name = request.table_name

        topk = request.topk
        nprobe = request.nprobe

        logger.info('Search {}: topk={} nprobe={}'.format(
            table_name, topk, nprobe))

        metadata = {'resp_class': milvus_pb2.TopKQueryResult}

        if nprobe > self.MAX_NPROBE or nprobe <= 0:
            raise exceptions.InvalidArgumentError(
                message='Invalid nprobe: {}'.format(nprobe), metadata=metadata)

        if topk > self.MAX_TOPK or topk <= 0:
            raise exceptions.InvalidTopKError(
                message='Invalid topk: {}'.format(topk), metadata=metadata)

        table_meta = self.table_meta.get(table_name, None)

        if not table_meta:
            status, info = self.router.connection(
                metadata=metadata).describe_table(table_name)
            if not status.OK():
                raise exceptions.TableNotFoundError(table_name,
                                                    metadata=metadata)

            self.table_meta[table_name] = info
            table_meta = info

        start = time.time()

        query_record_array = []

        for query_record in request.query_record_array:
            query_record_array.append(list(query_record.vector_data))

        query_range_array = []
        for query_range in request.query_range_array:
            query_range_array.append(
                Range(query_range.start_value, query_range.end_value))

        status, id_results, dis_results = self._do_query(context,
                                                         table_name,
                                                         table_meta,
                                                         query_record_array,
                                                         topk,
                                                         nprobe,
                                                         query_range_array,
                                                         partition_tags=getattr(request, "partition_tag_array", []),
                                                         metadata=metadata)

        now = time.time()
        # logger.info('SearchVector takes: {}'.format(now - start))

        topk_result_list = milvus_pb2.TopKQueryResult(
            status=status_pb2.Status(error_code=status.error_code,
                                     reason=status.reason),
            row_num=len(request.query_record_array) if len(id_results) else 0,
            ids=id_results,
            distances=dis_results)
        return topk_result_list
Esempio n. 4
0
def gen_one_binary(topk):
    ids = [random.randrange(10000000, 99999999) for _ in range(topk)]
    distances = [random.random() for _ in range(topk)]
    return milvus_pb2.TopKQueryResult(struct.pack(str(topk) + 'l', *ids),
                                      struct.pack(str(topk) + 'd', *distances))