Example #1
0
    def __init__(self, args):
        super(KobeModel, self).__init__()

        self.encoder = Encoder(
            vocab_size=args.text_vocab_size + args.cond_vocab_size,
            max_seq_len=args.max_seq_len,
            d_model=args.d_model,
            nhead=args.nhead,
            num_layers=args.num_encoder_layers,
            dropout=args.dropout,
            mode=args.mode,
        )
        self.decoder = Decoder(
            vocab_size=args.text_vocab_size,
            max_seq_len=args.max_seq_len,
            d_model=args.d_model,
            nhead=args.nhead,
            num_layers=args.num_decoder_layers,
            dropout=args.dropout,
        )
        self.lr = args.lr
        self.d_model = args.d_model
        self.loss = nn.CrossEntropyLoss(reduction="mean",
                                        ignore_index=0,
                                        label_smoothing=0.1)
        self._reset_parameters()

        self.decoding_strategy = args.decoding_strategy
        self.vocab = BertTokenizer.from_pretrained(args.text_vocab_path)
        self.bleu = BLEU(tokenize=args.tokenize)
        self.sacre_tokenizer = _get_tokenizer(args.tokenize)()
        self.bert_scorer = BERTScorer(lang=args.tokenize,
                                      rescale_with_baseline=True)
Example #2
0
def bert_based(gts, res):
    refs, cands = [], []
    for refers in gts.values():
        sub_refs = []
        for ref in refers:
            sub_refs.append(ref + '.')
        refs.append(sub_refs)
    for cand in res.values():
        cands.append(cand[0] + '.')

    scorer = BERTScorer(lang="en", rescale_with_baseline=True)
    P, R, F1 = scorer.score(cands, refs, verbose=True)
    out_file.write('BERTScore = %s' % F1.mean().item() + "\n")
    BERTScore = F1.mean().item()

    total_bleurt_score = []
    scorer = bleurt_sc.BleurtScorer(bleurt_checkpoint)

    for ref_caption, cand in zip(refs, cands):
        bleurt_score_per_img = []
        for ref in ref_caption:
            bleurt_score_per_img.append(
                scorer.score([ref], [cand], batch_size=None)[0])
        total_bleurt_score.append(max(bleurt_score_per_img))
    out_file.write('BLEURT =%s' % statistics.mean(total_bleurt_score))
def bertscore_bias():
    scorer = BERTScorer(lang="zh", rescale_with_baseline=True)
    df = read_data()
    sample1 = df.apply(
        lambda x: scorer.score([x.context], [x['True']])[2].item(), axis=1)
    sample2 = df.apply(
        lambda x: scorer.score([x.context], [x['False']])[2].item(), axis=1)
    print('True', sample1.mean(), 'False', sample2.mean())
    return stats.ttest_ind(sample1, sample2, equal_var=False)[1]
Example #4
0
def get_bertscore_sentence_scores(
    sys_sents: List[str],
    refs_sents: List[List[str]],
    lowercase: bool = False,
    tokenizer: str = "13a",
):
    scorer = BERTScorer(lang="en", rescale_with_baseline=True)

    sys_sents = [utils_prep.normalize(sent, lowercase, tokenizer) for sent in sys_sents]
    refs_sents = [[utils_prep.normalize(sent, lowercase, tokenizer) for sent in ref_sents] for ref_sents in refs_sents]
    refs_sents = [list(r) for r in zip(*refs_sents)]

    return scorer.score(sys_sents, refs_sents)
Example #5
0
    def __init__(self, threshold, top_k):
        scorer = BERTScorer(lang="en", rescale_with_baseline=True)
        model = scorer._model
        embedding = ContextualEmbedding(model, "roberta-large", 510)
        baseline_val = scorer.baseline_vals[2].item()

        super(BertscoreAligner, self).__init__(embedding, threshold, top_k,
                                               baseline_val)
def gen_samples():

    scorer = BERTScorer(lang="zh", rescale_with_baseline=True)
    data_file = '32-deduplicate-story.csv'
    df = pd.read_csv(data_file)
    # import pdb;pdb.set_trace()
    stories = list(df.story.dropna())
    stories_split = [split_by_fullstop(x) for x in stories]
    stories_split_select = [
        random.randint(0,
                       len(x) - 1) for x in stories_split
    ]
    stories_sentencesample = [
        x[y] for x, y in zip(stories_split, stories_split_select)
    ]
    stories_split_copy = copy.deepcopy(stories_split)
    stories_context = []
    for ss, sss in zip(stories_split_copy, stories_split_select):
        ss[sss] = '<MASK>'
        stories_context.append(ss)
    stories_context = [''.join(x) for x in stories_context]
    positive_samples = [
        (x, y, True) for x, y in zip(stories_context, stories_sentencesample)
    ]
    cands = stories_sentencesample
    assert len(cands) == len(stories_split)
    refs = []
    for i, cand in enumerate(cands):
        refs.append([
            x for j, y in enumerate(stories_split) for x in y
            if len(x) > 0 and j != i
        ])
    bestmatch = []
    print(len(cands))
    for i, (c, ref) in enumerate(zip(cands, refs)):
        print(i, 'th candidate...')
        cand = [c] * len(ref)
        import pdb
        pdb.set_trace()
        P, R, F1 = scorer.score(cand, ref)
        bestmatch.append(int(torch.argmax(R)))
    negative_samples = [(x, y[z], False)
                        for x, y, z in zip(stories_context, refs, bestmatch)]
    return [(x, w, y[z]) for x, y, z, w in zip(
        stories_context, refs, bestmatch, stories_sentencesample)]
Example #7
0
    def add_bertscore(self):
        bert_scorer = BERTScorer(num_layers=12,
                                 model_type="bert-base-german-cased")

        def bert_score(hypothesis: str, reference: str):
            P, R, F1 = bert_scorer.score([hypothesis], [reference])
            return Tensor.item(F1)

        self.scores.append(bert_score)
Example #8
0
    def run_bartscore(self):
        ''' Computes the BARTScore score between the set of hypothesis 
            and reference summaries.
        '''
        print('\n===== BARTScore =====\n')
        bartscore = BERTScorer(lang="en",
                               model_type=self.bartscore_model,
                               num_layers=12)

        for hyps_path, refs_path in zip(self.hyps_paths, self.refs_paths):
            self.load_summs(hyps_path, refs_path)
            P, R, F1 = bartscore.score(self.hyps, self.refs, batch_size=64)
            self.df_scores.loc[self.df_scores['hyps_path'] == hyps_path,
                               'bartscore'] = F1.tolist()
            self.save_temp_csv()
            print(F1.mean())

        del P, R, F1, bartscore
        torch.cuda.empty_cache()
