def test_glove_sms(self): metric = SentenceMoversMetric(wordrep='glove', metric='sms') score = metric.evaluate_batch(CANDS, REFS) avg = sum([0.998619953103503, 0.4511249300530058, 0.2903306055171392 ]) / 3 self.assertTrue((score["sentence_movers_glove_sms"] - avg) < EPS)
def test_glove_swms(self): metric = SentenceMoversMetric(wordrep='glove', metric='s+wms') score = metric.evaluate_batch(CANDS, REFS) avg = sum([0.9993097383211589, 0.330209238465829, 0.11248476630690257 ]) / 3 self.assertTrue((score["sentence_movers_glove_s+wms"] - avg) < EPS)
def cli_main(args): # ===================================== # INITIALIZE METRICS gin.parse_config_file(args.config_file) toks_needed = set() metrics = [x.strip() for x in args.metrics.split(",")] metrics_dict = {} if "rouge" in metrics: from summ_eval.rouge_metric import RougeMetric metrics_dict["rouge"] = RougeMetric() toks_needed.add("line_delimited") if "bert_score" in metrics: from summ_eval.bert_score_metric import BertScoreMetric bert_score_metric = BertScoreMetric() metrics_dict["bert_score"] = bert_score_metric toks_needed.add("space") if "mover_score" in metrics: from summ_eval.mover_score_metric import MoverScoreMetric mover_score_metric = MoverScoreMetric() metrics_dict["mover_score"] = mover_score_metric toks_needed.add("space") if "chrf" in metrics: from summ_eval.chrfpp_metric import ChrfppMetric metrics_dict["chrf"] = ChrfppMetric() toks_needed.add("space") if "meteor" in metrics: from summ_eval.meteor_metric import MeteorMetric metrics_dict["meteor"] = MeteorMetric() toks_needed.add("space") if "bleu" in metrics: from summ_eval.bleu_metric import BleuMetric metrics_dict["bleu"] = BleuMetric() toks_needed.add("space") if "cider" in metrics: from summ_eval.cider_metric import CiderMetric metrics_dict["cider"] = CiderMetric() toks_needed.add("stem") if "s3" in metrics: from summ_eval.s3_metric import S3Metric metrics_dict["s3"] = S3Metric() toks_needed.add("stem") if "rouge_we" in metrics: from summ_eval.rouge_we_metric import RougeWeMetric metrics_dict["rouge_we"] = RougeWeMetric() toks_needed.add("stem") if "stats" in metrics: from summ_eval.data_stats_metric import DataStatsMetric metrics_dict['stats'] = DataStatsMetric() toks_needed.add("spacy") if "sms" in metrics: from summ_eval.sentence_movers_metric import SentenceMoversMetric metrics_dict['sms'] = SentenceMoversMetric() toks_needed.add("spacy") if "summaqa" in metrics: from summ_eval.summa_qa_metric import SummaQAMetric metrics_dict['summaqa'] = SummaQAMetric() toks_needed.add("spacy") toks_needed.add("space") if "syntactic" in metrics: from summ_eval.syntactic_metric import SyntacticMetric metrics_dict["syntactic"] = SyntacticMetric() toks_needed.add("space") if "supert" in metrics: from summ_eval.supert_metric import SupertMetric metrics_dict['supert'] = SupertMetric() toks_needed.add("space") if "blanc" in metrics: from summ_eval.blanc_metric import BlancMetric metrics_dict['blanc'] = BlancMetric() toks_needed.add("space") # ===================================== # ===================================== # READ INPUT print("Reading the input") ids = [] articles = [] references = [] summaries = [] bad_lines = 0 if args.jsonl_file is not None: try: with open(args.jsonl_file) as inputf: for count, line in enumerate(inputf): try: data = json.loads(line) try: ids.append(data['id']) except: pass if len(data['decoded']) == 0: bad_lines += 1 continue summaries.append(data['decoded']) # references.append(data['reference']) if data.get("reference", None): references.append(data['reference']) else: # there are 10 additional references added, the first is the orginal references.append(data["references"][0]) # if "summaqa" in metrics or "stats" in metrics or "supert" in metrics or "blanc" in metrics: # remove stats if "summaqa" in metrics or "supert" in metrics or "blanc" in metrics: try: articles.append(data['text']) except: raise ValueError("You specified summaqa and stats, which" \ "require input articles, but we could not parse the file!") except: bad_lines += 1 except Exception as e: print("Input did not match required format") print(e) sys.exit() print(f"This many bad lines encountered during loading: {bad_lines}") if args.summ_file is not None: with open(args.summ_file) as inputf: summaries = inputf.read().splitlines() if args.ref_file is not None: with open(args.ref_file) as inputf: references = inputf.read().splitlines() # if "summaqa" in metrics or "stats" in metrics or "supert" in metrics or "blanc" in metrics: if "summaqa" in metrics or "supert" in metrics or "blanc" in metrics: if args.article_file is None and len(articles) == 0: raise ValueError("You specified summaqa and stats, which" \ "require input articles, but we could not parse the file!") if len(articles) > 0: pass else: with open(args.article_file) as inputf: articles = inputf.read().splitlines() if len(ids) == 0: ids = list(range(0, len(summaries))) # ===================================== # ===================================== # TOKENIZATION print("Preparing the input") references_delimited = None summaries_delimited = None if len(references) > 0: if isinstance(references[0], list): if "line_delimited" in toks_needed: references_delimited = ["\n".join(ref) for ref in references] if "space" in toks_needed: references_space = [" ".join(ref) for ref in references] elif args.eos is not None: if "line_delimited" not in toks_needed: raise ValueError('You provided a delimiter but are not using a metric which requires one.') if args.eos == "\n": references_delimited = [ref.split(args.eos) for ref in references] else: references_delimited = [f"{args.eos}\n".join(ref.split(args.eos)) for ref in references] elif "line_delimited" in toks_needed: references_delimited = references if "space" in toks_needed: references_space = references if isinstance(summaries[0], list): if "line_delimited" in toks_needed: summaries_delimited = ["\n".join(summ) for summ in summaries] if "space" in toks_needed: summaries_space = [" ".join(summ) for summ in summaries] elif args.eos is not None: if "line_delimited" not in toks_needed: raise ValueError('You provided a delimiter but are not using a metric which requires one.') if args.eos == "\n": summaries_delimited = [ref.split(args.eos) for ref in summaries] else: summaries_delimited = [f"{args.eos}\n".join(ref.split(args.eos)) for ref in summaries] elif "line_delimited" in toks_needed: summaries_delimited = summaries if "space" in toks_needed: summaries_space = summaries if "stem" in toks_needed: tokenizer = RegexpTokenizer(r'\w+') stemmer = SnowballStemmer("english") if isinstance(summaries[0], list): summaries_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(" ".join(summ))] for summ in summaries] references_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(" ".join(ref))] for ref in references] else: summaries_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(summ)] for summ in summaries] references_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(ref)] for ref in references] summaries_stemmed = [" ".join(summ) for summ in summaries_stemmed] references_stemmed = [" ".join(ref) for ref in references_stemmed] if "spacy" in toks_needed: nlp = spacy.load('en_core_web_sm') # nlp = spacy.load('en_core_web_md') disable = ["tagger", "textcat"] if "summaqa" not in metrics: disable.append("ner") if isinstance(summaries[0], list): summaries_spacy = [nlp(" ".join(text), disable=disable) for text in summaries] else: summaries_spacy = [nlp(text, disable=disable) for text in summaries] if "stats" in metrics: summaries_spacy_stats = [[tok.text for tok in summary] for summary in summaries_spacy] if "sms" in metrics: if isinstance(references[0], list): references_spacy = [nlp(" ".join(text), disable=disable) for text in references] else: references_spacy = [nlp(text, disable=disable) for text in references] # if "summaqa" in metrics or "stats" in metrics: # if isinstance(articles[0], list): # input_spacy = [nlp(" ".join(text), disable=disable) for text in articles] # else: # input_spacy = [nlp(text, disable=disable) for text in articles] # if "stats" in metrics: # input_spacy_stats = [[tok.text for tok in article] for article in input_spacy] # use reference as article for stats if "summaqa" in metrics or "stats" in metrics: if isinstance(references[0], list): input_spacy = [nlp(" ".join(text), disable=disable) for text in references] else: input_spacy = [nlp(text, disable=disable) for text in references] if "stats" in metrics: input_spacy_stats = [[tok.text for tok in ref] for ref in input_spacy] if "supert" in metrics or "blanc" in metrics: inputs_space = articles # ===================================== # ===================================== # GET SCORES if args.aggregate: final_output = dict() else: final_output = defaultdict(lambda: defaultdict(int)) # import pdb;pdb.set_trace() for metric, metric_cls in metrics_dict.items(): print(f"Calculating scores for the {metric} metric.") try: if metric == "rouge": output = metric_cls.evaluate_batch(summaries_delimited, references_delimited, aggregate=args.aggregate) # only rouge uses this input so we can delete it del references_delimited del summaries_delimited elif metric in ('bert_score', 'mover_score', 'chrf', 'meteor', 'bleu'): output = metric_cls.evaluate_batch(summaries_space, references_space, aggregate=args.aggregate) elif metric in ('s3', 'rouge_we', 'cider'): output = metric_cls.evaluate_batch(summaries_stemmed, references_stemmed, aggregate=args.aggregate) elif metric == "sms": output = metric_cls.evaluate_batch(summaries_spacy, references_spacy, aggregate=args.aggregate) elif metric in ('summaqa', 'stats', 'supert', 'blanc'): if metric == "summaqa": output = metric_cls.evaluate_batch(summaries_space, input_spacy, aggregate=args.aggregate) elif metric == "stats": output = metric_cls.evaluate_batch(summaries_spacy_stats, input_spacy_stats, aggregate=args.aggregate) elif metric in ('supert', 'blanc'): output = metric_cls.evaluate_batch(summaries_space, inputs_space, aggregate=args.aggregate) if args.aggregate: final_output.update(output) else: ids = list(range(0, len(ids))) for cur_id, cur_output in zip(ids, output): final_output[cur_id].update(cur_output) except Exception as e: print(e) print(f"An error was encountered with the {metric} metric.") # ===================================== # ===================================== # OUTPUT SCORES metrics_str = "_".join(metrics) # json_file_end = args.jsonl_file.split("/")[-1] json_file_end = args.jsonl_file.replace("/", "_") output_path = f"output_{metrics_str}.jsonl" print(f"saving to {output_path}") # with open(f"outputs/{args.output_file}_{json_file_end}_{metrics_str}.jsonl", "w") as outputf: with open(output_path, "w") as outputf: if args.aggregate: json.dump(final_output, outputf) else: for key, value in final_output.items(): value["id"] = key json.dump(value, outputf) outputf.write("\n") # ===================================== print(f"Write scores to: {output_path}") return output_path
def test_glove_swms_single(self): metric = SentenceMoversMetric(wordrep='glove', metric='s+wms') score = metric.evaluate_example(CANDS[0], REFS[0]) score0 = 0.9993097383211589 self.assertTrue((score["sentence_movers_glove_s+wms"] - score0) < EPS)
def test_glove_sms_single(self): metric = SentenceMoversMetric(wordrep='glove', metric='sms') score = metric.evaluate_example(CANDS[0], REFS[0]) score0 = 0.998619953103503 self.assertTrue((score["sentence_movers_glove_sms"] - score0) < EPS)
def test_glove_wms(self): metric = SentenceMoversMetric(wordrep='glove', metric='wms') score = metric.evaluate_batch(CANDS, REFS) avg = sum([1.0, 0.2417027898540903, 0.04358073486688484]) / 3 self.assertTrue((score["sentence_movers_glove_wms"] - avg) < EPS)