Exemplo n.º 1
0
    def test_score(self):
        metric = BertScoreMetric(lang='en', model_type='roberta-large', num_layers=17, verbose=False, idf=False,\
                 batch_size=3, rescale_with_baseline=False)
        score_dict = metric.evaluate_batch(CANDS, REFS)

        avgP = sum([0.9843302369117737, 0.9832239747047424, 0.9120386242866516])/3
        avgR = sum([0.9823839068412781, 0.9732863903045654, 0.920428991317749])/3
        avgF = sum([0.9833561182022095, 0.9782299995422363, 0.916214644908905])/3
        self.assertTrue((score_dict['bert_score_precision'] - avgP) < EPS)
        self.assertTrue((score_dict['bert_score_recall'] - avgR) < EPS)
        self.assertTrue((score_dict['bert_score_f1'] - avgF) < EPS)
Exemplo n.º 2
0
    def test_multi_refs(self):
        cands = ['I like lemons.']
        refs = [['I am proud of you.', 'I love lemons.', 'Go go go.']]
        metric = BertScoreMetric(lang='en', batch_size=3, rescale_with_baseline=True)

        score_dict = metric.evaluate_batch(cands, refs)
        score_dict_best = metric.evaluate_batch(cands, [refs[0][1]])

        self.assertTrue((score_dict['bert_score_precision'] - score_dict_best['bert_score_precision']) < EPS)
        self.assertTrue((score_dict['bert_score_recall'] - score_dict_best['bert_score_recall']) < EPS)
        self.assertTrue((score_dict['bert_score_f1'] - score_dict_best['bert_score_f1']) < EPS)
Exemplo n.º 3
0
    def test_idf_score_rescale(self):
        metric = BertScoreMetric(lang='en', model_type='roberta-large', num_layers=17, verbose=False, idf=True,\
                 batch_size=3, rescale_with_baseline=True)
        score_dict = metric.evaluate_batch(CANDS, REFS)

        avgP = sum([0.903778135776520, 0.854439020156860, 0.375287383794785])/3
        avgR = sum([0.897446095943451, 0.820639789104462, 0.509167850017548])/3
        avgF = sum([0.900772094726562, 0.837753534317017, 0.442304641008377])/3
        self.assertTrue((score_dict['bert_score_precision'] - avgP) < EPS)
        self.assertTrue((score_dict['bert_score_recall'] - avgR) < EPS)
        self.assertTrue((score_dict['bert_score_f1'] - avgF) < EPS)
Exemplo n.º 4
0
    def test_score_rescale(self):
        metric = BertScoreMetric(lang='en', model_type='roberta-large', num_layers=17, verbose=False, idf=False,\
                 batch_size=3, rescale_with_baseline=True)
        score_dict = metric.evaluate_batch(CANDS, REFS)

        avgP = sum([0.907000780105591, 0.900435566902161, 0.477955609560013])/3
        avgR = sum([0.895456790924072, 0.841467440128326, 0.527785062789917])/3
        avgF = sum([0.901383399963379, 0.871010780334473, 0.503565192222595])/3
        self.assertTrue((score_dict['bert_score_precision'] - avgP) < EPS)
        self.assertTrue((score_dict['bert_score_recall'] - avgR) < EPS)
        self.assertTrue((score_dict['bert_score_f1'] - avgF) < EPS)
Exemplo n.º 5
0
    def test_idf_score(self):
        metric = BertScoreMetric(lang='en', model_type='roberta-large', num_layers=17, verbose=False, idf=True,\
                 batch_size=3, rescale_with_baseline=False)
        score_dict = metric.evaluate_batch(CANDS, REFS)

        avgP = sum([0.9837872385978699, 0.9754738807678223, 0.8947395086288452])/3
        avgR = sum([0.9827190637588501, 0.9697767496109009, 0.9172918796539307])/3
        avgF = sum([0.9832529425621033, 0.972616970539093, 0.9058753848075867])/3
        self.assertTrue((score_dict['bert_score_precision'] - avgP) < EPS)
        self.assertTrue((score_dict['bert_score_recall'] - avgR) < EPS)
        self.assertTrue((score_dict['bert_score_f1'] - avgF) < EPS)