def cal_bert_score(target,source, batch_size = 2, chunks=5000, prefix="", folder="./bert_scores"):

    import glob
    import pickle as pkl
    import os
    import torch
    import gc
    from bert_score import BERTScorer
    scores = {}
    scores["f1"] = []
    scores["p"] = []
    scores["r"] = []
    
    #for i in tqdm(range(0,len(texts_basis), batch_size)):
    #    src_texts = texts_basis[i:i + batch_size]
       # print(src_texts[:2])
    #    summaries = texts_to_compare[i:i + batch_size]
        #print(summaries[:2])
    #    P, R, F1 = score(summaries, src_texts, lang='en', verbose=False, batch_size=batch_size, device="cuda:0")
     #   scores["f1"].extend(F1.numpy())
     #   scores["p"].extend(P.numpy())
      #  scores["r"].extend(R.numpy())
    
    #
    scorer = BERTScorer(lang="en", model_type="xlnet-base-cased", batch_size=batch_size, device="cuda:0")
    t = trange(0, len(target), chunks, desc=f"done in {chunks}")
    files = glob.glob(f"{folder}/*p_r_f1_*.pkl")
    for i in t:
        if f"{folder}/{prefix}_p_r_f1_{i}.pkl" not in files:
            t.set_description(f"done in {chunks}")
            cal_bert_score_chunk(scorer, target[i:i+chunks], source[i:i+chunks], i, batch_size, True, prefix, folder)
            gc.collect()
            torch.cuda.empty_cache()
        else:
            t.set_description(f"done in {chunks}, skipped: {i}")

    files = glob.glob(f"{folder}/*p_r_f1_*.pkl")
Example #10
0
def load_bert_score(model: str, device: str):
    """
    Load BERTScore model from HuggingFace hub

    Args:
        model (str): model name to be loaded
        device (str): device info

    Returns:
        function: BERTScore score function

    """
    print("Loading BERTScore Pipeline...")

    try:
        scorer = BERTScorer(
            model_type=model,
            lang="en",
            rescale_with_baseline=True,
            device=device,
        )
        return scorer.score
    except KeyError:
        print("Input model is not supported by BERTScore")
Example #11
0
import os
from bert_score import BERTScorer
scorer = BERTScorer(lang="ch", batch_size=1, device='cuda:0')

logdir = './logs/S2S/decode_val_600maxenc_4beam_35mindec_150maxdec_ckpt-62256'
decodeddir = logdir + '/decoded'
referencedir = logdir + '/reference'
dir_or_files = os.listdir(decodeddir)
dir_or_files = sorted(dir_or_files)
count = 0
for file in dir_or_files:
    f = open(os.path.join(decodeddir, file), 'r', encoding='utf-8')
    decodetext = []
    for line in f.readlines():
        decodetext.append(line[1:])
    f.close()
    f = open(os.path.join(referencedir, file[0:6] + '_reference.txt'),
             'r',
             encoding='utf-8')
    reftext = []
    for line in f.readlines():
        reftext.append(line[1:])
        # reftext.append(line[1:])
    f.close()
    # count += 1
    # if count == 10:
    # 	break
print(scorer.score(decodetext, reftext))
Example #12
0
def main(params):
    # Loading data
    dataset, num_labels = load_data(params)
    dataset = dataset["train"]
    text_key = 'text'
    if params.dataset == "dbpedia14":
        text_key = 'content'
    print(f"Loaded dataset {params.dataset}, that has {len(dataset)} rows")

    # Load model and tokenizer from HuggingFace
    model_class = transformers.AutoModelForSequenceClassification
    model = model_class.from_pretrained(params.model,
                                        num_labels=num_labels).cuda()

    if params.ckpt != None:
        state_dict = torch.load(params.ckpt)
        model.load_state_dict(state_dict)
    tokenizer = textattack.models.tokenizers.AutoTokenizer(params.model)
    model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(
        model, tokenizer, batch_size=params.batch_size)

    # Create radioactive directions and modify classification layer to use those
    if params.radioactive:
        torch.manual_seed(0)
        radioactive_directions = torch.randn(num_labels, 768)
        radioactive_directions /= torch.norm(radioactive_directions,
                                             dim=1,
                                             keepdim=True)
        print(radioactive_directions)
        model.classifier.weight.data = radioactive_directions.cuda()
        model.classifier.bias.data = torch.zeros(num_labels).cuda()

    start_index = params.chunk_id * params.chunk_size
    end_index = start_index + params.chunk_size

    if params.target_dir is not None:
        target_file = join(params.target_dir, f"{params.chunk_id}.csv")
        f = open(target_file, "w")
        f = csv.writer(f,
                       delimiter=',',
                       quotechar='"',
                       quoting=csv.QUOTE_NONNUMERIC)

    # Creating attack
    print(f"Building {params.attack} attack")
    if params.attack == "custom":
        current_label = -1
        if params.targeted:
            current_label = dataset[start_index]['label']
            assert all([
                dataset[i]['label'] == current_label
                for i in range(start_index, end_index)
            ])
        attack = build_attack(model_wrapper, current_label)
    elif params.attack == "bae":
        print(f"Building BAE method with threshold={params.bae_threshold:.2f}")
        attack = build_baegarg2019(model_wrapper,
                                   threshold_cosine=params.bae_threshold,
                                   query_budget=params.query_budget)
    elif params.attack == "bert-attack":
        assert params.query_budget is None
        attack = BERTAttackLi2020.build(model_wrapper)
    elif params.attack == "clare":
        assert params.query_budget is None
        attack = CLARE2020.build(model_wrapper)

    # Launching attack
    begin_time = time.time()
    samples = [
        (dataset[i][text_key],
         attack.goal_function.get_output(AttackedText(dataset[i][text_key])))
        for i in range(start_index, end_index)
    ]
    results = list(attack.attack_dataset(samples))

    # Storing attacked text
    bert_scorer = BERTScorer(model_type="bert-base-uncased", idf=False)

    n_success = 0
    similarities = []
    queries = []
    use = USE()

    for i_result, result in enumerate(results):
        print("")
        print(50 * "*")
        print("")
        text = dataset[start_index + i_result][text_key]
        ptext = result.perturbed_text()
        i_data = start_index + i_result
        if params.target_dir is not None:
            if params.dataset == 'dbpedia14':
                f.writerow([
                    dataset[i_data]['label'] + 1, dataset[i_data]['title'],
                    ptext
                ])
            else:
                f.writerow([dataset[i_data]['label'] + 1, ptext])

        print("True label ", dataset[i_data]['label'])
        print(f"CLEAN TEXT\n {text}")
        print(f"ADV TEXT\n {ptext}")

        if type(result) not in [SuccessfulAttackResult, FailedAttackResult]:
            print("WARNING: Attack neither succeeded nor failed...")
        print(result.goal_function_result_str())
        precision, recall, f1 = [
            r.item() for r in bert_scorer.score([ptext], [text])
        ]
        print(
            f"Bert scores: precision {precision:.2f}, recall: {recall:.2f}, f1: {f1:.2f}"
        )
        initial_logits = model_wrapper([text])
        final_logits = model_wrapper([ptext])
        print("Initial logits", initial_logits)
        print("Final logits", final_logits)
        print("Logits difference", final_logits - initial_logits)

        # Statistics
        n_success += 1 if type(result) is SuccessfulAttackResult else 0
        queries.append(result.num_queries)
        similarities.append(use.compute_sim([text], [ptext]))

    print("Processing all samples took %.2f" % (time.time() - begin_time))
    print(f"Total success: {n_success}/{len(results)}")
    logs = {
        "success_rate": n_success / len(results),
        "avg_queries": sum(queries) / len(queries),
        "queries": queries,
        "avg_similarity": sum(similarities) / len(similarities),
        "similarities": similarities,
    }
    print("__logs:" + json.dumps(logs))
    if params.target_dir is not None:
        f.close()
