Esempio n. 1
0
 def test_raises_error_if_feature_is_not_separated_by_colon(self):
     contents = ["1 qid:1 2:0.0 3=0.1 4:10.0"]
     with self.assertRaisesRegex(
             ParserError,
             "failed to extract feature index and value from '3=0.1'"):
         for _ in LibSVMRankingParser(contents, {}):
             pass
Esempio n. 2
0
 def _generate_examples(self, path):
     """Yields examples."""
     # Istella datasets seems to be encoded as latin1 and not utf-8, so we have
     # to read the file contents as bytes and manually decode it as latin1.
     with tf.io.gfile.GFile(path, "rb") as f:
         lines = map(lambda bytes_line: bytes_line.decode("latin1"), f)
         yield from LibSVMRankingParser(lines, _FEATURE_NAMES, _LABEL_NAME)
Esempio n. 3
0
 def test_raises_error_if_feature_value_is_malformatted(self):
     contents = ["1 qid:1 2:0.0 3:0.00.1 4:10.0"]
     with self.assertRaisesRegex(
             ParserError,
             "failed to extract feature index and value from '3:0.00.1'"):
         for _ in LibSVMRankingParser(contents, {}):
             pass
Esempio n. 4
0
 def test_raises_error_if_qid_is_missing(self):
     contents = ["0 2:0.0 3:0.0"]
     with self.assertRaisesRegex(
             ParserError,
             "line must contain a qid after the relevance label"):
         for _ in LibSVMRankingParser(contents, {}):
             pass
Esempio n. 5
0
    def test_parses_a_libsvm_ranking_dataset(self):

        results = iter(
            LibSVMRankingParser(_TEST_DATASET.split("\n"),
                                feature_names={
                                    1: "bm25",
                                    2: "tfidf",
                                    3: "querylen",
                                    4: "doclen",
                                    5: "qualityscore"
                                }))

        qid, features = next(results)
        self.assertEqual(qid, "1")
        self.assertAllEqual(features["label"], [3., 1., 0., 0.])
        self.assertAllEqual(features["bm25"], [1., 0., 0., 0.])
        self.assertAllEqual(features["tfidf"], [1., 0., 1., 0.])
        self.assertAllEqual(features["querylen"], [0., 1., 0., 1.])
        self.assertAllEqual(features["doclen"], [0.2, 0.1, 0.4, 0.3])
        self.assertAllEqual(features["qualityscore"], [0., 1., 0., 0])

        qid, features = next(results)
        self.assertEqual(qid, "2")
        self.assertAllEqual(features["label"], [0., 1., 0., 0., 1., 2.])
        self.assertAllEqual(features["bm25"], [0., 1., 0., 0., 0., 1.])
        self.assertAllEqual(features["tfidf"], [0., 0., 0., 0., 0., 1.])
        self.assertAllEqual(features["querylen"], [1., 1., 1., 1., 1., 0.])
        self.assertAllEqual(features["doclen"], [0.2, 0.4, 0.1, 0.2, 0.1, 0.3])
        self.assertAllEqual(features["qualityscore"], [0., 0., 0., 0., 1., 0.])

        qid, features = next(results)
        self.assertEqual(qid, "3")
        self.assertAllEqual(features["label"], [3., 0.])
        self.assertAllEqual(features["bm25"], [1., 0.])
        self.assertAllEqual(features["tfidf"], [0., 1.])
        self.assertAllEqual(features["querylen"], [0., 1.])
        self.assertAllEqual(features["doclen"], [0.4, 0.5])
        self.assertAllEqual(features["qualityscore"], [1., 0.])

        # Assert that the end of the file has been reached.
        with self.assertRaises(StopIteration):
            next(results)
Esempio n. 6
0
 def test_raises_error_if_label_is_not_a_number(self):
     contents = ["abc qid:1 2:0.0 3:0.0"]
     with self.assertRaisesRegex(
             ParserError, "label 'abc' could not be converted to a float"):
         for _ in LibSVMRankingParser(contents, {}):
             pass
Esempio n. 7
0
 def test_raises_error_if_line_only_contains_label(self):
     contents = ["1"]
     with self.assertRaisesRegex(
             ParserError, "could not extract label, qid and features"):
         for _ in LibSVMRankingParser(contents, {}):
             pass
Esempio n. 8
0
 def test_raises_error_with_offending_line(self):
     contents = ["1 qid:1 2:1000.0", "2 qid:1 9:9.9", "malformatted line"]
     with self.assertRaisesWithPredicateMatch(
             ParserError, lambda error: error.line == "malformatted line"):
         for _ in LibSVMRankingParser(contents, {}):
             pass
Esempio n. 9
0
 def test_raises_error_with_offending_line_number(self):
     contents = ["1 qid:1 2:0.0", "malformatted line", "1 qid:1 2:0.0"]
     with self.assertRaisesWithPredicateMatch(
             ParserError, lambda error: error.line_number == 2):
         for _ in LibSVMRankingParser(contents, {}):
             pass
Esempio n. 10
0
 def test_raises_error_if_qid_is_malformatted(self):
     contents = ["1 qid: 2:0.0 3:0.0"]
     with self.assertRaisesRegex(ParserError, "qid can not be empty"):
         for _ in LibSVMRankingParser(contents, {}):
             pass
Esempio n. 11
0
 def _generate_examples(self, path):
     """Yields examples."""
     with tf.io.gfile.GFile(path, "r") as f:
         yield from LibSVMRankingParser(f, _FEATURE_NAMES, _LABEL_NAME)