Exemplo n.º 6
0
def cli_main(args):
    # =====================================
    # INITIALIZE METRICS
    gin.parse_config_file(args.config_file)
    toks_needed = set()
    metrics = [x.strip() for x in args.metrics.split(",")]
    metrics_dict = {}
    if "rouge" in metrics:
        from summ_eval.rouge_metric import RougeMetric
        metrics_dict["rouge"] = RougeMetric()
        toks_needed.add("line_delimited")

    if "bert_score" in metrics:
        from summ_eval.bert_score_metric import BertScoreMetric
        bert_score_metric = BertScoreMetric()
        metrics_dict["bert_score"] = bert_score_metric
        toks_needed.add("space")
    if "mover_score" in metrics:
        from summ_eval.mover_score_metric import MoverScoreMetric
        mover_score_metric = MoverScoreMetric()
        metrics_dict["mover_score"] = mover_score_metric
        toks_needed.add("space")
    if "chrf" in metrics:
        from summ_eval.chrfpp_metric import ChrfppMetric
        metrics_dict["chrf"] = ChrfppMetric()
        toks_needed.add("space")
    if "meteor" in metrics:
        from summ_eval.meteor_metric import MeteorMetric
        metrics_dict["meteor"] = MeteorMetric()
        toks_needed.add("space")
    if "bleu" in metrics:
        from summ_eval.bleu_metric import BleuMetric
        metrics_dict["bleu"] = BleuMetric()
        toks_needed.add("space")
    if "cider" in metrics:
        from summ_eval.cider_metric import CiderMetric
        metrics_dict["cider"] = CiderMetric()
        toks_needed.add("stem")

    if "s3" in metrics:
        from summ_eval.s3_metric import S3Metric
        metrics_dict["s3"] = S3Metric()
        toks_needed.add("stem")
    if "rouge_we" in metrics:
        from summ_eval.rouge_we_metric import RougeWeMetric
        metrics_dict["rouge_we"] = RougeWeMetric()
        toks_needed.add("stem")

    if "stats" in metrics:
        from summ_eval.data_stats_metric import DataStatsMetric
        metrics_dict['stats'] = DataStatsMetric()
        toks_needed.add("spacy")
    if "sms" in metrics:
        from summ_eval.sentence_movers_metric import SentenceMoversMetric
        metrics_dict['sms'] = SentenceMoversMetric()
        toks_needed.add("spacy")
    if "summaqa" in metrics:
        from summ_eval.summa_qa_metric import SummaQAMetric
        metrics_dict['summaqa'] = SummaQAMetric()
        toks_needed.add("spacy")
        toks_needed.add("space")
    if "syntactic" in metrics:
        from summ_eval.syntactic_metric import SyntacticMetric
        metrics_dict["syntactic"] = SyntacticMetric()
        toks_needed.add("space")
    if "supert" in metrics:
        from summ_eval.supert_metric import SupertMetric
        metrics_dict['supert'] = SupertMetric()
        toks_needed.add("space")
    if "blanc" in metrics:
        from summ_eval.blanc_metric import BlancMetric
        metrics_dict['blanc'] = BlancMetric()
        toks_needed.add("space")
    # =====================================

    # =====================================
    # READ INPUT
    print("Reading the input")
    ids = []
    articles = []
    references = []
    summaries = []
    bad_lines = 0
    if args.jsonl_file is not None:
        try:
            with open(args.jsonl_file) as inputf:
                for count, line in enumerate(inputf):
                    try:
                        data = json.loads(line)
                        try:
                            ids.append(data['id'])
                        except:
                            pass
                        if len(data['decoded']) == 0:
                            bad_lines += 1
                            continue
                        summaries.append(data['decoded'])
                        # references.append(data['reference'])
                        if data.get("reference", None):
                            references.append(data['reference'])
                        else:  # there are 10 additional references added, the first is the orginal
                            references.append(data["references"][0])
                        # if "summaqa" in metrics or "stats" in metrics or "supert" in metrics or "blanc" in metrics:
                        # remove stats
                        if "summaqa" in metrics or "supert" in metrics or "blanc" in metrics:
                            try:
                                articles.append(data['text'])
                            except:
                                raise ValueError("You specified summaqa and stats, which" \
                                                 "require input articles, but we could not parse the file!")
                    except:
                        bad_lines += 1
        except Exception as e:
            print("Input did not match required format")
            print(e)
            sys.exit()
        print(f"This many bad lines encountered during loading: {bad_lines}")

    if args.summ_file is not None:
        with open(args.summ_file) as inputf:
            summaries = inputf.read().splitlines()
    if args.ref_file is not None:
        with open(args.ref_file) as inputf:
            references = inputf.read().splitlines()
    # if "summaqa" in metrics or "stats" in metrics or "supert" in metrics or "blanc" in metrics:
    if "summaqa" in metrics or "supert" in metrics or "blanc" in metrics:
        if args.article_file is None and len(articles) == 0:
            raise ValueError("You specified summaqa and stats, which" \
                             "require input articles, but we could not parse the file!")
        if len(articles) > 0:
            pass
        else:
            with open(args.article_file) as inputf:
                articles = inputf.read().splitlines()
    if len(ids) == 0:
        ids = list(range(0, len(summaries)))
    # =====================================

    # =====================================
    # TOKENIZATION
    print("Preparing the input")
    references_delimited = None
    summaries_delimited = None
    if len(references) > 0:
        if isinstance(references[0], list):
            if "line_delimited" in toks_needed:
                references_delimited = ["\n".join(ref) for ref in references]
            if "space" in toks_needed:
                references_space = [" ".join(ref) for ref in references]
        elif args.eos is not None:
            if "line_delimited" not in toks_needed:
                raise ValueError('You provided a delimiter but are not using a metric which requires one.')
            if args.eos == "\n":
                references_delimited = [ref.split(args.eos) for ref in references]
            else:
                references_delimited = [f"{args.eos}\n".join(ref.split(args.eos)) for ref in references]
        elif "line_delimited" in toks_needed:
            references_delimited = references
        if "space" in toks_needed:
            references_space = references

    if isinstance(summaries[0], list):
        if "line_delimited" in toks_needed:
            summaries_delimited = ["\n".join(summ) for summ in summaries]
        if "space" in toks_needed:
            summaries_space = [" ".join(summ) for summ in summaries]
    elif args.eos is not None:
        if "line_delimited" not in toks_needed:
            raise ValueError('You provided a delimiter but are not using a metric which requires one.')
        if args.eos == "\n":
            summaries_delimited = [ref.split(args.eos) for ref in summaries]
        else:
            summaries_delimited = [f"{args.eos}\n".join(ref.split(args.eos)) for ref in summaries]
    elif "line_delimited" in toks_needed:
        summaries_delimited = summaries
    if "space" in toks_needed:
        summaries_space = summaries

    if "stem" in toks_needed:
        tokenizer = RegexpTokenizer(r'\w+')
        stemmer = SnowballStemmer("english")
        if isinstance(summaries[0], list):
            summaries_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(" ".join(summ))] for summ in
                                 summaries]
            references_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(" ".join(ref))] for ref in
                                  references]
        else:
            summaries_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(summ)] for summ in summaries]
            references_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(ref)] for ref in references]
        summaries_stemmed = [" ".join(summ) for summ in summaries_stemmed]
        references_stemmed = [" ".join(ref) for ref in references_stemmed]

    if "spacy" in toks_needed:
        nlp = spacy.load('en_core_web_sm')
        # nlp = spacy.load('en_core_web_md')
        disable = ["tagger", "textcat"]
        if "summaqa" not in metrics:
            disable.append("ner")
        if isinstance(summaries[0], list):
            summaries_spacy = [nlp(" ".join(text), disable=disable) for text in summaries]
        else:
            summaries_spacy = [nlp(text, disable=disable) for text in summaries]
        if "stats" in metrics:
            summaries_spacy_stats = [[tok.text for tok in summary] for summary in summaries_spacy]
        if "sms" in metrics:
            if isinstance(references[0], list):
                references_spacy = [nlp(" ".join(text), disable=disable) for text in references]
            else:
                references_spacy = [nlp(text, disable=disable) for text in references]

        # if "summaqa" in metrics or "stats" in metrics:
        #     if isinstance(articles[0], list):
        #         input_spacy = [nlp(" ".join(text), disable=disable) for text in articles]
        #     else:
        #         input_spacy = [nlp(text, disable=disable) for text in articles]
        #     if "stats" in metrics:
        #         input_spacy_stats = [[tok.text for tok in article] for article in input_spacy]
        # use reference as article for stats
        if "summaqa" in metrics or "stats" in metrics:
            if isinstance(references[0], list):
                input_spacy = [nlp(" ".join(text), disable=disable) for text in references]
            else:
                input_spacy = [nlp(text, disable=disable) for text in references]
            if "stats" in metrics:
                input_spacy_stats = [[tok.text for tok in ref] for ref in input_spacy]
    if "supert" in metrics or "blanc" in metrics:
        inputs_space = articles
    # =====================================

    # =====================================
    # GET SCORES
    if args.aggregate:
        final_output = dict()
    else:
        final_output = defaultdict(lambda: defaultdict(int))
    # import pdb;pdb.set_trace()
    for metric, metric_cls in metrics_dict.items():
        print(f"Calculating scores for the {metric} metric.")
        try:
            if metric == "rouge":
                output = metric_cls.evaluate_batch(summaries_delimited, references_delimited, aggregate=args.aggregate)
                # only rouge uses this input so we can delete it
                del references_delimited
                del summaries_delimited
            elif metric in ('bert_score', 'mover_score', 'chrf', 'meteor', 'bleu'):
                output = metric_cls.evaluate_batch(summaries_space, references_space, aggregate=args.aggregate)
            elif metric in ('s3', 'rouge_we', 'cider'):
                output = metric_cls.evaluate_batch(summaries_stemmed, references_stemmed, aggregate=args.aggregate)
            elif metric == "sms":
                output = metric_cls.evaluate_batch(summaries_spacy, references_spacy, aggregate=args.aggregate)
            elif metric in ('summaqa', 'stats', 'supert', 'blanc'):
                if metric == "summaqa":
                    output = metric_cls.evaluate_batch(summaries_space, input_spacy, aggregate=args.aggregate)
                elif metric == "stats":
                    output = metric_cls.evaluate_batch(summaries_spacy_stats, input_spacy_stats,
                                                       aggregate=args.aggregate)
                elif metric in ('supert', 'blanc'):
                    output = metric_cls.evaluate_batch(summaries_space, inputs_space, aggregate=args.aggregate)
            if args.aggregate:
                final_output.update(output)
            else:
                ids = list(range(0, len(ids)))
                for cur_id, cur_output in zip(ids, output):
                    final_output[cur_id].update(cur_output)
        except Exception as e:
            print(e)
            print(f"An error was encountered with the {metric} metric.")
    # =====================================

    # =====================================
    # OUTPUT SCORES
    metrics_str = "_".join(metrics)
    # json_file_end = args.jsonl_file.split("/")[-1]
    json_file_end = args.jsonl_file.replace("/", "_")
    output_path = f"output_{metrics_str}.jsonl"
    print(f"saving to {output_path}")
    # with open(f"outputs/{args.output_file}_{json_file_end}_{metrics_str}.jsonl", "w") as outputf:
    with open(output_path, "w") as outputf:
        if args.aggregate:
            json.dump(final_output, outputf)
        else:
            for key, value in final_output.items():
                value["id"] = key
                json.dump(value, outputf)
                outputf.write("\n")
    # =====================================
    print(f"Write scores to: {output_path}")
    return output_path