Example #13
0
### BERT score with human reference
from bert_score import BERTScorer

scorer = BERTScorer(lang="en", rescale_with_baseline=False)

from bert_score import score
import glob

human_files = "/data/private/E2E/dataset/f_test.txt"

human_open = open(human_files, "r")
human_dataset = human_open.readlines()
human_open.close()

human_references = []

temp_reference = []
for i in range(len(human_dataset)):
    if human_dataset[i] == '\n':
        human_references.append(temp_reference)
        temp_reference = []
    else:
        temp_reference.append(human_dataset[i].strip())
human_references.append(temp_reference)
human_compare = []
for i in range(len(human_references)):
    for k in range(len(human_references[i])):
        human_compare.append(human_references[i][k])

# output_path = "/data/private/E2E/predictions/final/*"
# output_path = "/data/private/E2E/predictions/reproduce/try_2/*"
Example #14
0
from beam import BeamSearch
from tensorboardX import SummaryWriter
from utils import batch_bleu, bert_dual_sequence_mask
import os
import logging
from utils import mkdir
from dotmap import DotMap
import tqdm
import json
import sacrebleu
import numpy as np
from transformers import BertModel, RobertaModel
from itertools import chain
from bert_score import BERTScorer

scorer = BERTScorer(lang="en")
bert_score = scorer.score

__transformers__ = [BertModel, RobertaModel]


