Esempio n. 1
0
    def test_search(self, started_app):
        collection_name = inspect.currentframe().f_code.co_name
        to_index_cnt = random.randint(10, 20)
        collection = TablesFactory(collection_id=collection_name,
                                   state=Tables.NORMAL)
        to_index_files = TableFilesFactory.create_batch(
            to_index_cnt,
            collection=collection,
            file_type=TableFiles.FILE_TYPE_TO_INDEX)
        topk = random.randint(5, 10)
        nq = random.randint(5, 10)
        param = {
            'collection_name': collection_name,
            'query_records': self.random_data(nq, collection.dimension),
            'top_k': topk,
            'params': {
                'nprobe': 2049
            }
        }

        result = [
            milvus_pb2.TopKQueryResult(query_result_arrays=[
                milvus_pb2.QueryResult(id=i, distance=random.random())
                for i in range(topk)
            ]) for i in range(nq)
        ]

        mock_results = milvus_pb2.TopKQueryResultList(status=status_pb2.Status(
            error_code=status_pb2.SUCCESS, reason="Success"),
                                                      topk_query_result=result)

        collection_schema = CollectionSchema(
            collection_name=collection_name,
            index_file_size=collection.index_file_size,
            metric_type=collection.metric_type,
            dimension=collection.dimension)

        status, _ = self.client.search_vectors(**param)
        assert status.code == Status.ILLEGAL_ARGUMENT

        param['params']['nprobe'] = 2048
        RouterMixin.connection = mock.MagicMock(return_value=Milvus())
        RouterMixin.query_conn.conn = mock.MagicMock(return_value=Milvus())
        Milvus.describe_collection = mock.MagicMock(
            return_value=(BAD, collection_schema))
        status, ret = self.client.search_vectors(**param)
        assert status.code == Status.COLLECTION_NOT_EXISTS

        Milvus.describe_collection = mock.MagicMock(
            return_value=(OK, collection_schema))
        Milvus.search_vectors_in_files = mock.MagicMock(
            return_value=mock_results)

        status, ret = self.client.search_vectors(**param)
        assert status.OK()
        assert len(ret) == nq
Esempio n. 2
0
    def Search(self, request, context):

        table_name = request.table_name

        topk = request.topk
        nprobe = request.nprobe

        logger.info('Search {}: topk={} nprobe={}'.format(
            table_name, topk, nprobe))

        metadata = {'resp_class': milvus_pb2.TopKQueryResultList}

        if nprobe > self.MAX_NPROBE or nprobe <= 0:
            raise exceptions.InvalidArgumentError(
                message='Invalid nprobe: {}'.format(nprobe), metadata=metadata)

        if topk > self.MAX_TOPK or topk <= 0:
            raise exceptions.InvalidTopKError(
                message='Invalid topk: {}'.format(topk), metadata=metadata)

        table_meta = self.table_meta.get(table_name, None)

        if not table_meta:
            status, info = self.router.connection(
                metadata=metadata).describe_table(table_name)
            if not status.OK():
                raise exceptions.TableNotFoundError(table_name,
                                                    metadata=metadata)

            self.table_meta[table_name] = info
            table_meta = info

        start = time.time()

        query_record_array = []

        for query_record in request.query_record_array:
            query_record_array.append(list(query_record.vector_data))

        query_range_array = []
        for query_range in request.query_range_array:
            query_range_array.append(
                Range(query_range.start_value, query_range.end_value))

        status, results = self._do_query(context,
                                         table_name,
                                         table_meta,
                                         query_record_array,
                                         topk,
                                         nprobe,
                                         query_range_array,
                                         metadata=metadata)

        now = time.time()
        logger.info('SearchVector takes: {}'.format(now - start))

        topk_result_list = milvus_pb2.TopKQueryResultList(
            status=status_pb2.Status(error_code=status.error_code,
                                     reason=status.reason),
            topk_query_result=results)
        return topk_result_list