Exemplo n.º 7
0
def main(args):
    try:
        assert len(args.exps) == len(args.summary_type) == 2
    except AssertionError:
        # do a single file evaluation.
        assert len(args.exps) > 0, "At least one experiment must be specified."
        print(f"Performing single-file evaluation on {args.exps[0]}")
        args.exps = [args.exps[0], args.exps[0]]
        args.summary_type = ['reference', 'decoded']

    batch_size = 4096

    inputs = []
    for exp, key in zip(args.exps, args.summary_type):
        ipfiles = exp.glob("eval.jsonl*")
        data = {}
        # legacy compatibility to existing code
        if key == "reference":
            key = "gold"
        for ipfile in ipfiles:
            with ipfile.open("r") as fd:
                d = [json.loads(line) for line in fd]
            data.update({ex["paper_id"]: ex[key] for ex in d})
        inputs.append(data)

    # rearrange to pid-to-(system_1, system_2) dict.
    data = list({pid: (inputs[0][pid], inputs[1][pid])
                 for pid in inputs[0]}.items())

    rouge_metric = RougeMetric()
    bertscore_metric = BertScoreMetric()

    results = []
    for batch in tqdm(batch_data(data, batch_size),
                      total=len(data) / batch_size,
                      ncols=80,
                      ascii=True):
        batch_pids, batch_texts = zip(*batch)
        clean_refs = [replace_html(ex[0]) for ex in batch_texts]
        clean_preds = [replace_html(ex[1]) for ex in batch_texts]

        # assign zero-scores for empty reference cases
        empty_ref_indices = [i for i, x in enumerate(clean_refs) if x == ""]
        if len(empty_ref_indices) > 0:
            wrapped = []
            for i in empty_ref_indices:
                wrapped.append({
                    "paper_id": batch_pids[i],
                    args.summary_type[0]: batch_texts[i][0],
                    args.summary_type[1]: batch_texts[i][1],
                    "rouge": {
                        "rouge_1_f_score": 0,
                        "rouge_2_f_score": 0,
                        "rouge_3_f_score": 0,
                        "rouge_4_f_score": 0,
                        "rouge_l_f_score": 0,
                        "rouge_w_1.2_f_score": 0,
                        "rouge_s*_f_score": 0,
                        "rouge_su*_f_score": 0,
                    },
                    "bert_score": dict(bert_score_f1=0),
                })
            results.extend(wrapped)
            batch_pids = [
                ex for i, ex in enumerate(batch_pids)
                if i not in empty_ref_indices
            ]
            clean_refs = [
                ex for i, ex in enumerate(clean_refs)
                if i not in empty_ref_indices
            ]
            clean_preds = [
                ex for i, ex in enumerate(clean_preds)
                if i not in empty_ref_indices
            ]

        try:
            rouges = rouge_metric.evaluate_batch(clean_preds,
                                                 clean_refs,
                                                 aggregate=False)
            bertscores = bertscore_metric.evaluate_batch(clean_preds,
                                                         clean_refs,
                                                         aggregate=False)
        except:
            import pdb

            pdb.set_trace()
        results.extend(
            wrap_results(args, batch_pids, batch_texts, rouges, bertscores))

    with args.output_file.open("w") as fd:
        for example in results:
            print(json.dumps(example), file=fd)