def build_trainer(model, args, datamaker, phase="train"):
    if phase not in ["train"]:
        raise NotImplementedError(
            "PRETRAIN and TUNE modes to be implemented, only TRAIN mode is supported"
        )

    trainer = Trainer(
        model,
        patience=args.patience,
        # val_interval=100,
Example #15
0
    # if metric is moverscore, prepare the fitness func args here itself (optimization)
    if args.metric == 'moverscore':
        from all_metrics.moverscore import get_idf_dict
        with open('stopwords.txt', 'r', encoding='utf-8') as f:
            stop_words = set(f.read().strip().split(' '))
        moverscore_args = {
            'idf_dict_ref': get_idf_dict([" ".join(ref) for ref in src_docs]),
            'idf_dict_hyp': get_idf_dict([" ".join(src) for src in tgt_docs]),
            'stop_words': stop_words
        }
    else:
        moverscore_args = None

    if args.metric == 'bertscore':
        from bert_score import BERTScorer
        scorer = BERTScorer(lang='en', rescale_with_baseline=True)

    with parallel_backend('multiprocessing', args.n_jobs):
        Parallel()(delayed(generate)(
            src_doc_str=src_docs[doc_num],
            tgt_doc_str=tgt_docs[doc_num],
            length_max=args.max_len,
            n_epochs=args.n_epochs,
            population_size=args.pop_size,
            doc_num=doc_num,
            optim_metric=args.metric,
            out_path=args.out_path,
            moverscore_args=moverscore_args,
            prf='f',
        ) for doc_num in range(args.start_doc_idx, args.end_doc_idx))
Example #16
0
class Scorer:
    def __init__(self,
                 src_path,
                 ref_path,
                 metric,
                 ref_sep,
                 fast_moverscore=False,
                 num_ref=1):
        self.src_path = src_path
        self.ref_path = ref_path
        self.metric = metric
        self.ref_sep = ref_sep
        self.num_ref = num_ref

        self.ref_lines_with_tags = read_file(ref_path)
        self.ref_lines = [
            ' '.join(
                get_sents_from_tags(ref.replace(self.ref_sep, ''),
                                    sent_start_tag='<t>',
                                    sent_end_tag='</t>'))
            for ref in self.ref_lines_with_tags
        ]

        for i, ref in enumerate(self.ref_lines):
            if len(ref) == 0:
                self.ref_lines[i] = '### DUPLICATE ###'

        self.idf_refs = None
        self.idf_hyps = None
        if metric == 'moverscore':
            from all_metrics.moverscore import get_idf_dict
            with open('all_metrics/stopwords.txt', 'r', encoding='utf-8') as f:
                self.stop_words = set(f.read().strip().split(' '))
            if fast_moverscore:
                assert src_path is not None, f"src_path must be provided for fast moverscore"
                src_lines_with_tags = read_file(src_path)
                src_lines = [
                    ' '.join(
                        get_sents_from_tags(src,
                                            sent_start_tag='<t>',
                                            sent_end_tag='</t>'))
                    for src in src_lines_with_tags
                ]
                self.idf_refs = get_idf_dict(self.ref_lines)
                self.idf_hyps = get_idf_dict(src_lines)

        if metric == 'bertscore':
            from bert_score import BERTScorer
            self.bert_scorer = BERTScorer(lang='en',
                                          rescale_with_baseline=True)

        if metric == 'js2':
            ref_sents = [
                get_sents_from_tags(ref_line.replace(ref_sep, ''),
                                    sent_start_tag='<t>',
                                    sent_end_tag='</t>')
                for ref_line in self.ref_lines_with_tags
            ]

            self.ref_freq = [compute_tf(rs, N=2) for rs in ref_sents]

        if metric == 'rwe':
            self.embs = we.load_embeddings('../data/peyrard_s3/deps.words')

    def score(self, file_num, summ_path, model_name, variant_name):
        """
        :return: a list with format: [{score: value}] with scores for each doc in each dict
        """
        logger.info(
            f"getting scores for model: {model_name}, variant: {variant_name}, file num: {file_num}"
        )
        summ_lines_with_tags = read_file(summ_path)
        summ_lines = [
            ' '.join(
                get_sents_from_tags(summ,
                                    sent_start_tag='<t>',
                                    sent_end_tag='</t>'))
            for summ in summ_lines_with_tags
        ]
        for i, summ in enumerate(summ_lines):
            if len(summ) == 0:
                summ_lines[i] = '### DUPLICATE ###'

        if self.metric == 'moverscore':
            from all_metrics.moverscore import word_mover_score, get_idf_dict
            idf_refs = get_idf_dict(
                self.ref_lines) if self.idf_refs is None else self.idf_refs
            idf_hyps = get_idf_dict(
                summ_lines) if self.idf_hyps is None else self.idf_hyps
            scores = word_mover_score(self.ref_lines,
                                      summ_lines,
                                      idf_refs,
                                      idf_hyps,
                                      self.stop_words,
                                      n_gram=1,
                                      remove_subwords=True,
                                      batch_size=64,
                                      device='cuda:0')
            scores = [{'mover_score': s} for s in scores]

        elif self.metric == 'bertscore':
            (P, R, F) = self.bert_scorer.score(summ_lines, self.ref_lines)
            P, R, F = list(F.numpy()), list(P.numpy()), list(R.numpy())
            scores = [{
                'bert_precision_score': p,
                'bert_recall_score': r,
                'bert_f_score': f_score
            } for p, r, f_score in zip(P, R, F)]

        elif self.metric == 'js2':
            summ_sents = [
                get_sents_from_tags(summ_line,
                                    sent_start_tag='<t>',
                                    sent_end_tag='</t>')
                for summ_line in summ_lines_with_tags
            ]
            # import pdb; pdb.set_trace()
            scores = [{
                'js-2': -js_divergence(summ_sent, ref_freq, N=2)
            } for summ_sent, ref_freq in zip(summ_sents, self.ref_freq)]

        elif self.metric == 'rouge':
            args = argparse.Namespace(check_repeats=True,
                                      delete=True,
                                      get_each_score=True,
                                      stemming=True,
                                      method='sent_tag_verbatim',
                                      n_bootstrap=1000,
                                      run_google_rouge=False,
                                      run_rouge=True,
                                      source=summ_path,
                                      target=self.ref_path,
                                      ref_sep=self.ref_sep,
                                      num_ref=self.num_ref,
                                      temp_dir='../data/temp/')

            scores = baseline_main(
                args, return_pyrouge_scores=True)['individual_score_results']
            scores = [scores[doc_id] for doc_id in range(len(self.ref_lines))]

        elif self.metric == 'rwe':
            scores = [{
                'rouge_1_we':
                pd_rouge.rouge_n_we(ref, [summ], self.embs, n=1, alpha=0.5)
            } for ref, summ in zip(self.ref_lines, summ_lines)]

        elif self.metric == 'sms' or self.metric == 'wms':
            from all_metrics.sentence_mover.smd import smd
            scores = smd(self.ref_lines,
                         summ_lines,
                         word_rep='glove',
                         metric=self.metric)
            scores = [{self.metric: s} for s in scores]

        else:
            raise NotImplementedError(f"metric {self.metric} not supported")

        assert len(scores) == len(self.ref_lines)
        sd = {}
        for doc_id in range(len(self.ref_lines)):
            sd[doc_id] = {
                'doc_id': doc_id,
                'ref_summ': self.ref_lines_with_tags[doc_id],
                'system_summaries': {
                    f'{model_name}_{variant_name}': {
                        'system_summary': summ_lines_with_tags[doc_id],
                        'scores': scores[doc_id]
                    }
                }
            }
        return sd
Example #17
0
    def __init__(self,
                 src_path,
                 ref_path,
                 metric,
                 ref_sep,
                 fast_moverscore=False,
                 num_ref=1):
        self.src_path = src_path
        self.ref_path = ref_path
        self.metric = metric
        self.ref_sep = ref_sep
        self.num_ref = num_ref

        self.ref_lines_with_tags = read_file(ref_path)
        self.ref_lines = [
            ' '.join(
                get_sents_from_tags(ref.replace(self.ref_sep, ''),
                                    sent_start_tag='<t>',
                                    sent_end_tag='</t>'))
            for ref in self.ref_lines_with_tags
        ]

        for i, ref in enumerate(self.ref_lines):
            if len(ref) == 0:
                self.ref_lines[i] = '### DUPLICATE ###'

        self.idf_refs = None
        self.idf_hyps = None
        if metric == 'moverscore':
            from all_metrics.moverscore import get_idf_dict
            with open('all_metrics/stopwords.txt', 'r', encoding='utf-8') as f:
                self.stop_words = set(f.read().strip().split(' '))
            if fast_moverscore:
                assert src_path is not None, f"src_path must be provided for fast moverscore"
                src_lines_with_tags = read_file(src_path)
                src_lines = [
                    ' '.join(
                        get_sents_from_tags(src,
                                            sent_start_tag='<t>',
                                            sent_end_tag='</t>'))
                    for src in src_lines_with_tags
                ]
                self.idf_refs = get_idf_dict(self.ref_lines)
                self.idf_hyps = get_idf_dict(src_lines)

        if metric == 'bertscore':
            from bert_score import BERTScorer
            self.bert_scorer = BERTScorer(lang='en',
                                          rescale_with_baseline=True)

        if metric == 'js2':
            ref_sents = [
                get_sents_from_tags(ref_line.replace(ref_sep, ''),
                                    sent_start_tag='<t>',
                                    sent_end_tag='</t>')
                for ref_line in self.ref_lines_with_tags
            ]

            self.ref_freq = [compute_tf(rs, N=2) for rs in ref_sents]

        if metric == 'rwe':
            self.embs = we.load_embeddings('../data/peyrard_s3/deps.words')
def compute_bert_based_scores(test_path, path_results,
                              sentences_generated_path):
    bert_scorer = BERTScorer(lang="en", rescale_with_baseline=True)
    bleurt_scorer = bleurt_sc.BleurtScorer(bleurt_checkpoint)

    with open(test_path) as json_file:
        test = json.load(json_file)

    test_sentences = defaultdict(list)
    for ref in test["annotations"]:
        image_id = ref["image_id"]
        caption = ref["caption"]
        test_sentences[image_id].append(caption)

    # get previous score of coco metrics (bleu,meteor,etc) to append bert_based_score
    scores_path = path_results

    with open(scores_path) as json_file:
        scores = json.load(json_file)

    # get previous generated sentences to calculate bertscore according to refs
    generated_sentences_path = sentences_generated_path
    with open(generated_sentences_path) as json_file:
        generated_sentences = json.load(json_file)
    total_precision = 0.0
    total_recall = 0.0
    total_fmeasure = 0.0
    total_bleurt_score = []
    for dict_image_and_caption in generated_sentences:
        image_id = dict_image_and_caption["image_id"]
        caption = [dict_image_and_caption["caption"]]
        references = [test_sentences[image_id]]
        bleurt_score_per_img = []
        for ref in references[0]:
            bleurt_score_per_img.append(
                bleurt_scorer.score([ref], caption, batch_size=None)[0])
        total_bleurt_score.append(max(bleurt_score_per_img))

        P_mul, R_mul, F_mul = bert_scorer.score(caption, references)
        precision = P_mul[0].item()
        recall = R_mul[0].item()
        f_measure = F_mul[0].item()

        total_precision += precision
        total_recall += recall
        total_fmeasure += f_measure

        # calculate bert_based_scores
        key_image_id = str(image_id)
        scores[str(key_image_id)]["BertScore_P"] = precision
        scores[key_image_id]["BertScore_R"] = recall
        scores[key_image_id]["BertScore_F"] = f_measure
        scores[key_image_id]["BLEURT"] = max(bleurt_score_per_img)
        # print("\ncaption and score", caption, f_measure)

    n_captions = len(generated_sentences)
    scores["avg_metrics"]["BertScore_P"] = total_precision / n_captions
    scores["avg_metrics"]["BertScore_R"] = total_recall / n_captions
    scores["avg_metrics"]["BertScore_F"] = total_fmeasure / n_captions
    scores["avg_metrics"]["BLEURT"] = statistics.mean(total_bleurt_score)

    # save scores dict to a json
    with open(scores_path, 'w+') as f:
        json.dump(scores, f, indent=2)
batch_size = 2
max_len = 20
top_k = 50
temperature = 1.5
generation_mode = "parallel-sequential"
leed_out_len = 5 # max_len
burnin = 250
sample = True
max_iter = 500
question = 2
ModelForQ_A_on = True
Metrics_calculation = True

#========================================== BERTScorer initialisation ==================================================
Q_metrics = [[],[],[]]
scorer = BERTScorer(model_type='bert-base-uncased')
q_refs = pickle.load(open('Metrics_files/Q_refs.pkl', 'rb'))
q3_refs = q_refs['q3_refs']
q2_refs = q_refs['q2_refs']
q1_refs = q_refs['q1_refs']
all_q_refs = [q1_refs,q2_refs,q3_refs]
all_q_cands = ['Who is she?', 'Are you okay?', 'Why?']


#==================================================== Word generation ==================================================
# Choose the prefix context
#seed_text = ugen.tokenizer.tokenize("who is she?".lower())
if Metrics_calculation:
    print('Metrics (BLEU, P, R, F1)')
    for i in range(len(all_q_cands)):
        seed_text = ugen.tokenizer.tokenize(all_q_cands[i].lower())
from bert_score import BERTScorer
import random
import pandas as pd
from utils import split_by_fullstop
from tools import start_debugger_on_exception
import torch
start_debugger_on_exception()
scorer = BERTScorer(lang="zh", rescale_with_baseline=True)
data_file = 'annotateddata/batch1.csv'
df = pd.read_csv(data_file)
# import pdb;pdb.set_trace()
stories = list(df.RESULT.dropna())
stories_split = [split_by_fullstop(x) for x in stories]
refs_pre = [x for y in stories_split for x in y if len(x)>0]
stories_split_select = [random.randint(0,len(x)-1) for x in stories_split]
stories_sentencesample = [x[y] for x,y in zip(stories_split,stories_split_select)]

stories_context = [] 
for ss,sss in zip(stories_split,stories_split_select):
    ss[sss] = '<MASK>'
    stories_context.append(ss)
stories_context = [''.join(x) for x in stories_context]  
positive_samples = [(x,y,True) for x,y in zip(stories_context,stories_sentencesample)]
cands_pre = stories_sentencesample
len_refs = len(refs_pre)
len_cands = len(cands_pre)
cands = [x for x in cands_pre for i in range(len_refs)]

refs = refs_pre*len_cands
# print(refs)
# print(cands)
 def __init__(self, chara_word='野球'):
     self.c = chara_word
     data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data')
     self.baseline_file_path = os.path.join(data_dir, 'bert-base-multilingual-cased.tsv')
     self.scorer = BERTScorer(model_type=os.path.join(data_dir, 'bert-base_mecab-ipadic-bpe-32k_whole-word-mask'), num_layers=11, lang='ja', rescale_with_baseline=True, baseline_path=self.baseline_file_path)
     self.min_rhyme = 2
Example #22
0
def evaluate_metric(metric, stem, remove_stop, prompt='overall'):
    ''' Compute the correlation between the human eval scores and the scores awarded by the
        eval metric.
    '''
    assert metric in ['ROUGE-1-F', 'ROUGE-2-F', 'ROUGE-L-F', 'bert-human', 'bert-score', 'bart-score', 
        'bleurt-base', 'bleurt-lg', 'mover-1', 'mover-2', 'mover-smd', 'bert-avg-score']
    stemmed_str = "_stem" if stem else ""
    stop_str = "_removestop" if remove_stop else ""
    ranks_file_path = os.path.join('learned_eval/outputs', 'wref_{}{}{}_{}_rank_correlation.csv'.format(metric, stemmed_str, stop_str, prompt))
    print('\n====={}=====\n'.format(ranks_file_path))

    ranks_file = open(ranks_file_path, 'w')
    ranks_file.write('article,summ_id, human_score, metric_score\n')

    sorted_scores = read_sorted_scores()
    input_articles, _ = read_articles()
    corr_data = np.zeros((len(sorted_scores), 3))

    stopwords_list = set(stopwords.words("english"))
    stemmer = PorterStemmer()

    # Init the metric
    if metric == 'bert-human':
        rewarder = Rewarder(os.path.join(MODEL_WEIGHT_DIR, 'sample.model'))
    elif metric.endswith('score'):   
        from bert_score import BERTScorer
        if 'bert-score' == metric:
            rewarder = BERTScorer(lang="en", rescale_with_baseline=True, model_type='roberta-large-mnli')
        elif 'bart-score' == metric:
            rewarder = BERTScorer(lang="en", model_type="facebook/bart-large-mnli", num_layers=12)
        elif 'bert-avg' in metric:
            r1 = BERTScorer(lang="en", rescale_with_baseline=False, model_type='roberta-large')
            r2 = BERTScorer(lang="en", rescale_with_baseline=False, model_type='albert-xxlarge-v2')
            r3 = BERTScorer(lang="en", rescale_with_baseline=False, model_type='bart-large-mnli', num_layers=12)
    elif metric.startswith('bleurt'):
        from bleurt import score
        if 'base' in metric: 
            checkpoint = "bleurt-base-512"
        elif 'lg' in metric: 
            checkpoint = "bleurt-large-512"
        rewarder = score.BleurtScorer(checkpoint)
    elif metric.startswith('mover'):
        from moverscore import get_idf_dict, word_mover_score
        hyps = [s['sys_summ'] for score in sorted_scores.values() for s in score if s['sys_name'] != 'reference']
        refs = [s['sys_summ'] for score in sorted_scores.values() for s in score if s['sys_name'] == 'reference']
        idf_dict_hyp = get_idf_dict(hyps)
        idf_dict_ref = get_idf_dict(refs)
    elif 'rouge' in metric.lower():
        from rouge_score import rouge_scorer
        from rouge_score.scoring import BootstrapAggregator

    # Loop over each article and compute the correlation between human judgement
    # and the metric scores. 
    for i, (article_id, scores) in tqdm(enumerate(sorted_scores.items())):
        scores_list = [s for s in scores if s['sys_name'] != 'reference']
        human_ranks = [s['scores'][prompt] for s in scores_list]
        if len(human_ranks) < 2: 
            continue    # Must be at least 2 scores to compute the correlation
        ref_summ = scores_list[0]['ref']
        article = [entry['article'] for entry in input_articles if entry['id']==article_id][0]

        # Pre-processing (if necessary)
        if stem and remove_stop:
            sys_summs = [" ".join(sent2stokens_wostop(s['sys_summ'], stemmer, stopwords_list, 'english', True)) for s in scores_list]
            ref_summ = " ".join(sent2stokens_wostop(ref_summ, stemmer, stopwords_list, 'english', True))
            article = " ".join(sent2stokens_wostop(article, stemmer, stopwords_list, 'english', True))
        elif not stem and remove_stop:
            sys_summs = [" ".join(sent2tokens_wostop(s['sys_summ'], stopwords_list, 'english', True)) for s in scores_list]
            ref_summ = " ".join(sent2tokens_wostop(ref_summ, stopwords_list, 'english', True))
            article = " ".join(sent2tokens_wostop(article, stopwords_list, 'english', True))
        elif not remove_stop and stem:
            sys_summs = [" ".join(sent2stokens(s['sys_summ'], stemmer, 'english', True)) for s in scores_list]
            ref_summ = " ".join(sent2stokens(ref_summ, stemmer, 'english', True))
            article = " ".join(sent2stokens(article, stemmer, 'english', True))
        else:
            sys_summs = [s['sys_summ'] for s in scores_list]

        # Clean summaries
        summ_ids = [s['summ_id'] for s in scores_list]
        sys_summs = [text_normalization(s) for s in sys_summs]
        ref_summ = text_normalization(ref_summ)
        article = text_normalization(article)

        # Compute metric scores
        if 'rouge' in metric.lower():
            auto_metric_ranks = []
            if '1' in metric:
                rouge_metric = 'rouge1'
            elif '2' in metric:
                rouge_metric = 'rouge2'
            elif 'L' in metric:
                rouge_metric = 'rougeL'
            rew_rouge = rouge_scorer.RougeScorer([rouge_metric], use_stemmer=True)
            for ss in sys_summs:
                ss = ss.replace('. ', '\n')
                ref_summ = ref_summ.replace('. ', '\n')
                score = rew_rouge.score(ref_summ, ss)
                auto_metric_ranks.append(score[rouge_metric].fmeasure)
        if metric == 'bert-human':
            auto_metric_ranks = [rewarder(ref_summ,ss) for ss in sys_summs]
        elif metric.endswith('score'):   
            if 'bert-score' == metric:
                auto_metric_ranks = [rewarder.score([ref_summ], [ss])[-1].item() for ss in sys_summs]
            elif 'bart-score' == metric:
                auto_metric_ranks = [rewarder.score([ref_summ], [ss])[-1].item() for ss in sys_summs]
            elif 'bert-avg' in metric:
                rewarder_scores = []
                for rewarder in [r1, r2, r3]:
                    r_scores = np.array([rewarder.score([ref_summ], [ss])[-1].item() for ss in sys_summs])
                    r_scores = (r_scores - np.min(r_scores)) / (np.max(r_scores) - np.min(r_scores))
                    rewarder_scores.append(r_scores)
                auto_metric_ranks = list(np.mean(rewarder_scores, axis=0))
        elif metric.startswith('bleurt'):
            auto_metric_ranks = [rewarder.score([ref_summ], [ss])[0] for ss in sys_summs]
        elif metric.startswith('mover'):
            if '1' in metric: 
                n_gram = 1
            elif '2' in metric: 
                n_gram = 2
            else: 
                raise ValueError("smd not implemented currently")
            auto_metric_ranks = [word_mover_score([ref_summ], [ss], idf_dict_ref, idf_dict_hyp,
                                stop_words=[], n_gram=n_gram, remove_subwords=True)[0] for ss in sys_summs]
   
        for sid, amr, hr in zip(summ_ids, auto_metric_ranks, human_ranks):
            ranks_file.write('{},{},{:.2f},{:.4f}\n'.format(article_id, sid, hr, amr))

        # Compute correlations
        spearmanr_result = spearmanr(human_ranks, auto_metric_ranks)
        pearsonr_result = pearsonr(human_ranks, auto_metric_ranks)
        kendalltau_result = kendalltau(human_ranks, auto_metric_ranks)
        corr_data[i, :] = [spearmanr_result[0], pearsonr_result[0], kendalltau_result[0]]

    corr_mean_all = np.nanmean(corr_data, axis=0)
    corr_std_all = np.nanstd(corr_data, axis=0)
    print('\n====={}=====\n'.format(ranks_file_path))
    print("Correlation mean on all data spearman/pearsonr/kendall: {}".format(corr_mean_all))
    print("Correlation std on all data spearman/pearsonr/kendall: {}".format(corr_std_all))

    ranks_file.flush()
    ranks_file.close()

    return ranks_file_path
Example #23
0
from bert_score import BERTScorer
scorer = BERTScorer(lang="en", rescale_with_baseline=False)

import glob
human_files = "/data/private/WebNLG-models/prediction/challenge/reference.txt"

human_open = open(human_files, "r")
human_dataset = human_open.readlines()
human_open.close()

output_path = "/data/private/WebNLG-models/prediction/challenge/compare/*"
# output_path = "/data/private/WebNLG-models/prediction/challenge/my_output/*"
pred_files = glob.glob(output_path)

score_list = []
for i in range(len(pred_files)):
    cands = []
    pred_data_open = open(pred_files[i], "r")
    pred_data_dataset = pred_data_open.readlines()
    pred_data_open.close()

    P, R, F1 = scorer.score(human_dataset, pred_data_dataset)

    F1_list = list(F1.numpy())
    BERT_score = sum(F1_list) / len(F1_list)

    score_list.append(BERT_score)

for i in range(len(pred_files)):
    print(pred_files[i])
    print(score_list[i])
Example #24
0
def get_scores(nrows, metrics=None):
    ''' Get correlations between metric similarity and label similarity '''
    df = pd.read_csv(QQP_DATA_PATH, nrows=nrows)
    start_time = time()
    if not metrics:
        metrics = [
            'mover-1',
            'mover-2',
            'bleurt',
            'bertscore',
            'bartscore',
            'rouge1',
            'rouge2',
            'rougeLsum',
        ]
    for m in tqdm(metrics):
        if m.startswith('rouge'):
            scorer = rouge_scorer.RougeScorer(
                [met for met in metrics if met.startswith('rouge')],
                use_stemmer=True)
            scores = [
                scorer.score(r, c)[m].fmeasure
                for c, r in zip(df.question1, df.question2)
            ]
        elif m == 'bertscore':
            scorer = BERTScorer(lang="en",
                                rescale_with_baseline=True,
                                model_type='roberta-large-mnli')
            _, _, scores = scorer.score(df.question1.tolist(),
                                        df.question2.tolist())
        elif m == 'bartscore':
            scorer = BERTScorer(lang="en",
                                model_type="facebook/bart-large-mnli",
                                num_layers=12)
            _, _, scores = scorer.score(df.question1.tolist(),
                                        df.question2.tolist())
        elif m == 'bleurt':
            checkpoint = "bleurt-large-512"
            scorer = score.BleurtScorer(checkpoint)
            scores = scorer.score(df.question1, df.question2, batch_size=50)
        elif m.startswith('mover'):
            # Truncate long questions else moverscore gets OOM
            q1 = df['question1'].apply(lambda s: s[:300]).tolist()
            q2 = df['question2'].apply(lambda s: s[:300]).tolist()
            idf_dict_hyp = get_idf_dict(q1)
            idf_dict_ref = get_idf_dict(q2)
            if '1' in m:
                n_gram = 1
            else:
                n_gram = 2
            scores = word_mover_score(q2,
                                      q1,
                                      idf_dict_ref,
                                      idf_dict_hyp,
                                      stop_words=[],
                                      n_gram=n_gram,
                                      remove_subwords=True,
                                      batch_size=64)

        df[m] = scores
        print('\n' * 10, m, '\n' * 10)
        df.to_csv(QQP_OUT_PATH)
Example #25
0
class KobeModel(pl.LightningModule):
    def __init__(self, args):
        super(KobeModel, self).__init__()

        self.encoder = Encoder(
            vocab_size=args.text_vocab_size + args.cond_vocab_size,
            max_seq_len=args.max_seq_len,
            d_model=args.d_model,
            nhead=args.nhead,
            num_layers=args.num_encoder_layers,
            dropout=args.dropout,
            mode=args.mode,
        )
        self.decoder = Decoder(
            vocab_size=args.text_vocab_size,
            max_seq_len=args.max_seq_len,
            d_model=args.d_model,
            nhead=args.nhead,
            num_layers=args.num_decoder_layers,
            dropout=args.dropout,
        )
        self.lr = args.lr
        self.d_model = args.d_model
        self.loss = nn.CrossEntropyLoss(reduction="mean",
                                        ignore_index=0,
                                        label_smoothing=0.1)
        self._reset_parameters()

        self.decoding_strategy = args.decoding_strategy
        self.vocab = BertTokenizer.from_pretrained(args.text_vocab_path)
        self.bleu = BLEU(tokenize=args.tokenize)
        self.sacre_tokenizer = _get_tokenizer(args.tokenize)()
        self.bert_scorer = BERTScorer(lang=args.tokenize,
                                      rescale_with_baseline=True)

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)

    def _tokenwise_loss_acc(self, logits: torch.Tensor,
                            batch: Batched) -> Tuple[torch.Tensor, float]:
        unmask = ~batch.description_token_ids_mask.T[1:]
        unmasked_logits = logits[unmask]
        unmasked_targets = batch.description_token_ids[1:][unmask]
        acc = helpers.accuracy(unmasked_logits, unmasked_targets)
        return self.loss(logits.transpose(1, 2),
                         batch.description_token_ids[1:]), acc

    def training_step(self, batch: Batched, batch_idx: int):
        encoded = self.encoder.forward(batch)
        logits = self.decoder.forward(batch, encoded)
        loss, acc = self._tokenwise_loss_acc(logits, batch)
        self.lr_schedulers().step()
        self.log("train/loss", loss.item())
        self.log("train/acc", acc)
        return loss

    def _shared_eval_step(self, batch: Batched,
                          batch_idx: int) -> DecodedBatch:
        encoded = self.encoder.forward(batch)
        logits = self.decoder.forward(batch, encoded)
        loss, acc = self._tokenwise_loss_acc(logits, batch)

        preds = self.decoder.predict(encoded_batch=encoded,
                                     decoding_strategy=self.decoding_strategy)
        generated = self.vocab.batch_decode(preds.T.tolist(),
                                            skip_special_tokens=True)

        return DecodedBatch(
            loss=loss.item(),
            acc=acc,
            generated=generated,
            descriptions=batch.descriptions,
        )

    def validation_step(self, batch, batch_idx):
        return self._shared_eval_step(batch, batch_idx)

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        return self._shared_eval_step(batch, batch_idx)

    def _shared_epoch_end(self, outputs: List[DecodedBatch], prefix):
        loss = np.mean([o.loss for o in outputs])
        acc = np.mean([o.acc for o in outputs])
        self.log(f"{prefix}/loss", loss)
        self.log(f"{prefix}/acc", acc)

        generated = [g for o in outputs for g in o.generated]
        references = [r for o in outputs for r in o.descriptions]

        # fmt: off
        # BLEU score
        self.log(f"{prefix}/bleu",
                 self.bleu.corpus_score(generated, [references]).score)

        # Diversity score
        self.log(
            f"{prefix}/diversity_3",
            float(
                helpers.diversity([self.sacre_tokenizer(g) for g in generated],
                                  n=3)))
        self.log(
            f"{prefix}/diversity_4",
            float(
                helpers.diversity([self.sacre_tokenizer(g) for g in generated],
                                  n=4)))
        self.log(
            f"{prefix}/diversity_5",
            float(
                helpers.diversity([self.sacre_tokenizer(g) for g in generated],
                                  n=5)))
        # fmt: on

        # BERTScore
        p, r, f = self.bert_scorer.score(generated, references)
        self.log(f"{prefix}/BERTScore_P", p.mean().item())
        self.log(f"{prefix}/BERTScore_R", r.mean().item())
        self.log(f"{prefix}/BERTScore_F", f.mean().item())

        # Examples
        columns = ["Generated", "Reference"]
        data = list(zip(generated[:256:16], references[:256:16]))
        table = wandb.Table(data=data, columns=columns)
        self.logger.experiment.log({f"examples/{prefix}": table})

    def validation_epoch_end(self, outputs):
        self._shared_epoch_end(outputs, "val")

    def test_epoch_end(self, outputs):
        self._shared_epoch_end(outputs, "test")

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(),
                                lr=self.lr,
                                betas=(0.9, 0.98))
        scheduler = WarmupDecayLR(optimizer,
                                  warmup_steps=10000,
                                  d_model=self.d_model)
        return [optimizer], [scheduler]
