示例#1
0
 def test_default(self):
     query = Query()
     self.assertDictEqual(
         query.create_body(query=self.query),
         {
             "yql":
             'select * from sources * where (userInput("this is  a test"));',
             "ranking": {
                 "profile": "default",
                 "listFeatures": "false"
             },
         },
     )
示例#2
0
 def test_evaluate_query(self):
     self.app.query = Mock(return_value={})
     eval_metric = Mock()
     eval_metric.evaluate_query = Mock(return_value={"metric": 1})
     eval_metric2 = Mock()
     eval_metric2.evaluate_query = Mock(return_value={"metric_2": 2})
     query_model = Query()
     evaluation = self.app.evaluate_query(
         eval_metrics=[eval_metric, eval_metric2],
         query_model=query_model,
         query_id="0",
         query="this is a test",
         id_field="vespa_id_field",
         relevant_docs=self.labelled_data[0]["relevant_docs"],
         default_score=0,
         hits=10,
     )
     self.assertEqual(self.app.query.call_count, 1)
     self.app.query.assert_has_calls([
         call(query="this is a test", query_model=query_model, hits=10),
     ])
     self.assertEqual(eval_metric.evaluate_query.call_count, 1)
     eval_metric.evaluate_query.assert_has_calls([
         call({}, self.labelled_data[0]["relevant_docs"], "vespa_id_field",
              0),
     ])
     self.assertDictEqual(evaluation, {
         "query_id": "0",
         "metric": 1,
         "metric_2": 2
     })
示例#3
0
    def test_query(self):
        app = Vespa(url="http://localhost", port=8080)

        body = {"yql": "select * from sources * where test"}
        self.assertDictEqual(
            app.query(body=body, debug_request=True).request_body, body)

        self.assertDictEqual(
            app.query(
                query="this is a test",
                query_model=Query(match_phase=OR(),
                                  rank_profile=RankProfile()),
                debug_request=True,
                hits=10,
            ).request_body,
            {
                "yql":
                'select * from sources * where ([{"grammar": "any"}]userInput("this is a test"));',
                "ranking": {
                    "profile": "default",
                    "listFeatures": "false"
                },
                "hits": 10,
            },
        )

        self.assertDictEqual(
            app.query(
                query="this is a test",
                query_model=Query(match_phase=OR(),
                                  rank_profile=RankProfile()),
                debug_request=True,
                hits=10,
                recall=("id", [1, 5]),
            ).request_body,
            {
                "yql":
                'select * from sources * where ([{"grammar": "any"}]userInput("this is a test"));',
                "ranking": {
                    "profile": "default",
                    "listFeatures": "false"
                },
                "hits": 10,
                "recall": "+(id:1 id:5)",
            },
        )
示例#4
0
    def test_collect_training_data(self):

        mock_return_value = [
            {
                "document_id": "abc",
                "query_id": "123",
                "relevant": 1,
                "a": 1,
                "b": 2,
            },
            {
                "document_id": "def",
                "query_id": "123",
                "relevant": 0,
                "a": 3,
                "b": 4,
            },
            {
                "document_id": "ghi",
                "query_id": "123",
                "relevant": 0,
                "a": 5,
                "b": 6,
            },
        ]
        self.app.collect_training_data_point = Mock(
            return_value=mock_return_value)
        labelled_data = [{
            "query_id": 123,
            "query": "this is a query",
            "relevant_docs": [{
                "id": "abc",
                "score": 1
            }],
        }]
        query_model = Query(rank_profile=RankProfile(list_features=True))
        data = self.app.collect_training_data(
            labelled_data=labelled_data,
            id_field="vespa_id_field",
            query_model=query_model,
            number_additional_docs=2,
            timeout="15s",
        )
        self.app.collect_training_data_point.assert_has_calls([
            call(
                query="this is a query",
                query_id=123,
                relevant_id="abc",
                id_field="vespa_id_field",
                query_model=query_model,
                number_additional_docs=2,
                relevant_score=1,
                default_score=0,
                timeout="15s",
            )
        ])
        assert_frame_equal(data, DataFrame.from_records(mock_return_value))
示例#5
0
    def test_collect_training_data_point(self):

        self.app.query = Mock(side_effect=[
            VespaResult(self.raw_vespa_result_recall),
            VespaResult(self.raw_vespa_result_additional),
        ])
        query_model = Query(rank_profile=RankProfile(list_features=True))
        data = self.app.collect_training_data_point(
            query="this is a query",
            query_id="123",
            relevant_id="abc",
            id_field="vespa_id_field",
            query_model=query_model,
            number_additional_docs=2,
            timeout="15s",
        )

        self.assertEqual(self.app.query.call_count, 2)
        self.app.query.assert_has_calls([
            call(
                query="this is a query",
                query_model=query_model,
                recall=("vespa_id_field", ["abc"]),
                timeout="15s",
            ),
            call(
                query="this is a query",
                query_model=query_model,
                hits=2,
                timeout="15s",
            ),
        ])
        expected_data = [
            {
                "document_id": "abc",
                "query_id": "123",
                "relevant": 1,
                "a": 1,
                "b": 2
            },
            {
                "document_id": "def",
                "query_id": "123",
                "relevant": 0,
                "a": 3,
                "b": 4
            },
            {
                "document_id": "ghi",
                "query_id": "123",
                "relevant": 0,
                "a": 5,
                "b": 6
            },
        ]
        self.assertEqual(data, expected_data)
示例#6
0
 def test_disable_rank_features(self):
     with self.assertRaises(AssertionError):
         self.app.collect_training_data_point(
             query="this is a query",
             query_id="123",
             relevant_id="abc",
             id_field="vespa_id_field",
             query_model=Query(),
             number_additional_docs=2,
         )
示例#7
0
 def test_match_and_rank(self):
     query = Query(
         match_phase=ANN(
             doc_vector="doc_vector",
             query_vector="query_vector",
             embedding_model=lambda x: [1, 2, 3],
             hits=10,
             label="label",
         ),
         rank_profile=RankProfile(name="bm25", list_features=True),
     )
     self.assertDictEqual(
         query.create_body(query=self.query),
         {
             "yql":
             'select * from sources * where ([{"targetNumHits": 10, "label": "label"}]nearestNeighbor(doc_vector, query_vector));',
             "ranking": {
                 "profile": "bm25",
                 "listFeatures": "true"
             },
             "ranking.features.query(query_vector)": "[1, 2, 3]",
         },
     )
示例#8
0
    def test_collect_training_data_point_0_recall_hits(self):

        self.raw_vespa_result_recall = {
            "root": {
                "id": "toplevel",
                "relevance": 1.0,
                "fields": {
                    "totalCount": 0
                },
                "coverage": {
                    "coverage": 100,
                    "documents": 62529,
                    "full": True,
                    "nodes": 2,
                    "results": 1,
                    "resultsFull": 1,
                },
            }
        }
        self.app.query = Mock(side_effect=[
            VespaResult(self.raw_vespa_result_recall),
            VespaResult(self.raw_vespa_result_additional),
        ])
        query_model = Query(rank_profile=RankProfile(list_features=True))
        data = self.app.collect_training_data_point(
            query="this is a query",
            query_id="123",
            relevant_id="abc",
            id_field="vespa_id_field",
            query_model=query_model,
            number_additional_docs=2,
            timeout="15s",
        )

        self.assertEqual(self.app.query.call_count, 1)
        self.app.query.assert_has_calls([
            call(
                query="this is a query",
                query_model=query_model,
                recall=("vespa_id_field", ["abc"]),
                timeout="15s",
            ),
        ])
        expected_data = []
        self.assertEqual(data, expected_data)