def test_glove_sms(self):
        metric = SentenceMoversMetric(wordrep='glove', metric='sms')
        score = metric.evaluate_batch(CANDS, REFS)
        avg = sum([0.998619953103503, 0.4511249300530058, 0.2903306055171392
                   ]) / 3

        self.assertTrue((score["sentence_movers_glove_sms"] - avg) < EPS)
    def test_glove_swms(self):
        metric = SentenceMoversMetric(wordrep='glove', metric='s+wms')
        score = metric.evaluate_batch(CANDS, REFS)
        avg = sum([0.9993097383211589, 0.330209238465829, 0.11248476630690257
                   ]) / 3

        self.assertTrue((score["sentence_movers_glove_s+wms"] - avg) < EPS)
Exemplo n.º 3
0
def cli_main(args):
    # =====================================
    # INITIALIZE METRICS
    gin.parse_config_file(args.config_file)
    toks_needed = set()
    metrics = [x.strip() for x in args.metrics.split(",")]
    metrics_dict = {}
    if "rouge" in metrics:
        from summ_eval.rouge_metric import RougeMetric
        metrics_dict["rouge"] = RougeMetric()
        toks_needed.add("line_delimited")

    if "bert_score" in metrics:
        from summ_eval.bert_score_metric import BertScoreMetric
        bert_score_metric = BertScoreMetric()
        metrics_dict["bert_score"] = bert_score_metric
        toks_needed.add("space")
    if "mover_score" in metrics:
        from summ_eval.mover_score_metric import MoverScoreMetric
        mover_score_metric = MoverScoreMetric()
        metrics_dict["mover_score"] = mover_score_metric
        toks_needed.add("space")
    if "chrf" in metrics:
        from summ_eval.chrfpp_metric import ChrfppMetric
        metrics_dict["chrf"] = ChrfppMetric()
        toks_needed.add("space")
    if "meteor" in metrics:
        from summ_eval.meteor_metric import MeteorMetric
        metrics_dict["meteor"] = MeteorMetric()
        toks_needed.add("space")
    if "bleu" in metrics:
        from summ_eval.bleu_metric import BleuMetric
        metrics_dict["bleu"] = BleuMetric()
        toks_needed.add("space")
    if "cider" in metrics:
        from summ_eval.cider_metric import CiderMetric
        metrics_dict["cider"] = CiderMetric()
        toks_needed.add("stem")

    if "s3" in metrics:
        from summ_eval.s3_metric import S3Metric
        metrics_dict["s3"] = S3Metric()
        toks_needed.add("stem")
    if "rouge_we" in metrics:
        from summ_eval.rouge_we_metric import RougeWeMetric
        metrics_dict["rouge_we"] = RougeWeMetric()
        toks_needed.add("stem")

    if "stats" in metrics:
        from summ_eval.data_stats_metric import DataStatsMetric
        metrics_dict['stats'] = DataStatsMetric()
        toks_needed.add("spacy")
    if "sms" in metrics:
        from summ_eval.sentence_movers_metric import SentenceMoversMetric
        metrics_dict['sms'] = SentenceMoversMetric()
        toks_needed.add("spacy")
    if "summaqa" in metrics:
        from summ_eval.summa_qa_metric import SummaQAMetric
        metrics_dict['summaqa'] = SummaQAMetric()
        toks_needed.add("spacy")
        toks_needed.add("space")
    if "syntactic" in metrics:
        from summ_eval.syntactic_metric import SyntacticMetric
        metrics_dict["syntactic"] = SyntacticMetric()
        toks_needed.add("space")
    if "supert" in metrics:
        from summ_eval.supert_metric import SupertMetric
        metrics_dict['supert'] = SupertMetric()
        toks_needed.add("space")
    if "blanc" in metrics:
        from summ_eval.blanc_metric import BlancMetric
        metrics_dict['blanc'] = BlancMetric()
        toks_needed.add("space")
    # =====================================

    # =====================================
    # READ INPUT
    print("Reading the input")
    ids = []
    articles = []
    references = []
    summaries = []
    bad_lines = 0
    if args.jsonl_file is not None:
        try:
            with open(args.jsonl_file) as inputf:
                for count, line in enumerate(inputf):
                    try:
                        data = json.loads(line)
                        try:
                            ids.append(data['id'])
                        except:
                            pass
                        if len(data['decoded']) == 0:
                            bad_lines += 1
                            continue
                        summaries.append(data['decoded'])
                        # references.append(data['reference'])
                        if data.get("reference", None):
                            references.append(data['reference'])
                        else:  # there are 10 additional references added, the first is the orginal
                            references.append(data["references"][0])
                        # if "summaqa" in metrics or "stats" in metrics or "supert" in metrics or "blanc" in metrics:
                        # remove stats
                        if "summaqa" in metrics or "supert" in metrics or "blanc" in metrics:
                            try:
                                articles.append(data['text'])
                            except:
                                raise ValueError("You specified summaqa and stats, which" \
                                                 "require input articles, but we could not parse the file!")
                    except:
                        bad_lines += 1
        except Exception as e:
            print("Input did not match required format")
            print(e)
            sys.exit()
        print(f"This many bad lines encountered during loading: {bad_lines}")

    if args.summ_file is not None:
        with open(args.summ_file) as inputf:
            summaries = inputf.read().splitlines()
    if args.ref_file is not None:
        with open(args.ref_file) as inputf:
            references = inputf.read().splitlines()
    # if "summaqa" in metrics or "stats" in metrics or "supert" in metrics or "blanc" in metrics:
    if "summaqa" in metrics or "supert" in metrics or "blanc" in metrics:
        if args.article_file is None and len(articles) == 0:
            raise ValueError("You specified summaqa and stats, which" \
                             "require input articles, but we could not parse the file!")
        if len(articles) > 0:
            pass
        else:
            with open(args.article_file) as inputf:
                articles = inputf.read().splitlines()
    if len(ids) == 0:
        ids = list(range(0, len(summaries)))
    # =====================================

    # =====================================
    # TOKENIZATION
    print("Preparing the input")
    references_delimited = None
    summaries_delimited = None
    if len(references) > 0:
        if isinstance(references[0], list):
            if "line_delimited" in toks_needed:
                references_delimited = ["\n".join(ref) for ref in references]
            if "space" in toks_needed:
                references_space = [" ".join(ref) for ref in references]
        elif args.eos is not None:
            if "line_delimited" not in toks_needed:
                raise ValueError('You provided a delimiter but are not using a metric which requires one.')
            if args.eos == "\n":
                references_delimited = [ref.split(args.eos) for ref in references]
            else:
                references_delimited = [f"{args.eos}\n".join(ref.split(args.eos)) for ref in references]
        elif "line_delimited" in toks_needed:
            references_delimited = references
        if "space" in toks_needed:
            references_space = references

    if isinstance(summaries[0], list):
        if "line_delimited" in toks_needed:
            summaries_delimited = ["\n".join(summ) for summ in summaries]
        if "space" in toks_needed:
            summaries_space = [" ".join(summ) for summ in summaries]
    elif args.eos is not None:
        if "line_delimited" not in toks_needed:
            raise ValueError('You provided a delimiter but are not using a metric which requires one.')
        if args.eos == "\n":
            summaries_delimited = [ref.split(args.eos) for ref in summaries]
        else:
            summaries_delimited = [f"{args.eos}\n".join(ref.split(args.eos)) for ref in summaries]
    elif "line_delimited" in toks_needed:
        summaries_delimited = summaries
    if "space" in toks_needed:
        summaries_space = summaries

    if "stem" in toks_needed:
        tokenizer = RegexpTokenizer(r'\w+')
        stemmer = SnowballStemmer("english")
        if isinstance(summaries[0], list):
            summaries_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(" ".join(summ))] for summ in
                                 summaries]
            references_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(" ".join(ref))] for ref in
                                  references]
        else:
            summaries_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(summ)] for summ in summaries]
            references_stemmed = [[stemmer.stem(word) for word in tokenizer.tokenize(ref)] for ref in references]
        summaries_stemmed = [" ".join(summ) for summ in summaries_stemmed]
        references_stemmed = [" ".join(ref) for ref in references_stemmed]

    if "spacy" in toks_needed:
        nlp = spacy.load('en_core_web_sm')
        # nlp = spacy.load('en_core_web_md')
        disable = ["tagger", "textcat"]
        if "summaqa" not in metrics:
            disable.append("ner")
        if isinstance(summaries[0], list):
            summaries_spacy = [nlp(" ".join(text), disable=disable) for text in summaries]
        else:
            summaries_spacy = [nlp(text, disable=disable) for text in summaries]
        if "stats" in metrics:
            summaries_spacy_stats = [[tok.text for tok in summary] for summary in summaries_spacy]
        if "sms" in metrics:
            if isinstance(references[0], list):
                references_spacy = [nlp(" ".join(text), disable=disable) for text in references]
            else:
                references_spacy = [nlp(text, disable=disable) for text in references]

        # if "summaqa" in metrics or "stats" in metrics:
        #     if isinstance(articles[0], list):
        #         input_spacy = [nlp(" ".join(text), disable=disable) for text in articles]
        #     else:
        #         input_spacy = [nlp(text, disable=disable) for text in articles]
        #     if "stats" in metrics:
        #         input_spacy_stats = [[tok.text for tok in article] for article in input_spacy]
        # use reference as article for stats
        if "summaqa" in metrics or "stats" in metrics:
            if isinstance(references[0], list):
                input_spacy = [nlp(" ".join(text), disable=disable) for text in references]
            else:
                input_spacy = [nlp(text, disable=disable) for text in references]
            if "stats" in metrics:
                input_spacy_stats = [[tok.text for tok in ref] for ref in input_spacy]
    if "supert" in metrics or "blanc" in metrics:
        inputs_space = articles
    # =====================================

    # =====================================
    # GET SCORES
    if args.aggregate:
        final_output = dict()
    else:
        final_output = defaultdict(lambda: defaultdict(int))
    # import pdb;pdb.set_trace()
    for metric, metric_cls in metrics_dict.items():
        print(f"Calculating scores for the {metric} metric.")
        try:
            if metric == "rouge":
                output = metric_cls.evaluate_batch(summaries_delimited, references_delimited, aggregate=args.aggregate)
                # only rouge uses this input so we can delete it
                del references_delimited
                del summaries_delimited
            elif metric in ('bert_score', 'mover_score', 'chrf', 'meteor', 'bleu'):
                output = metric_cls.evaluate_batch(summaries_space, references_space, aggregate=args.aggregate)
            elif metric in ('s3', 'rouge_we', 'cider'):
                output = metric_cls.evaluate_batch(summaries_stemmed, references_stemmed, aggregate=args.aggregate)
            elif metric == "sms":
                output = metric_cls.evaluate_batch(summaries_spacy, references_spacy, aggregate=args.aggregate)
            elif metric in ('summaqa', 'stats', 'supert', 'blanc'):
                if metric == "summaqa":
                    output = metric_cls.evaluate_batch(summaries_space, input_spacy, aggregate=args.aggregate)
                elif metric == "stats":
                    output = metric_cls.evaluate_batch(summaries_spacy_stats, input_spacy_stats,
                                                       aggregate=args.aggregate)
                elif metric in ('supert', 'blanc'):
                    output = metric_cls.evaluate_batch(summaries_space, inputs_space, aggregate=args.aggregate)
            if args.aggregate:
                final_output.update(output)
            else:
                ids = list(range(0, len(ids)))
                for cur_id, cur_output in zip(ids, output):
                    final_output[cur_id].update(cur_output)
        except Exception as e:
            print(e)
            print(f"An error was encountered with the {metric} metric.")
    # =====================================

    # =====================================
    # OUTPUT SCORES
    metrics_str = "_".join(metrics)
    # json_file_end = args.jsonl_file.split("/")[-1]
    json_file_end = args.jsonl_file.replace("/", "_")
    output_path = f"output_{metrics_str}.jsonl"
    print(f"saving to {output_path}")
    # with open(f"outputs/{args.output_file}_{json_file_end}_{metrics_str}.jsonl", "w") as outputf:
    with open(output_path, "w") as outputf:
        if args.aggregate:
            json.dump(final_output, outputf)
        else:
            for key, value in final_output.items():
                value["id"] = key
                json.dump(value, outputf)
                outputf.write("\n")
    # =====================================
    print(f"Write scores to: {output_path}")
    return output_path
    def test_glove_swms_single(self):
        metric = SentenceMoversMetric(wordrep='glove', metric='s+wms')
        score = metric.evaluate_example(CANDS[0], REFS[0])
        score0 = 0.9993097383211589

        self.assertTrue((score["sentence_movers_glove_s+wms"] - score0) < EPS)
    def test_glove_sms_single(self):
        metric = SentenceMoversMetric(wordrep='glove', metric='sms')
        score = metric.evaluate_example(CANDS[0], REFS[0])
        score0 = 0.998619953103503

        self.assertTrue((score["sentence_movers_glove_sms"] - score0) < EPS)
    def test_glove_wms(self):
        metric = SentenceMoversMetric(wordrep='glove', metric='wms')
        score = metric.evaluate_batch(CANDS, REFS)
        avg = sum([1.0, 0.2417027898540903, 0.04358073486688484]) / 3

        self.assertTrue((score["sentence_movers_glove_wms"] - avg) < EPS)