Example #26
0
args = parser.parse_args()

gold_path = args.gold
pred_path = args.pred

l = args.lang

with open(pred_path) as f:
    try:
        preds = [line.split(',') for line in f]
        preds = [[s.strip() for s in l] for l in preds]
    except:
        preds = [line.strip().split(',') for line in f]

with open(gold_path) as f:
    try:
        golds = [line.split(',') for line in f]
        golds = [[s.strip() for s in l] for l in golds]
    except:
        golds = [line.strip().split(',') for line in f]

scorer = BERTScorer(lang='en')

if 'tfidf' in pred_path.split('/')[-1]:
    print("in tfidf == true")
    P, R, F = get_bert_score(scorer, preds, golds)
else:
    P, R, F = get_bert_score(scorer, preds, golds)
    #P, R, F= scorer.score(cands, refs)
print(f"P={P.mean().item():f} R={R.mean().item():f} F={F.mean().item():f}")
Example #27
0
result = [cos_sim(i, j) for i, j in zip(adv_emb, ori_emb)]
with open('results/SimCSE.txt', 'w') as f:
    f.write('\n'.join([str(i) for i in result]))

result = [distance(i, j) for i, j in zip(adv_emb, ori_emb)]
with open('results/SimCSE_distance.txt', 'w') as f:
    f.write('\n'.join([str(i) for i in result]))

