Exemplo n.º 1
0
def test_example_document_retrieval(
        kilt_trie: Trie, fairseq_wikipage_retrieval: GENREHubInterface):
    sentences = ["Einstein was a German physicist."]
    results = fairseq_wikipage_retrieval.sample(
        sentences,
        prefix_allowed_tokens_fn=lambda batch_id, sent: kilt_trie.get(
            sent.tolist()),
    )
    assert results == EXPECTED_RESULTS_DOCUMENT_RETRIEVAL
Exemplo n.º 2
0
    def get_trie_entity(sent, sent_orig):
        pointer_start, pointer_end = get_pointer_mention(sent)

        if pointer_start + 1 != pointer_end:
            mention = decode_fn(sent[pointer_start + 1:pointer_end]).strip()

            if candidates_trie is not None:
                candidates_trie_tmp = candidates_trie
            elif mention_to_candidates_dict is not None:
                candidates_trie_tmp = Trie([
                    encode_fn(" }} [ {} ]".format(e))[1:]
                    for e in mention_to_candidates_dict.get(mention, ["NIL"])
                ])
            else:
                raise RuntimeError()

            return candidates_trie_tmp.get(sent[pointer_end:])

        return []
Exemplo n.º 3
0
def decode_with_constrain(sentences, schema, model):
    trie = Trie([model.encode(" {}".format(e))[1:].tolist() for e in schema])

    prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn(
        model,
        sentences,
        mention_trie=trie,
    )

    return model.sample(
        sentences,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    )
Exemplo n.º 4
0
def add_to_trie(trie, sequence):
    if sequence != []:
        if sequence[0] not in trie._leaves:
            trie._leaves[sequence[0]] = Trie([])
        add_to_trie(trie._leaves[sequence[0]], sequence[1:])
# hf-experiments
# @author Loreto Parisi (loretoparisi at gmail dot com)
# Copyright (c) 2021 Loreto Parisi (loretoparisi at gmail dot com)

import os
import pickle
from genre.trie import Trie
from genre.hf_model import GENRE

cache_dir = os.getenv("cache_dir", "../../models")

# load the prefix tree (trie)
with open(os.path.join(cache_dir, "kilt_titles_trie_dict.pkl"), "rb") as f:
    trie = Trie.load_from_dict(pickle.load(f))

# Example: Document Retrieval
model = GENRE.from_pretrained(os.path.join(cache_dir,
                                           "hf_wikipage_retrieval")).eval()
sentences = ["Madonna was the mother of Jesus."]
out = model.sample(
    sentences,
    prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),
)
print(out)
Exemplo n.º 6
0
def kilt_trie():
    # load the prefix tree (trie)
    with open("./data/kilt_titles_trie_dict.pkl", "rb") as f:
        trie = Trie.load_from_dict(pickle.load(f))
    return trie
Exemplo n.º 7
0
sentences = ["In 1921, Einstein received a Nobel Prize."]

# get the prefix_allowed_tokens_fn with the only constraints to annotate the original sentence (i.e., no other constrains on mention nor candidates)
# use .sample to make predictions constraining using prefix_allowed_tokens_fn
prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn(model, sentences)
out = model.sample(
    sentences,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)
print(out)

# constrain the mentions with a prefix tree (no constrains on candidates)
prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn(
    model,
    sentences,
    mention_trie=Trie([model.encode(e)[1:].tolist() for e in [" Einstein"]]))
out = model.sample(
    sentences,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
)
print(out)

# constrain the candidates with a prefix tree (no constrains on mentions)
prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn(
    model,
    sentences,
    candidates_trie=Trie([
        model.encode(" }} [ {} ]".format(e))[1:].tolist()
        for e in ["Albert Einstein", "Nobel Prize in Physics", "NIL"]
    ]))
