コード例 #1
0
 def test_bleu_list(self):
   ref = ["test 1 two 3", "more tests!"]
   hyp = ["test 1 two 3", "More tests!"]
   uncased_score = bleu.bleu_on_list(ref, hyp, False)
   cased_score = bleu.bleu_on_list(ref, hyp, True)
   self.assertEqual(uncased_score, 100)
   self.assertLess(cased_score, 100)
コード例 #2
0
  def reduce_aggregated_logs(self, aggregated_logs, global_step=None):

    def _decode(ids):
      return self._sp_tokenizer.detokenize(ids).numpy().decode()

    def _trim_and_decode(ids):
      """Trim EOS and PAD tokens from ids, and decode to return a string."""
      try:
        index = list(ids).index(self._eos_id)
        return _decode(ids[:index])
      except ValueError:  # No EOS found in sequence
        return _decode(ids)

    translations = []
    for u_id in sorted(aggregated_logs):
      if u_id >= len(self._references):
        continue
      src = _trim_and_decode(aggregated_logs[u_id][0])
      translation = _trim_and_decode(aggregated_logs[u_id][1])
      translations.append(translation)
      if self.task_config.print_translations:
        # Deccoding the in_ids to reflect what the model sees.
        logging.info("Translating:\n\tInput: %s\n\tOutput: %s\n\tReference: %s",
                     src, translation, self._references[u_id])
    sacrebleu_score = sacrebleu.corpus_bleu(
        translations, [self._references]).score
    bleu_score = bleu.bleu_on_list(self._references, translations)
    return {"sacrebleu_score": sacrebleu_score,
            "bleu_score": bleu_score}