예제 #1
0
    def test_query_properties_match_and_rank(self):

        query_model = QueryModel(
            query_properties=[
                QueryRankingFeature(name="query_vector",
                                    mapping=lambda x: [1, 2, 3])
            ],
            match_phase=OR(),
            rank_profile=RankProfile(name="bm25", list_features=True),
        )
        self.assertDictEqual(
            query_model.create_body(query=self.query),
            {
                "yql":
                'select * from sources * where ([{"grammar": "any"}]userInput("this is  a test"));',
                "ranking": {
                    "profile": "bm25",
                    "listFeatures": "true"
                },
                "ranking.features.query(query_vector)": "[1, 2, 3]",
            },
        )

        query_model = QueryModel(
            query_properties=[
                QueryRankingFeature(name="query_vector",
                                    mapping=lambda x: [1, 2, 3])
            ],
            match_phase=ANN(
                doc_vector="doc_vector",
                query_vector="query_vector",
                hits=10,
                label="label",
            ),
            rank_profile=RankProfile(name="bm25", list_features=True),
        )
        self.assertDictEqual(
            query_model.create_body(query=self.query),
            {
                "yql":
                'select * from sources * where ([{"targetNumHits": 10, "label": "label", "approximate": true}]nearestNeighbor(doc_vector, query_vector));',
                "ranking": {
                    "profile": "bm25",
                    "listFeatures": "true"
                },
                "ranking.features.query(query_vector)": "[1, 2, 3]",
            },
        )
예제 #2
0
    def test_collect_training_data_point(self):

        self.app.query = Mock(side_effect=[
            VespaResult(self.raw_vespa_result_recall),
            VespaResult(self.raw_vespa_result_additional),
        ])
        query_model = QueryModel(rank_profile=RankProfile(list_features=True))
        data = self.app.collect_training_data_point(
            query="this is a query",
            query_id="123",
            relevant_id="abc",
            id_field="vespa_id_field",
            query_model=query_model,
            number_additional_docs=2,
            fields=["rankfeatures", "title"],
            timeout="15s",
        )

        self.assertEqual(self.app.query.call_count, 2)
        self.app.query.assert_has_calls([
            call(
                query="this is a query",
                query_model=query_model,
                recall=("vespa_id_field", ["abc"]),
                timeout="15s",
            ),
            call(
                query="this is a query",
                query_model=query_model,
                hits=2,
                timeout="15s",
            ),
        ])
        expected_data = [
            {
                "document_id": "abc",
                "query_id": "123",
                "label": 1,
                "a": 1,
                "b": 2,
                "title": "this is a title",
            },
            {
                "document_id": "def",
                "query_id": "123",
                "label": 0,
                "a": 3,
                "b": 4,
                "title": "this is a title 2",
            },
            {
                "document_id": "ghi",
                "query_id": "123",
                "label": 0,
                "a": 5,
                "b": 6,
                "title": "this is a title 3",
            },
        ]
        self.assertEqual(data, expected_data)
예제 #3
0
    def test_query(self):
        app = Vespa(url="http://localhost", port=8080)

        body = {"yql": "select * from sources * where test"}
        self.assertDictEqual(
            app.query(body=body, debug_request=True).request_body, body)

        self.assertDictEqual(
            app.query(
                query="this is a test",
                query_model=Query(match_phase=OR(),
                                  rank_profile=RankProfile()),
                debug_request=True,
                hits=10,
            ).request_body,
            {
                "yql":
                'select * from sources * where ([{"grammar": "any"}]userInput("this is a test"));',
                "ranking": {
                    "profile": "default",
                    "listFeatures": "false"
                },
                "hits": 10,
            },
        )

        self.assertDictEqual(
            app.query(
                query="this is a test",
                query_model=Query(match_phase=OR(),
                                  rank_profile=RankProfile()),
                debug_request=True,
                hits=10,
                recall=("id", [1, 5]),
            ).request_body,
            {
                "yql":
                'select * from sources * where ([{"grammar": "any"}]userInput("this is a test"));',
                "ranking": {
                    "profile": "default",
                    "listFeatures": "false"
                },
                "hits": 10,
                "recall": "+(id:1 id:5)",
            },
        )
