def test_text_generator_empty(self):
     # Tests AssertionError is raised when specifying an
     # empty file path for generator.
     FLAGS.bleurt_checkpoint = get_test_checkpoint()
     with self.assertRaises(AssertionError):
         generator = score_files._text_generator("", "")
         score_files.score_files(generator, FLAGS.bleurt_checkpoint)
 def test_score_files_sentence_pairs(self):
     # Tests specifying JSONL file of sentence pairs genereates accurate scores.
     checkpoint = get_test_checkpoint()
     sentence_pairs_file, _, _ = get_test_data()
     with tempfile.TemporaryDirectory() as temp_dir:
         FLAGS.scores_file = os.path.join(temp_dir, "scores")
         generator = score_files._json_generator(sentence_pairs_file)
         score_files.score_files(generator, checkpoint)
         self.assertTrue(tf.io.gfile.exists(FLAGS.scores_file))
         scores = get_scores_from_scores_file(FLAGS.scores_file)
         self.assertLen(scores, 4)
         self.assertAllClose(scores, ref_scores)
 def test_sentence_pairs_consume_buffer(self):
     # Tests specifying a number of sentence pairs that
     # exceeds BLEURT batch size, requiring a call to _consume_buffer.
     checkpoint = get_test_checkpoint()
     sentence_pairs_file, _, _ = get_test_data()
     with tempfile.TemporaryDirectory() as temp_dir:
         FLAGS.bleurt_batch_size = 1
         FLAGS.scores_file = os.path.join(temp_dir, "scores")
         generator = score_files._json_generator(sentence_pairs_file)
         score_files.score_files(generator, checkpoint)
         scores = get_scores_from_scores_file(FLAGS.scores_file)
         self.assertLen(scores, 4)
         self.assertAllClose(scores, ref_scores)
 def test_score_diff_sentence_pairs(self):
     # Tests specifying sentence pairs where number of candidates
     # and references lengths differ.
     checkpoint = get_test_checkpoint()
     with tempfile.TemporaryDirectory() as temp_dir:
         FLAGS.sentence_pairs_file = os.path.join(temp_dir,
                                                  "sentence_pairs.jsonl")
         with tf.io.gfile.GFile(FLAGS.sentence_pairs_file,
                                "w+") as sentence_pairs:
             sentence_pairs.write("{\"candidate\": \"sashimi\"}")
         with self.assertRaises(AssertionError):
             generator = score_files._json_generator(
                 FLAGS.sentence_pairs_file)
             score_files.score_files(generator, checkpoint)
 def test_score_files_text(self):
     # Tests specifying two text files for candidates
     # and references generates accurate scores.
     checkpoint = get_test_checkpoint()
     _, reference_file, candidate_file = get_test_data()
     with tempfile.TemporaryDirectory() as temp_dir:
         FLAGS.scores_file = os.path.join(temp_dir, "scores")
         generator = score_files._text_generator(reference_file,
                                                 candidate_file)
         score_files.score_files(generator, checkpoint)
         self.assertTrue(tf.io.gfile.exists(FLAGS.scores_file))
         scores = get_scores_from_scores_file(FLAGS.scores_file)
         self.assertLen(scores, 4)
         self.assertAllClose(scores, ref_scores)
 def test_score_diff_text_files(self):
     # Tests specifying two text files where number of candidates
     # and references lengths differ.
     checkpoint = get_test_checkpoint()
     with tempfile.TemporaryDirectory() as temp_dir:
         FLAGS.reference_file = os.path.join(temp_dir, "references")
         FLAGS.candidate_file = os.path.join(temp_dir, "candidates")
         with tf.io.gfile.GFile(FLAGS.reference_file, "w+") as references:
             references.write("nigiri\nshrimp tempura\ntonkatsu")
         with tf.io.gfile.GFile(FLAGS.candidate_file, "w+") as candidates:
             candidates.write("ramen\nfish")
         with self.assertRaises(AssertionError):
             generator = score_files._text_generator(
                 FLAGS.reference_file, FLAGS.candidate_file)
             score_files.score_files(generator, checkpoint)
 def test_score_empty_reference_and_candidate_pair(self):
     # Tests scoring sentence pairs with empty candidate and empty reference.
     checkpoint = get_test_checkpoint()
     with tempfile.TemporaryDirectory() as temp_dir:
         FLAGS.sentence_pairs_file = os.path.join(temp_dir,
                                                  "sentence_pairs.jsonl")
         FLAGS.scores_file = os.path.join(temp_dir, "scores")
         with tf.io.gfile.GFile(FLAGS.sentence_pairs_file,
                                "w+") as sentence_pairs:
             sentence_pairs.write(
                 "{\"candidate\": \"\", \"reference\": \"\"}")
         generator = score_files._json_generator(FLAGS.sentence_pairs_file)
         score_files.score_files(generator, checkpoint)
         scores = get_scores_from_scores_file(FLAGS.scores_file)
         self.assertLen(scores, 1)
         self.assertAllClose(scores, [0.679957])
 def test_score_empty_candidate_and_reference_text(self):
     # Tests scoring text files with an empty candidate and reference.
     checkpoint = get_test_checkpoint()
     with tempfile.TemporaryDirectory() as temp_dir:
         FLAGS.reference_file = os.path.join(temp_dir, "references")
         FLAGS.candidate_file = os.path.join(temp_dir, "candidates")
         FLAGS.scores_file = os.path.join(temp_dir, "scores")
         with tf.io.gfile.GFile(FLAGS.reference_file, "w+") as references:
             references.write("\n")
         with tf.io.gfile.GFile(FLAGS.candidate_file, "w+") as candidates:
             candidates.write("\n")
         generator = score_files._text_generator(FLAGS.reference_file,
                                                 FLAGS.candidate_file)
         score_files.score_files(generator, checkpoint)
         scores = get_scores_from_scores_file(FLAGS.scores_file)
         self.assertLen(scores, 1)
         self.assertAllClose(scores, [0.679957])