def test_raises_error_if_feature_is_not_separated_by_colon(self): contents = ["1 qid:1 2:0.0 3=0.1 4:10.0"] with self.assertRaisesRegex( ParserError, "failed to extract feature index and value from '3=0.1'"): for _ in LibSVMRankingParser(contents, {}): pass
def _generate_examples(self, path): """Yields examples.""" # Istella datasets seems to be encoded as latin1 and not utf-8, so we have # to read the file contents as bytes and manually decode it as latin1. with tf.io.gfile.GFile(path, "rb") as f: lines = map(lambda bytes_line: bytes_line.decode("latin1"), f) yield from LibSVMRankingParser(lines, _FEATURE_NAMES, _LABEL_NAME)
def test_raises_error_if_feature_value_is_malformatted(self): contents = ["1 qid:1 2:0.0 3:0.00.1 4:10.0"] with self.assertRaisesRegex( ParserError, "failed to extract feature index and value from '3:0.00.1'"): for _ in LibSVMRankingParser(contents, {}): pass
def test_raises_error_if_qid_is_missing(self): contents = ["0 2:0.0 3:0.0"] with self.assertRaisesRegex( ParserError, "line must contain a qid after the relevance label"): for _ in LibSVMRankingParser(contents, {}): pass
def test_parses_a_libsvm_ranking_dataset(self): results = iter( LibSVMRankingParser(_TEST_DATASET.split("\n"), feature_names={ 1: "bm25", 2: "tfidf", 3: "querylen", 4: "doclen", 5: "qualityscore" })) qid, features = next(results) self.assertEqual(qid, "1") self.assertAllEqual(features["label"], [3., 1., 0., 0.]) self.assertAllEqual(features["bm25"], [1., 0., 0., 0.]) self.assertAllEqual(features["tfidf"], [1., 0., 1., 0.]) self.assertAllEqual(features["querylen"], [0., 1., 0., 1.]) self.assertAllEqual(features["doclen"], [0.2, 0.1, 0.4, 0.3]) self.assertAllEqual(features["qualityscore"], [0., 1., 0., 0]) qid, features = next(results) self.assertEqual(qid, "2") self.assertAllEqual(features["label"], [0., 1., 0., 0., 1., 2.]) self.assertAllEqual(features["bm25"], [0., 1., 0., 0., 0., 1.]) self.assertAllEqual(features["tfidf"], [0., 0., 0., 0., 0., 1.]) self.assertAllEqual(features["querylen"], [1., 1., 1., 1., 1., 0.]) self.assertAllEqual(features["doclen"], [0.2, 0.4, 0.1, 0.2, 0.1, 0.3]) self.assertAllEqual(features["qualityscore"], [0., 0., 0., 0., 1., 0.]) qid, features = next(results) self.assertEqual(qid, "3") self.assertAllEqual(features["label"], [3., 0.]) self.assertAllEqual(features["bm25"], [1., 0.]) self.assertAllEqual(features["tfidf"], [0., 1.]) self.assertAllEqual(features["querylen"], [0., 1.]) self.assertAllEqual(features["doclen"], [0.4, 0.5]) self.assertAllEqual(features["qualityscore"], [1., 0.]) # Assert that the end of the file has been reached. with self.assertRaises(StopIteration): next(results)
def test_raises_error_if_label_is_not_a_number(self): contents = ["abc qid:1 2:0.0 3:0.0"] with self.assertRaisesRegex( ParserError, "label 'abc' could not be converted to a float"): for _ in LibSVMRankingParser(contents, {}): pass
def test_raises_error_if_line_only_contains_label(self): contents = ["1"] with self.assertRaisesRegex( ParserError, "could not extract label, qid and features"): for _ in LibSVMRankingParser(contents, {}): pass
def test_raises_error_with_offending_line(self): contents = ["1 qid:1 2:1000.0", "2 qid:1 9:9.9", "malformatted line"] with self.assertRaisesWithPredicateMatch( ParserError, lambda error: error.line == "malformatted line"): for _ in LibSVMRankingParser(contents, {}): pass
def test_raises_error_with_offending_line_number(self): contents = ["1 qid:1 2:0.0", "malformatted line", "1 qid:1 2:0.0"] with self.assertRaisesWithPredicateMatch( ParserError, lambda error: error.line_number == 2): for _ in LibSVMRankingParser(contents, {}): pass
def test_raises_error_if_qid_is_malformatted(self): contents = ["1 qid: 2:0.0 3:0.0"] with self.assertRaisesRegex(ParserError, "qid can not be empty"): for _ in LibSVMRankingParser(contents, {}): pass
def _generate_examples(self, path): """Yields examples.""" with tf.io.gfile.GFile(path, "r") as f: yield from LibSVMRankingParser(f, _FEATURE_NAMES, _LABEL_NAME)