def test_score(self): metric = BertScoreMetric(lang='en', model_type='roberta-large', num_layers=17, verbose=False, idf=False,\ batch_size=3, rescale_with_baseline=False) score_dict = metric.evaluate_batch(CANDS, REFS) avgP = sum([0.9843302369117737, 0.9832239747047424, 0.9120386242866516])/3 avgR = sum([0.9823839068412781, 0.9732863903045654, 0.920428991317749])/3 avgF = sum([0.9833561182022095, 0.9782299995422363, 0.916214644908905])/3 self.assertTrue((score_dict['bert_score_precision'] - avgP) < EPS) self.assertTrue((score_dict['bert_score_recall'] - avgR) < EPS) self.assertTrue((score_dict['bert_score_f1'] - avgF) < EPS)
def test_multi_refs(self): cands = ['I like lemons.'] refs = [['I am proud of you.', 'I love lemons.', 'Go go go.']] metric = BertScoreMetric(lang='en', batch_size=3, rescale_with_baseline=True) score_dict = metric.evaluate_batch(cands, refs) score_dict_best = metric.evaluate_batch(cands, [refs[0][1]]) self.assertTrue((score_dict['bert_score_precision'] - score_dict_best['bert_score_precision']) < EPS) self.assertTrue((score_dict['bert_score_recall'] - score_dict_best['bert_score_recall']) < EPS) self.assertTrue((score_dict['bert_score_f1'] - score_dict_best['bert_score_f1']) < EPS)
def test_idf_score_rescale(self): metric = BertScoreMetric(lang='en', model_type='roberta-large', num_layers=17, verbose=False, idf=True,\ batch_size=3, rescale_with_baseline=True) score_dict = metric.evaluate_batch(CANDS, REFS) avgP = sum([0.903778135776520, 0.854439020156860, 0.375287383794785])/3 avgR = sum([0.897446095943451, 0.820639789104462, 0.509167850017548])/3 avgF = sum([0.900772094726562, 0.837753534317017, 0.442304641008377])/3 self.assertTrue((score_dict['bert_score_precision'] - avgP) < EPS) self.assertTrue((score_dict['bert_score_recall'] - avgR) < EPS) self.assertTrue((score_dict['bert_score_f1'] - avgF) < EPS)
def test_score_rescale(self): metric = BertScoreMetric(lang='en', model_type='roberta-large', num_layers=17, verbose=False, idf=False,\ batch_size=3, rescale_with_baseline=True) score_dict = metric.evaluate_batch(CANDS, REFS) avgP = sum([0.907000780105591, 0.900435566902161, 0.477955609560013])/3 avgR = sum([0.895456790924072, 0.841467440128326, 0.527785062789917])/3 avgF = sum([0.901383399963379, 0.871010780334473, 0.503565192222595])/3 self.assertTrue((score_dict['bert_score_precision'] - avgP) < EPS) self.assertTrue((score_dict['bert_score_recall'] - avgR) < EPS) self.assertTrue((score_dict['bert_score_f1'] - avgF) < EPS)
def test_idf_score(self): metric = BertScoreMetric(lang='en', model_type='roberta-large', num_layers=17, verbose=False, idf=True,\ batch_size=3, rescale_with_baseline=False) score_dict = metric.evaluate_batch(CANDS, REFS) avgP = sum([0.9837872385978699, 0.9754738807678223, 0.8947395086288452])/3 avgR = sum([0.9827190637588501, 0.9697767496109009, 0.9172918796539307])/3 avgF = sum([0.9832529425621033, 0.972616970539093, 0.9058753848075867])/3 self.assertTrue((score_dict['bert_score_precision'] - avgP) < EPS) self.assertTrue((score_dict['bert_score_recall'] - avgR) < EPS) self.assertTrue((score_dict['bert_score_f1'] - avgF) < EPS)
def cli_main(args): # ===================================== # INITIALIZE METRICS gin.parse_config_file(args.config_file) toks_needed = set() metrics = [x.strip() for x in args.metrics.split(",")] metrics_dict = {} if "rouge" in metrics: from summ_eval.rouge_metric import RougeMetric metrics_dict["rouge"] = RougeMetric() toks_needed.add("line_delimited") if "bert_score" in metrics: from summ_eval.bert_score_metric import BertScoreMetric bert_score_metric = BertScoreMetric() metrics_dict["bert_score"] = bert_score_metric toks_needed.add("space") if "mover_score" in metrics: from summ_eval.mover_score_metric import MoverScoreMetric mover_score_metric = MoverScoreMetric() metrics_dict["mover_score"] = mover_score_metric toks_needed.add("space") if "chrf" in metrics: from summ_eval.chrfpp_metric import ChrfppMetric metrics_dict["chrf"] = ChrfppMetric() toks_needed.add("space") if "meteor" in metrics: from summ_eval.meteor_metric import MeteorMetric metrics_dict["meteor"] = MeteorMetric() toks_needed.add("space") if "bleu" in metrics: from summ_eval.bleu_metric import BleuMetric metrics_dict["bleu"] = BleuMetric() toks_needed.add("space") if "cider" in metrics: from summ_eval.cider_metric import CiderMetric metrics_dict["cider"] = CiderMetric() toks_needed.add("stem") if "s3" in metrics: from summ_eval.s3_metric import S3Metric metrics_dict["s3"] = S3Metric() toks_needed.add("stem") if "rouge_we" in metrics: from summ_eval.rouge_we_metric import RougeWeMetric metrics_dict["rouge_we"] = RougeWeMetric() toks_needed.add("stem") if "stats" in metrics: from summ_eval.data_stats_metric import DataStatsMetric metrics_dict['stats'] = DataStatsMetric() toks_needed.add("spacy") if "sms" in metrics: from summ_eval.sentence_movers_metric import SentenceMoversMetric metrics_dict['sms'] = SentenceMoversMetric() toks_needed.add("spacy") if "summaqa" in metrics: from summ_eval.summa_qa_metric import SummaQAMetric metrics_dict['summaqa'] = SummaQAMetric() toks_needed.add("spacy") toks_needed.add("space") if "syntactic" in metrics: from summ_eval.syntactic_metric import SyntacticMetric metrics_dict["syntactic"] = SyntacticMetric() toks_needed.add("space") if "supert" in metrics: from summ_eval.supert_metric import SupertMetric metrics_dict['supert'] = SupertMetric() toks_needed.add("space") if "blanc" in metrics: from summ_eval.blanc_metric import BlancMetric metrics_dict['blanc'] = BlancMetric() toks_needed.add("space") # ===================================== # ===================================== # READ INPUT print("Reading the input") ids = [] articles = [] references = [] summaries = [] bad_lines = 0 if args.jsonl_file is not None: try: with open(args.jsonl_file) as inputf: for count, line in enumerate(inputf): try: data = json.loads(line) try: ids.append(data['id']) except: pass if len(data['decoded']) == 0: bad_lines += 1 continue summaries.append(data['decoded']) # references.append(data['reference']) if data.get("reference", None): references.append(data['reference']) else: # there are 10 additional references added, the first is the orginal references.append(data["references"][0]) # if "summaqa" in metrics or "stats" in metrics or "supert" in metrics or "blanc" in metrics: # remove stats if "summaqa" in metrics or "supert" in metrics or "blanc" in metrics: try: articles.append(data['text']) except: raise ValueError("You specified summaqa and stats, which" \ "require input articles, but we could not parse the file!") except: bad_lines += 1 except Exception as e: print("Input did not match required format") print(e) sys.exit() print(f"This many bad lines encountered during loading: {bad_lines}") if args.summ_file is not None: with open(args.summ_file) as inputf: summaries = inputf.read().splitlines() if args.ref_file is not None: with open(args.ref_file) as inputf: references = inputf.read().splitlines() # if "summaqa" in metrics or "stats" in metrics or "supert" in metrics or "blanc" in metrics: if "summaqa" in metrics or "supert" in metrics or "blanc" in metrics: if args.article_file is None and len(articles) == 0: raise ValueError("You specified summaqa and stats, which" \ "require input articles, but we could not parse the file!") if len(articles) > 0: pass else: with open(args.article_file) as inputf: articles = inputf.read().splitlines() if len(ids) == 0: ids = list(range(0, len(summaries))) # ===================================== # ===================================== # TOKENIZATION print("Preparing the input") references_delimited = None summaries_delimited = None if len(references) > 0: if isinstance(references[0], list): if "line_delimited" in toks_needed: references_delimited = ["\n".join(ref) for ref in references] if "space" in toks_needed: references_space = [" ".join(ref) for ref in references] elif args.eos is not None: if "line_delimited" not in toks_needed: raise ValueError('You provided a delimiter but are not using a metric which requires one.') if args.eos == "\n": references_delimited = [ref.split(args.eos) for ref in references] else: references_delimited = [f"{args.eos}\n".join(ref.split(args.eos)) for ref in references] elif "line_delimited" in toks_needed: references_delimited = references if "space" in toks_needed: references_space = references if isinstance(summaries[0], list): if "line_delimited" in toks_needed: summaries_delimited = ["\n".join(summ) for summ in summaries] if "space" in toks_needed: summaries_space = [" ".join(summ) for summ in summaries] elif args.eos is not None: if "line_delimited" not in toks_needed: raise ValueError('You provided a delimiter but are not using a metric which requires one.') if args.eos == "\n": summaries_delimited = [ref.split(args.eos) for ref in summaries] else: summaries_delimited = [f"{args.eos}\n".join(ref.split(args.eos)) for ref in summaries] elif "line_delimited" in toks_needed: summaries_delimited = summaries if "space" in toks_needed: summaries_space = summaries if "stem" in toks_needed: tokenizer = RegexpTokenizer(r'\w+') stemmer = SnowballStemmer("english") if isinstance(summaries[0], list): summaries_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(" ".join(summ))] for summ in summaries] references_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(" ".join(ref))] for ref in references] else: summaries_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(summ)] for summ in summaries] references_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(ref)] for ref in references] summaries_stemmed = [" ".join(summ) for summ in summaries_stemmed] references_stemmed = [" ".join(ref) for ref in references_stemmed] if "spacy" in toks_needed: nlp = spacy.load('en_core_web_sm') # nlp = spacy.load('en_core_web_md') disable = ["tagger", "textcat"] if "summaqa" not in metrics: disable.append("ner") if isinstance(summaries[0], list): summaries_spacy = [nlp(" ".join(text), disable=disable) for text in summaries] else: summaries_spacy = [nlp(text, disable=disable) for text in summaries] if "stats" in metrics: summaries_spacy_stats = [[tok.text for tok in summary] for summary in summaries_spacy] if "sms" in metrics: if isinstance(references[0], list): references_spacy = [nlp(" ".join(text), disable=disable) for text in references] else: references_spacy = [nlp(text, disable=disable) for text in references] # if "summaqa" in metrics or "stats" in metrics: # if isinstance(articles[0], list): # input_spacy = [nlp(" ".join(text), disable=disable) for text in articles] # else: # input_spacy = [nlp(text, disable=disable) for text in articles] # if "stats" in metrics: # input_spacy_stats = [[tok.text for tok in article] for article in input_spacy] # use reference as article for stats if "summaqa" in metrics or "stats" in metrics: if isinstance(references[0], list): input_spacy = [nlp(" ".join(text), disable=disable) for text in references] else: input_spacy = [nlp(text, disable=disable) for text in references] if "stats" in metrics: input_spacy_stats = [[tok.text for tok in ref] for ref in input_spacy] if "supert" in metrics or "blanc" in metrics: inputs_space = articles # ===================================== # ===================================== # GET SCORES if args.aggregate: final_output = dict() else: final_output = defaultdict(lambda: defaultdict(int)) # import pdb;pdb.set_trace() for metric, metric_cls in metrics_dict.items(): print(f"Calculating scores for the {metric} metric.") try: if metric == "rouge": output = metric_cls.evaluate_batch(summaries_delimited, references_delimited, aggregate=args.aggregate) # only rouge uses this input so we can delete it del references_delimited del summaries_delimited elif metric in ('bert_score', 'mover_score', 'chrf', 'meteor', 'bleu'): output = metric_cls.evaluate_batch(summaries_space, references_space, aggregate=args.aggregate) elif metric in ('s3', 'rouge_we', 'cider'): output = metric_cls.evaluate_batch(summaries_stemmed, references_stemmed, aggregate=args.aggregate) elif metric == "sms": output = metric_cls.evaluate_batch(summaries_spacy, references_spacy, aggregate=args.aggregate) elif metric in ('summaqa', 'stats', 'supert', 'blanc'): if metric == "summaqa": output = metric_cls.evaluate_batch(summaries_space, input_spacy, aggregate=args.aggregate) elif metric == "stats": output = metric_cls.evaluate_batch(summaries_spacy_stats, input_spacy_stats, aggregate=args.aggregate) elif metric in ('supert', 'blanc'): output = metric_cls.evaluate_batch(summaries_space, inputs_space, aggregate=args.aggregate) if args.aggregate: final_output.update(output) else: ids = list(range(0, len(ids))) for cur_id, cur_output in zip(ids, output): final_output[cur_id].update(cur_output) except Exception as e: print(e) print(f"An error was encountered with the {metric} metric.") # ===================================== # ===================================== # OUTPUT SCORES metrics_str = "_".join(metrics) # json_file_end = args.jsonl_file.split("/")[-1] json_file_end = args.jsonl_file.replace("/", "_") output_path = f"output_{metrics_str}.jsonl" print(f"saving to {output_path}") # with open(f"outputs/{args.output_file}_{json_file_end}_{metrics_str}.jsonl", "w") as outputf: with open(output_path, "w") as outputf: if args.aggregate: json.dump(final_output, outputf) else: for key, value in final_output.items(): value["id"] = key json.dump(value, outputf) outputf.write("\n") # ===================================== print(f"Write scores to: {output_path}") return output_path
def main(args): try: assert len(args.exps) == len(args.summary_type) == 2 except AssertionError: # do a single file evaluation. assert len(args.exps) > 0, "At least one experiment must be specified." print(f"Performing single-file evaluation on {args.exps[0]}") args.exps = [args.exps[0], args.exps[0]] args.summary_type = ['reference', 'decoded'] batch_size = 4096 inputs = [] for exp, key in zip(args.exps, args.summary_type): ipfiles = exp.glob("eval.jsonl*") data = {} # legacy compatibility to existing code if key == "reference": key = "gold" for ipfile in ipfiles: with ipfile.open("r") as fd: d = [json.loads(line) for line in fd] data.update({ex["paper_id"]: ex[key] for ex in d}) inputs.append(data) # rearrange to pid-to-(system_1, system_2) dict. data = list({pid: (inputs[0][pid], inputs[1][pid]) for pid in inputs[0]}.items()) rouge_metric = RougeMetric() bertscore_metric = BertScoreMetric() results = [] for batch in tqdm(batch_data(data, batch_size), total=len(data) / batch_size, ncols=80, ascii=True): batch_pids, batch_texts = zip(*batch) clean_refs = [replace_html(ex[0]) for ex in batch_texts] clean_preds = [replace_html(ex[1]) for ex in batch_texts] # assign zero-scores for empty reference cases empty_ref_indices = [i for i, x in enumerate(clean_refs) if x == ""] if len(empty_ref_indices) > 0: wrapped = [] for i in empty_ref_indices: wrapped.append({ "paper_id": batch_pids[i], args.summary_type[0]: batch_texts[i][0], args.summary_type[1]: batch_texts[i][1], "rouge": { "rouge_1_f_score": 0, "rouge_2_f_score": 0, "rouge_3_f_score": 0, "rouge_4_f_score": 0, "rouge_l_f_score": 0, "rouge_w_1.2_f_score": 0, "rouge_s*_f_score": 0, "rouge_su*_f_score": 0, }, "bert_score": dict(bert_score_f1=0), }) results.extend(wrapped) batch_pids = [ ex for i, ex in enumerate(batch_pids) if i not in empty_ref_indices ] clean_refs = [ ex for i, ex in enumerate(clean_refs) if i not in empty_ref_indices ] clean_preds = [ ex for i, ex in enumerate(clean_preds) if i not in empty_ref_indices ] try: rouges = rouge_metric.evaluate_batch(clean_preds, clean_refs, aggregate=False) bertscores = bertscore_metric.evaluate_batch(clean_preds, clean_refs, aggregate=False) except: import pdb pdb.set_trace() results.extend( wrap_results(args, batch_pids, batch_texts, rouges, bertscores)) with args.output_file.open("w") as fd: for example in results: print(json.dumps(example), file=fd)