model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

adv_emb = model.encode(adv)
ori_emb = model.encode(ori)

result = [cos_sim(i, j) for i, j in zip(adv_emb, ori_emb)]
with open('results/SBERT.txt', 'w') as f:
    f.write('\n'.join([str(i) for i in result]))

result = [distance(i, j) for i, j in zip(adv_emb, ori_emb)]
with open('results/SBERT_distance.txt', 'w') as f:
    f.write('\n'.join([str(i) for i in result]))

model = BERTScorer(model_type='bert-base-uncased', idf=False)

SCORE_TYPE2IDX = {"precision": 0, "recall": 1, "f1": 2}
result = model.score(adv, ori)
result = result[SCORE_TYPE2IDX['f1']].numpy()

with open('results/BERTScore.txt', 'w') as f:
    f.write('\n'.join([str(i) for i in result]))
Example #28
0
def create_scorer():
    # Create scorer object for passing to get_bert_score
    scorer = BERTScorer(lang="en", rescale_with_baseline=True, model_type='roberta-base')
    return scorer
Example #29
0
# load hypothesis
cands = [line.strip().split("\t")[1] for line in open(hyp_file, 'r')]

# prepare reference list
ref_files = os.listdir(ref_path)
ref_dict = defaultdict(list)
for name in ref_files:
    _, _, sty, index = name.split(".")
    ref_dict[index].append((sty, ref_path + name))
