def test_bleu_list(self): ref = ["test 1 two 3", "more tests!"] hyp = ["test 1 two 3", "More tests!"] uncased_score = bleu.bleu_on_list(ref, hyp, False) cased_score = bleu.bleu_on_list(ref, hyp, True) self.assertEqual(uncased_score, 100) self.assertLess(cased_score, 100)
def reduce_aggregated_logs(self, aggregated_logs, global_step=None): def _decode(ids): return self._sp_tokenizer.detokenize(ids).numpy().decode() def _trim_and_decode(ids): """Trim EOS and PAD tokens from ids, and decode to return a string.""" try: index = list(ids).index(self._eos_id) return _decode(ids[:index]) except ValueError: # No EOS found in sequence return _decode(ids) translations = [] for u_id in sorted(aggregated_logs): if u_id >= len(self._references): continue src = _trim_and_decode(aggregated_logs[u_id][0]) translation = _trim_and_decode(aggregated_logs[u_id][1]) translations.append(translation) if self.task_config.print_translations: # Deccoding the in_ids to reflect what the model sees. logging.info("Translating:\n\tInput: %s\n\tOutput: %s\n\tReference: %s", src, translation, self._references[u_id]) sacrebleu_score = sacrebleu.corpus_bleu( translations, [self._references]).score bleu_score = bleu.bleu_on_list(self._references, translations) return {"sacrebleu_score": sacrebleu_score, "bleu_score": bleu_score}