def evaluate(self, result_path): from sacrebleu import download_test_set, corpus_bleu, smart_open assert os.path.exists(result_path) if self.dataset_token is not None: _, *refs = download_test_set(self.dataset_token, self.langpair_token) if not refs: raise SystemError( "Error with dataset_token and langpair_token: {} {}". format(self.dataset_token, self.langpair_token)) refs = [smart_open(x, encoding="utf-8").readlines() for x in refs] else: refs = [self.ref_lines] hyp_lines = open(result_path).readlines() bleu = corpus_bleu(hyp_lines, refs) return float(bleu.score)
def save(self, *args, **kwargs): """Saves a test set object.""" if self.sacrebleu: # Download and update source and reference files if not self.source_file or not self.reference_file: from sacrebleu import download_test_set paths = download_test_set(self.sacrebleu, self.langpair()) if not self.source_file: src_name = '{0}.{1}.src.txt'.format( self.sacrebleu, self.langpair()) src_file = File(open(paths[0], encoding='utf-8')) self.source_file.save(src_name, src_file) if not self.reference_file: ref_name = '{0}.{1}.ref.txt'.format( self.sacrebleu, self.langpair()) ref_file = File(open(paths[1], encoding='utf-8')) self.reference_file.save(ref_name, ref_file) super(TestSet, self).save(*args, **kwargs)
def multi_bleu_score(candidate, target_vocab): lang_pair = '-'.join(DATASET.split('_')) candidate = reverse_index(candidate, target_vocab) _, *refs = sacrebleu.download_test_set('wmt14', lang_pair) bleu = sacrebleu.corpus_bleu(candidate, refs) return bleu.score