for index in ref_dict:
    ref_dict[index].sort(key=lambda x: x[0])
    refs_i = []
    for _, file_path in ref_dict[index]:
        refs_i += [line.strip() for line in open(file_path, 'r')]
    ref_dict[index] = refs_i
ref_list = [refs for refs in ref_dict.values()]
ref_sents = [ref for refs in ref_list for ref in refs]
ref_list = list(zip(*ref_list))

# load BERT model
scorer = BERTScorer(model_type="albert-xlarge-v2",
                    lang="en",
                    rescale_with_baseline=True,
                    idf=True,
                    idf_sents=ref_sents,
                    batch_size=32)

P, R, F1 = scorer.score(cands, ref_list)
P, R, F1 = P.mean().item(), R.mean().item(), F1.mean().item()

print("P: %.4f; R: %.4f; F1: %.4f." % (P, R, F1))
class RhymeDistanceMeter:
    def __init__(self, chara_word='野球'):
        self.c = chara_word
        data_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data')
        self.baseline_file_path = os.path.join(data_dir, 'bert-base-multilingual-cased.tsv')
        self.scorer = BERTScorer(model_type=os.path.join(data_dir, 'bert-base_mecab-ipadic-bpe-32k_whole-word-mask'), num_layers=11, lang='ja', rescale_with_baseline=True, baseline_path=self.baseline_file_path)
        self.min_rhyme = 2

    def throw(self, s1, s2):
        rhyme_count = self.count_rhyme(s1, s2)
        sim_s, sim_c = self.score_similarity(s1, s2)
        len_rate = self.len_rate(s1, s2)
        dist = self.calc_dist(rhyme_count, sim_s, sim_c, len_rate)
        return dist

    def most_rhyming(self, killer_phrase, candidates, topn=3):
        res = {}
        for c in candidates:
            res[c] = self.count_rhyme(killer_phrase, c)
        logger.debug(f'{res=}')
        sorted_res = sorted(res.items(), key=lambda item: item[1], reverse=True)

        return [w[0] for w in sorted_res[:topn]]

    def len_rate(self, s1, s2):
        return min(len(s1), len(s2)) / max(len(s1), len(s2))

    def count_rhyme(self, s1, s2):
        romaji1 = romanize_sentence(s1)
        romaji2 = romanize_sentence(s2)

        vowel1 = vowelize(romaji1)
        vowel2 = vowelize(romaji2)
        logger.debug(f'{vowel1=}')
        logger.debug(f'{vowel2=}')

        min_len = min(len(vowel1), len(vowel2))

        cnt = 0
        # 脚韻
        for i in range(1, min_len+1):
            if vowel1[-i] == vowel2[-i]:
                cnt += 1
            else:
                break
        if cnt > 0:
            return cnt

        # 頭韻
        for i in range(min_len):
            if vowel1[i] == vowel2[i]:
                cnt += 1
            else:
                break

        return cnt

    def score_similarity(self, s1, s2):
        refs = [s1]
        hyps = [s2]

        s1_nouns = [w.surface for w in tagger(s1) if (w.feature[0] == '名詞' and w.surface != self.c)]
        s2_nouns = [w.surface for w in tagger(s2) if (w.feature[0] == '名詞' and w.surface != self.c)]
        logger.debug(f'{s1_nouns=}')
        logger.debug(f'{s2_nouns=}')

        for s in s1_nouns:
            refs.append(self.c)
            hyps.append(s)

        for s in s2_nouns:
            refs.append(self.c)
            hyps.append(s)

        logger.debug(f'{refs=}')
        logger.debug(f'{hyps=}')
        P, R, F1 = self.scorer.score(refs, hyps)
        dist_s = F1[0]

        logger.debug(f'{F1[1:]=}')
        dist_c = max(F1[1:])

        return dist_s, dist_c

    def calc_dist(self, count, sim_s, sim_c, len_rate):
        logger.debug(f'{count=}')
        logger.debug(f'{sim_s=}')
        logger.debug(f'{sim_c=}')
        logger.debug(f'{len_rate=}')
        return int(count ** ((1 - sim_s) * (sim_c * 10) * (1 + len_rate)))