예제 #4
0
    def test_collect_training_data(self):

        mock_return_value = [
            {
                "document_id": "abc",
                "query_id": "123",
                "relevant": 1,
                "a": 1,
                "b": 2,
            },
            {
                "document_id": "def",
                "query_id": "123",
                "relevant": 0,
                "a": 3,
                "b": 4,
            },
            {
                "document_id": "ghi",
                "query_id": "123",
                "relevant": 0,
                "a": 5,
                "b": 6,
            },
        ]
        self.app.collect_training_data_point = Mock(
            return_value=mock_return_value)
        labelled_data = [{
            "query_id": 123,
            "query": "this is a query",
            "relevant_docs": [{
                "id": "abc",
                "score": 1
            }],
        }]
        query_model = Query(rank_profile=RankProfile(list_features=True))
        data = self.app.collect_training_data(
            labelled_data=labelled_data,
            id_field="vespa_id_field",
            query_model=query_model,
            number_additional_docs=2,
            timeout="15s",
        )
        self.app.collect_training_data_point.assert_has_calls([
            call(
                query="this is a query",
                query_id=123,
                relevant_id="abc",
                id_field="vespa_id_field",
                query_model=query_model,
                number_additional_docs=2,
                relevant_score=1,
                default_score=0,
                timeout="15s",
            )
        ])
        assert_frame_equal(data, DataFrame.from_records(mock_return_value))
예제 #5
0
    def test_collect_training_data_point_0_recall_hits(self):

        self.raw_vespa_result_recall = {
            "root": {
                "id": "toplevel",
                "relevance": 1.0,
                "fields": {
                    "totalCount": 0
                },
                "coverage": {
                    "coverage": 100,
                    "documents": 62529,
                    "full": True,
                    "nodes": 2,
                    "results": 1,
                    "resultsFull": 1,
                },
            }
        }
        self.app.query = Mock(side_effect=[
            VespaQueryResponse(
                self.raw_vespa_result_recall, status_code=None, url=None),
            VespaQueryResponse(
                self.raw_vespa_result_additional, status_code=None, url=None),
        ])
        query_model = QueryModel(rank_profile=RankProfile(list_features=True))
        data = self.app.collect_training_data_point(
            query="this is a query",
            query_id="123",
            relevant_id="abc",
            id_field="vespa_id_field",
            query_model=query_model,
            number_additional_docs=2,
            fields=["rankfeatures"],
            timeout="15s",
        )

        self.assertEqual(self.app.query.call_count, 1)
        self.app.query.assert_has_calls([
            call(
                query="this is a query",
                query_model=query_model,
                recall=("vespa_id_field", ["abc"]),
                timeout="15s",
            ),
        ])
        expected_data = []
        self.assertEqual(data, expected_data)
예제 #6
0
파일: test_query.py 프로젝트: vdvorak/vespa
 def test_match_and_rank(self):
     query = Query(
         match_phase=ANN(
             doc_vector="doc_vector",
             query_vector="query_vector",
             embedding_model=lambda x: [1, 2, 3],
             hits=10,
             label="label",
         ),
         rank_profile=RankProfile(name="bm25", list_features=True),
     )
     self.assertDictEqual(
         query.create_body(query=self.query),
         {
             "yql":
             'select * from sources * where ([{"targetNumHits": 10, "label": "label"}]nearestNeighbor(doc_vector, query_vector));',
             "ranking": {
                 "profile": "bm25",
                 "listFeatures": "true"
             },
             "ranking.features.query(query_vector)": "[1, 2, 3]",
         },
     )
예제 #7
0
 def test_rank_profile(self):
     rank_profile = RankProfile(name="rank_profile", list_features=True)
     self.assertEqual(rank_profile.name, "rank_profile")
     self.assertEqual(rank_profile.list_features, "true")