def test_text_generator_empty(self): # Tests AssertionError is raised when specifying an # empty file path for generator. FLAGS.bleurt_checkpoint = get_test_checkpoint() with self.assertRaises(AssertionError): generator = score_files._text_generator("", "") score_files.score_files(generator, FLAGS.bleurt_checkpoint)
def test_score_files_sentence_pairs(self): # Tests specifying JSONL file of sentence pairs genereates accurate scores. checkpoint = get_test_checkpoint() sentence_pairs_file, _, _ = get_test_data() with tempfile.TemporaryDirectory() as temp_dir: FLAGS.scores_file = os.path.join(temp_dir, "scores") generator = score_files._json_generator(sentence_pairs_file) score_files.score_files(generator, checkpoint) self.assertTrue(tf.io.gfile.exists(FLAGS.scores_file)) scores = get_scores_from_scores_file(FLAGS.scores_file) self.assertLen(scores, 4) self.assertAllClose(scores, ref_scores)
def test_sentence_pairs_consume_buffer(self): # Tests specifying a number of sentence pairs that # exceeds BLEURT batch size, requiring a call to _consume_buffer. checkpoint = get_test_checkpoint() sentence_pairs_file, _, _ = get_test_data() with tempfile.TemporaryDirectory() as temp_dir: FLAGS.bleurt_batch_size = 1 FLAGS.scores_file = os.path.join(temp_dir, "scores") generator = score_files._json_generator(sentence_pairs_file) score_files.score_files(generator, checkpoint) scores = get_scores_from_scores_file(FLAGS.scores_file) self.assertLen(scores, 4) self.assertAllClose(scores, ref_scores)
def test_score_diff_sentence_pairs(self): # Tests specifying sentence pairs where number of candidates # and references lengths differ. checkpoint = get_test_checkpoint() with tempfile.TemporaryDirectory() as temp_dir: FLAGS.sentence_pairs_file = os.path.join(temp_dir, "sentence_pairs.jsonl") with tf.io.gfile.GFile(FLAGS.sentence_pairs_file, "w+") as sentence_pairs: sentence_pairs.write("{\"candidate\": \"sashimi\"}") with self.assertRaises(AssertionError): generator = score_files._json_generator( FLAGS.sentence_pairs_file) score_files.score_files(generator, checkpoint)
def test_score_files_text(self): # Tests specifying two text files for candidates # and references generates accurate scores. checkpoint = get_test_checkpoint() _, reference_file, candidate_file = get_test_data() with tempfile.TemporaryDirectory() as temp_dir: FLAGS.scores_file = os.path.join(temp_dir, "scores") generator = score_files._text_generator(reference_file, candidate_file) score_files.score_files(generator, checkpoint) self.assertTrue(tf.io.gfile.exists(FLAGS.scores_file)) scores = get_scores_from_scores_file(FLAGS.scores_file) self.assertLen(scores, 4) self.assertAllClose(scores, ref_scores)
def test_score_diff_text_files(self): # Tests specifying two text files where number of candidates # and references lengths differ. checkpoint = get_test_checkpoint() with tempfile.TemporaryDirectory() as temp_dir: FLAGS.reference_file = os.path.join(temp_dir, "references") FLAGS.candidate_file = os.path.join(temp_dir, "candidates") with tf.io.gfile.GFile(FLAGS.reference_file, "w+") as references: references.write("nigiri\nshrimp tempura\ntonkatsu") with tf.io.gfile.GFile(FLAGS.candidate_file, "w+") as candidates: candidates.write("ramen\nfish") with self.assertRaises(AssertionError): generator = score_files._text_generator( FLAGS.reference_file, FLAGS.candidate_file) score_files.score_files(generator, checkpoint)
def test_score_empty_reference_and_candidate_pair(self): # Tests scoring sentence pairs with empty candidate and empty reference. checkpoint = get_test_checkpoint() with tempfile.TemporaryDirectory() as temp_dir: FLAGS.sentence_pairs_file = os.path.join(temp_dir, "sentence_pairs.jsonl") FLAGS.scores_file = os.path.join(temp_dir, "scores") with tf.io.gfile.GFile(FLAGS.sentence_pairs_file, "w+") as sentence_pairs: sentence_pairs.write( "{\"candidate\": \"\", \"reference\": \"\"}") generator = score_files._json_generator(FLAGS.sentence_pairs_file) score_files.score_files(generator, checkpoint) scores = get_scores_from_scores_file(FLAGS.scores_file) self.assertLen(scores, 1) self.assertAllClose(scores, [0.679957])
def test_score_empty_candidate_and_reference_text(self): # Tests scoring text files with an empty candidate and reference. checkpoint = get_test_checkpoint() with tempfile.TemporaryDirectory() as temp_dir: FLAGS.reference_file = os.path.join(temp_dir, "references") FLAGS.candidate_file = os.path.join(temp_dir, "candidates") FLAGS.scores_file = os.path.join(temp_dir, "scores") with tf.io.gfile.GFile(FLAGS.reference_file, "w+") as references: references.write("\n") with tf.io.gfile.GFile(FLAGS.candidate_file, "w+") as candidates: candidates.write("\n") generator = score_files._text_generator(FLAGS.reference_file, FLAGS.candidate_file) score_files.score_files(generator, checkpoint) scores = get_scores_from_scores_file(FLAGS.scores_file) self.assertLen(scores, 1) self.assertAllClose(scores, [0.679957])