def test_example_document_retrieval( kilt_trie: Trie, fairseq_wikipage_retrieval: GENREHubInterface): sentences = ["Einstein was a German physicist."] results = fairseq_wikipage_retrieval.sample( sentences, prefix_allowed_tokens_fn=lambda batch_id, sent: kilt_trie.get( sent.tolist()), ) assert results == EXPECTED_RESULTS_DOCUMENT_RETRIEVAL
def get_trie_entity(sent, sent_orig): pointer_start, pointer_end = get_pointer_mention(sent) if pointer_start + 1 != pointer_end: mention = decode_fn(sent[pointer_start + 1:pointer_end]).strip() if candidates_trie is not None: candidates_trie_tmp = candidates_trie elif mention_to_candidates_dict is not None: candidates_trie_tmp = Trie([ encode_fn(" }} [ {} ]".format(e))[1:] for e in mention_to_candidates_dict.get(mention, ["NIL"]) ]) else: raise RuntimeError() return candidates_trie_tmp.get(sent[pointer_end:]) return []
def decode_with_constrain(sentences, schema, model): trie = Trie([model.encode(" {}".format(e))[1:].tolist() for e in schema]) prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn( model, sentences, mention_trie=trie, ) return model.sample( sentences, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, )
def add_to_trie(trie, sequence): if sequence != []: if sequence[0] not in trie._leaves: trie._leaves[sequence[0]] = Trie([]) add_to_trie(trie._leaves[sequence[0]], sequence[1:])
# hf-experiments # @author Loreto Parisi (loretoparisi at gmail dot com) # Copyright (c) 2021 Loreto Parisi (loretoparisi at gmail dot com) import os import pickle from genre.trie import Trie from genre.hf_model import GENRE cache_dir = os.getenv("cache_dir", "../../models") # load the prefix tree (trie) with open(os.path.join(cache_dir, "kilt_titles_trie_dict.pkl"), "rb") as f: trie = Trie.load_from_dict(pickle.load(f)) # Example: Document Retrieval model = GENRE.from_pretrained(os.path.join(cache_dir, "hf_wikipage_retrieval")).eval() sentences = ["Madonna was the mother of Jesus."] out = model.sample( sentences, prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()), ) print(out)
def kilt_trie(): # load the prefix tree (trie) with open("./data/kilt_titles_trie_dict.pkl", "rb") as f: trie = Trie.load_from_dict(pickle.load(f)) return trie
sentences = ["In 1921, Einstein received a Nobel Prize."] # get the prefix_allowed_tokens_fn with the only constraints to annotate the original sentence (i.e., no other constrains on mention nor candidates) # use .sample to make predictions constraining using prefix_allowed_tokens_fn prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn(model, sentences) out = model.sample( sentences, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, ) print(out) # constrain the mentions with a prefix tree (no constrains on candidates) prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn( model, sentences, mention_trie=Trie([model.encode(e)[1:].tolist() for e in [" Einstein"]])) out = model.sample( sentences, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, ) print(out) # constrain the candidates with a prefix tree (no constrains on mentions) prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn( model, sentences, candidates_trie=Trie([ model.encode(" }} [ {} ]".format(e))[1:].tolist() for e in ["Albert Einstein", "Nobel Prize in Physics", "NIL"] ])) out = model.sample(
def evaluate_kilt_dataset( model, dataset, batch_size=4, beams=10, max_len_a=384, max_len_b=15, candidates=False, trie=None, title2id={}, free_generation=False, test=False, ): dataset_original = deepcopy(dataset) gold = [] pred = [] iter_ = tqdm(dataset, desc="Evaluating") for docs in batch_it(iter_, batch_size): if not free_generation: batch_trie = { i: ( ( Trie( [ [2] + model.encode(e).tolist()[1:] for e in doc["candidates"] ] ) if doc["candidates"] else Trie([[2] + model.encode("NIL").tolist()[1:]]) ) if candidates else trie ) for i, doc in enumerate(docs) } def prefix_allowed_tokens_fn(batch_id, sent): return batch_trie[batch_id].get(sent.tolist()) outputs = model.sample( [ create_input( doc, max_len_a, start_delimiter="[START_ENT]", end_delimiter="[END_ENT]", ) for doc in docs ], beam=beams, max_len_b=max_len_b, prefix_allowed_tokens_fn=None if free_generation else prefix_allowed_tokens_fn, ) for doc, out in zip(docs, outputs): if not test: gold.append(doc["output"][0]["answer"]) try: pred.append(out[0]["text"]) except Exception as e: pred.append("NIL") print(doc) print(e) doc["output"] = [ { "answer": "", "provenance": [ { "wikipedia_id": title2id.get(prov["text"], None), "title": prov["text"], "score": prov["score"].item(), } for prov in out ], } ] if not test: true_pos = 0 for g, p in zip(gold, pred): if g == p and p != "NIL": true_pos += 1 precision = ( (true_pos / len([p for p in pred if p != "NIL"])) if len([p for p in pred if p != "NIL"]) else 0 ) recall = (true_pos / len(gold)) if len(gold) else 0 f1 = ( (2 * precision * recall / (precision + recall)) if precision + recall else 0 ) iter_.set_postfix(f1=f1, prec=precision, rec=recall) if not test: kilt_dict = compute(dataset_original, dataset, ks=[1, 5], rank_keys=["title"]) return dataset, f1, precision, recall, kilt_dict["Rprec"], kilt_dict["recall@5"] else: return dataset, 0, 0, 0, 0, 0
def evaluate_kilt_dataset( model, dataset, batch_size=4, beams=10, max_len_a=128, max_len_b=32, lenpen=1, trie=None, lang_title2wikidataID={}, wikidataID2lang_title={}, canonical_lang_title2wikidataID={}, wikidataID2canonical_lang_title={}, order="title_lang", canonical=False, free_generation=False, mention2wikidataID={}, candidates_lowercase=False, allowed_langs=[], desc=None, max_candidates=None, only_en_candidates=False, only_freebase_candidates=False, wikidataID2freebaseID={}, ): gold = [] pred = [] iter_ = tqdm(dataset, desc="Evaluating {}".format(desc if desc else "")) for docs in batch_it(iter_, batch_size): if not free_generation: batch_trie = {} for i, doc in enumerate(docs): mention = (unicodedata.normalize( "NFKD", HanziConv.toSimplified(doc["meta"]["mention"])).replace( "•", "·").replace(".", "·")) candidates = list(mention2wikidataID.get(mention, {}).items()) if candidates_lowercase: candidates += list( mention2wikidataID.get(mention.lower(), {}).items()) candidates_tmp = defaultdict(int) for k, v in candidates: candidates_tmp[k] += v candidates = [ e[0] for e in sorted( candidates_tmp.items(), key=lambda x: x[1], reverse=True, ) ] if only_en_candidates: candidates = [ cand for cand in candidates if "en" in dict(wikidataID2lang_title[cand]) ] if only_freebase_candidates: candidates = [ cand for cand in candidates if cand in wikidataID2freebaseID ] candidates = candidates[:max_candidates] if mention2wikidataID and candidates: if canonical: batch_bpes = [ [2] + model.encode("{} >> {}".format(*tuple( reversed(wikidataID2canonical_lang_title[cand]) ) if order == "title_lang" else "{} >> {}".format( *wikidataID2canonical_lang_title[cand])) ).tolist()[1:] for cand in candidates if cand in wikidataID2canonical_lang_title ] else: batch_bpes = [ [2] + model.encode("{} >> {}".format(title, lang) if order == "title_lang" else "{} >> {}". format(lang, title)).tolist()[1:] for cand in candidates for lang, title in wikidataID2lang_title.get( cand, []) if lang in allowed_langs ] if batch_bpes: batch_trie[i] = Trie() for e in batch_bpes: batch_trie[i].add(e) else: batch_trie[i] = trie else: batch_trie[i] = trie def prefix_allowed_tokens_fn(batch_id, sent): return [ e for e in batch_trie[batch_id].get(sent.tolist()) if e < len(model.task.target_dictionary) ] outputs = model.sample( [create_input(doc, max_len_a) for doc in docs], beam=beams, lenpen=lenpen, max_len_b=max_len_b, prefix_allowed_tokens_fn=None if free_generation else prefix_allowed_tokens_fn, ) for doc, out in zip(docs, outputs): doc["predictions"] = [{ "answer": list([ canonical_lang_title2wikidataID.get( tuple( reversed(o["text"].split(" >> ")) if order == "title_lang" else o["text"].split(" >> ")), None, ) ] if canonical else lang_title2wikidataID.get( tuple( reversed(o["text"].split(" >> ")) if order == "title_lang" else o["text"].split(" >> ")), [None], )), "text": o["text"], "logprob": o["logprob"].item(), } for o in out] gold.append(doc["output"][0]["answer"]) try: pred.append(doc["predictions"][0]["answer"]) except Exception as e: pred.append([None]) true_pos = 0 for g, p in zip(gold, pred): if set(g).intersection(set(p)) and p != [None]: true_pos += 1 precision = ((true_pos / len([p for p in pred if p != [None]])) if len([p for p in pred if p != [None]]) else 0) recall = (true_pos / len(gold)) if len(gold) else 0 f1 = ((2 * precision * recall / (precision + recall)) if precision + recall else 0) accuracy = [(set(g).intersection(set(p)) and p != [None]) or (g == [] and p == [None]) for g, p in zip(gold, pred)] accuracy = sum(accuracy) / len(accuracy) iter_.set_postfix(f1=f1, prec=precision, rec=recall, acc=accuracy) return dataset, f1, precision, recall, accuracy