out = model.sample(
Exemplo n.º 8
0
def evaluate_kilt_dataset(
    model,
    dataset,
    batch_size=4,
    beams=10,
    max_len_a=384,
    max_len_b=15,
    candidates=False,
    trie=None,
    title2id={},
    free_generation=False,
    test=False,
):

    dataset_original = deepcopy(dataset)

    gold = []
    pred = []

    iter_ = tqdm(dataset, desc="Evaluating")
    for docs in batch_it(iter_, batch_size):

        if not free_generation:
            batch_trie = {
                i: (
                    (
                        Trie(
                            [
                                [2] + model.encode(e).tolist()[1:]
                                for e in doc["candidates"]
                            ]
                        )
                        if doc["candidates"]
                        else Trie([[2] + model.encode("NIL").tolist()[1:]])
                    )
                    if candidates
                    else trie
                )
                for i, doc in enumerate(docs)
            }

            def prefix_allowed_tokens_fn(batch_id, sent):
                return batch_trie[batch_id].get(sent.tolist())

        outputs = model.sample(
            [
                create_input(
                    doc,
                    max_len_a,
                    start_delimiter="[START_ENT]",
                    end_delimiter="[END_ENT]",
                )
                for doc in docs
            ],
            beam=beams,
            max_len_b=max_len_b,
            prefix_allowed_tokens_fn=None
            if free_generation
            else prefix_allowed_tokens_fn,
        )

        for doc, out in zip(docs, outputs):
            if not test:
                gold.append(doc["output"][0]["answer"])
                try:
                    pred.append(out[0]["text"])
                except Exception as e:
                    pred.append("NIL")
                    print(doc)
                    print(e)

            doc["output"] = [
                {
                    "answer": "",
                    "provenance": [
                        {
                            "wikipedia_id": title2id.get(prov["text"], None),
                            "title": prov["text"],
                            "score": prov["score"].item(),
                        }
                        for prov in out
                    ],
                }
            ]

        if not test:
            true_pos = 0
            for g, p in zip(gold, pred):
                if g == p and p != "NIL":
                    true_pos += 1

            precision = (
                (true_pos / len([p for p in pred if p != "NIL"]))
                if len([p for p in pred if p != "NIL"])
                else 0
            )
            recall = (true_pos / len(gold)) if len(gold) else 0
            f1 = (
                (2 * precision * recall / (precision + recall))
                if precision + recall
                else 0
            )

            iter_.set_postfix(f1=f1, prec=precision, rec=recall)

    if not test:
        kilt_dict = compute(dataset_original, dataset, ks=[1, 5], rank_keys=["title"])
        return dataset, f1, precision, recall, kilt_dict["Rprec"], kilt_dict["recall@5"]
    else:
        return dataset, 0, 0, 0, 0, 0
Exemplo n.º 9
0
def evaluate_kilt_dataset(
    model,
    dataset,
    batch_size=4,
    beams=10,
    max_len_a=128,
    max_len_b=32,
    lenpen=1,
    trie=None,
    lang_title2wikidataID={},
    wikidataID2lang_title={},
    canonical_lang_title2wikidataID={},
    wikidataID2canonical_lang_title={},
    order="title_lang",
    canonical=False,
    free_generation=False,
    mention2wikidataID={},
    candidates_lowercase=False,
    allowed_langs=[],
    desc=None,
    max_candidates=None,
    only_en_candidates=False,
    only_freebase_candidates=False,
    wikidataID2freebaseID={},
):
    gold = []
    pred = []

    iter_ = tqdm(dataset, desc="Evaluating {}".format(desc if desc else ""))

    for docs in batch_it(iter_, batch_size):

        if not free_generation:
            batch_trie = {}
            for i, doc in enumerate(docs):
                mention = (unicodedata.normalize(
                    "NFKD",
                    HanziConv.toSimplified(doc["meta"]["mention"])).replace(
                        "•", "·").replace(".", "·"))

                candidates = list(mention2wikidataID.get(mention, {}).items())

                if candidates_lowercase:
                    candidates += list(
                        mention2wikidataID.get(mention.lower(), {}).items())

                candidates_tmp = defaultdict(int)
                for k, v in candidates:
                    candidates_tmp[k] += v

                candidates = [
                    e[0] for e in sorted(
                        candidates_tmp.items(),
                        key=lambda x: x[1],
                        reverse=True,
                    )
                ]

                if only_en_candidates:
                    candidates = [
                        cand for cand in candidates
                        if "en" in dict(wikidataID2lang_title[cand])
                    ]
                if only_freebase_candidates:
                    candidates = [
                        cand for cand in candidates
                        if cand in wikidataID2freebaseID
                    ]

                candidates = candidates[:max_candidates]

                if mention2wikidataID and candidates:
                    if canonical:
                        batch_bpes = [
                            [2] + model.encode("{} >> {}".format(*tuple(
                                reversed(wikidataID2canonical_lang_title[cand])
                            ) if order == "title_lang" else "{} >> {}".format(
                                *wikidataID2canonical_lang_title[cand]))
                                               ).tolist()[1:]
                            for cand in candidates
                            if cand in wikidataID2canonical_lang_title
                        ]
                    else:
                        batch_bpes = [
                            [2] +
                            model.encode("{} >> {}".format(title, lang) if
                                         order == "title_lang" else "{} >> {}".
                                         format(lang, title)).tolist()[1:]
                            for cand in candidates
                            for lang, title in wikidataID2lang_title.get(
                                cand, []) if lang in allowed_langs
                        ]

                    if batch_bpes:
                        batch_trie[i] = Trie()
                        for e in batch_bpes:
                            batch_trie[i].add(e)

                    else:
                        batch_trie[i] = trie

                else:
                    batch_trie[i] = trie

            def prefix_allowed_tokens_fn(batch_id, sent):
                return [
                    e for e in batch_trie[batch_id].get(sent.tolist())
                    if e < len(model.task.target_dictionary)
                ]

        outputs = model.sample(
            [create_input(doc, max_len_a) for doc in docs],
            beam=beams,
            lenpen=lenpen,
            max_len_b=max_len_b,
            prefix_allowed_tokens_fn=None
            if free_generation else prefix_allowed_tokens_fn,
        )

        for doc, out in zip(docs, outputs):

            doc["predictions"] = [{
                "answer":
                list([
                    canonical_lang_title2wikidataID.get(
                        tuple(
                            reversed(o["text"].split(" >> ")) if order ==
                            "title_lang" else o["text"].split(" >> ")),
                        None,
                    )
                ] if canonical else lang_title2wikidataID.get(
                    tuple(
                        reversed(o["text"].split(" >> ")) if order ==
                        "title_lang" else o["text"].split(" >> ")),
                    [None],
                )),
                "text":
                o["text"],
                "logprob":
                o["logprob"].item(),
            } for o in out]

            gold.append(doc["output"][0]["answer"])

            try:
                pred.append(doc["predictions"][0]["answer"])
            except Exception as e:
                pred.append([None])

        true_pos = 0
        for g, p in zip(gold, pred):
            if set(g).intersection(set(p)) and p != [None]:
                true_pos += 1

        precision = ((true_pos / len([p for p in pred if p != [None]]))
                     if len([p for p in pred if p != [None]]) else 0)
        recall = (true_pos / len(gold)) if len(gold) else 0
        f1 = ((2 * precision * recall / (precision + recall)) if precision +
              recall else 0)
        accuracy = [(set(g).intersection(set(p)) and p != [None])
                    or (g == [] and p == [None]) for g, p in zip(gold, pred)]
        accuracy = sum(accuracy) / len(accuracy)

        iter_.set_postfix(f1=f1, prec=precision, rec=recall, acc=accuracy)

    return dataset, f1, precision, recall, accuracy