Exemplo n.º 1
0
def load_id2title_mongo(name, path, overwrite=False):
    mongo_id2t = MongoBackedDict(dbname=name + ".id2t")
    # TODO Maybe you can use the same db and its reverse?
    mongo_t2id = MongoBackedDict(dbname=name + ".t2id")
    # TODO fix below
    redirect_set = None
    if mongo_id2t.size() == 0 or mongo_t2id.size() == 0 or overwrite:
        logging.info("db not found at %s. creating ...", path)
        id2t, t2id = {}, {}
        redirect_set = set([])
        for line in open(path):
            parts = line.strip().split("\t")
            if len(parts) != 3:
                logging.info("bad line %s", line)
                continue
            # page_id, title = parts
            page_id, page_title, is_redirect = parts
            id2t[page_id] = page_title
            t2id[page_title] = page_id
            if is_redirect == "1":
                redirect_set.add(page_title)
        mongo_id2t.bulk_insert(regular_map=id2t, insert_freq=len(id2t))
        mongo_t2id.bulk_insert(regular_map=t2id, insert_freq=len(t2id))
        # obj = id2t, t2id, redirect_set
        # save(pkl_path, obj)
        # logging.info("saving id2t pkl to %s", pkl_path)
    logging.info("id2t of size %d", mongo_id2t.size())
    logging.info("t2id of size %d", mongo_t2id.size())
    return mongo_id2t, mongo_t2id, redirect_set
Exemplo n.º 2
0
def load_redirects_mongo(name, path, overwrite=False):
    # pkl_path = path + ".pkl"
    # if os.path.exists(pkl_path):
    # logging.info("pkl found! loading map %s", pkl_path)
    # r2t = load(pkl_path)
    # else:
    mongo_r2t = MongoBackedDict(dbname=name)
    if mongo_r2t.size() == 0 or overwrite:
        logging.info("db not found at %s. creating ...", path)
        f = open(path)
        r2t = {}
        err = 0
        logging.info("pkl not found ...")
        logging.info("loading map from %s", path)
        for idx, l in enumerate(f):
            parts = l.strip().split("\t")
            if len(parts) != 2:
                logging.info("error on line %d %s", idx, parts)
                err += 1
                continue
            redirect, title = parts
            if redirect in r2t:
                logging.info("duplicate keys! was this on purpose?")
            r2t[redirect] = title
        logging.info("map of size %d loaded %d err lines", len(r2t), err)
        mongo_r2t.bulk_insert(regular_map=r2t, insert_freq=len(r2t))
    logging.info("r2t of size %d", mongo_r2t.size())
    return mongo_r2t
Exemplo n.º 3
0
def load_prob_map_mongo(out_prefix,
                        kind,
                        dbname=None,
                        force_rewrite=False,
                        hostname="localhost"):
    path = out_prefix + "." + kind
    if dbname is None:
        dbname = path
    logging.info("dbname is %s", dbname)
    probmap = MongoBackedDict(dbname=dbname, hostname=hostname)
    logging.info("reading collection %s", path)
    if probmap.size() > 0 and not force_rewrite:
        logging.info(
            "collection already exists in db (size=%d). returning ...",
            probmap.size())
        return probmap
    else:
        if force_rewrite:
            logging.info("dropping existing collection in db.")
            probmap.drop_collection()
        # mmap = defaultdict(lambda: defaultdict(float))
        mmap = {}
        for idx, line in enumerate(open(path)):
            parts = line.split("\t")
            if idx > 0 and idx % 1000000 == 0:
                logging.info("read line %d", idx)
            if len(parts) != 4:
                logging.info("error on line %d: %s", idx, line)
                continue
            y, x, prob, _ = parts
            if y not in mmap:
                mmap[y] = {}
            mmap[y][x] = float(prob)
        for y in list(mmap.keys()):
            # TODO will below ever be false?
            # if y not in probmap:
            # Nested dict keys cannot have '.' and '$' in mongodb
            # tmpdict = {x: mmap[y][x] for x in mmap[y]}
            if len(mmap[y]) > 5000:
                logging.info(
                    "string %s can link to %d items (>10k)... skipping", y,
                    len(mmap[y]))
                # mmap[y] = []
                # continue
                del mmap[y]
            else:
                # tmpdict = [(x, mmap[y][x]) for x in mmap[y]]
                mmap[y] = list(mmap[y].items())
                # try:
                #     probmap[y] = tmpdict
                # except DocumentTooLarge as e:
                #     print(y, len(tmpdict))
                #     print(e)
        probmap.bulk_insert(regular_map=mmap, insert_freq=len(mmap))
    return probmap
Exemplo n.º 4
0
def load_map_mongo(path, overwrite=False):
    m = MongoBackedDict(dbname=path)
    rev_m = None
    if m.size() == 0 or overwrite:
        logging.info("dropping existing collection ...")
        m.drop_collection()
        tmp = {}
        # logging.info("pkl not found ...")
        logging.info("loading map from %s", path)
        abs_path = abs_path + path[5:]
        f = open(abs_path)
        err = 0
        for idx, l in enumerate(f):
            parts = l.strip().split("\t")
            if len(parts) != 2:
                logging.info("error on line %d %s", idx, parts)
                err += 1
                continue
            k, v = parts
            if k in tmp:
                logging.info("duplicate key %s was this on purpose?", k)
            tmp[k] = v
        rev_m = {v: k for k, v in tmp.items()}
        logging.info("inserting map of size %d to mongo (%d err lines)",
                     len(tmp), err)
        m.bulk_insert(regular_map=tmp, insert_freq=len(tmp))
    return m, rev_m
Exemplo n.º 5
0
def load_prob_map_mongo(out_prefix, kind, force_rewrite=False):
    path = out_prefix + "." + kind
    probmap = MongoBackedDict(dbname=path)
    logging.info("reading collection %s", path)
    if probmap.size() > 0 and not force_rewrite:
        logging.info(
            "collection already exists in db (size=%d). returning ...",
            probmap.size())
        return probmap
    else:
        if force_rewrite:
            logging.info("dropping existing collection in db.")
            probmap.drop_collection()
        mmap = defaultdict(lambda: defaultdict(float))
        for idx, line in enumerate(open(path)):
            parts = line.split("\t")
            if idx > 0 and idx % 1000000 == 0:
                logging.info("read line %d", idx)
            if len(parts) != 4:
                logging.info("error on line %d: %s", idx, line)
                continue
            y, x, prob, _ = parts
            mmap[y][x] = float(prob)
        for y in mmap:
            # TODO will below ever be false?
            # if y not in probmap:
            # Nested dict keys cannot have '.' and '$' in mongodb
            # tmpdict = {x: mmap[y][x] for x in mmap[y]}
            if len(mmap[y]) > 10000:
                logging.info(
                    "string %s can link to %d items (>10k)... skipping", y,
                    len(mmap[y]))
                mmap[y] = []
                continue
            tmpdict = [(x, mmap[y][x]) for x in mmap[y]]
            try:
                probmap[y] = tmpdict
            except DocumentTooLarge as e:
                print(y, len(tmpdict))
                print(e)
    return probmap
Exemplo n.º 6
0
class Inlinks:
    """
    reads the outlinks file and computes the inlinks dictionary from it.
    saves it in a pickled dict for fast access.
    """
    def __init__(self,
                 normalizer=None,
                 overwrite=False,
                 links_file=None,
                 pkl_path="/shared/bronte/upadhya3/tac2018/inlinks.pkl"):
        # normalizer = TitleNormalizer() if normalizer is None else normalizer
        if links_file is None:
            links_file = "/shared/preprocessed/cddunca2/wikipedia/outlinks.t2t"
        self.normalizer = normalizer
        self.inlinks = MongoBackedDict(dbname="enwiki_inlinks")
        if self.inlinks.size() == 0 or overwrite:
            self.inlinks.drop_collection()
            start = time.time()
            logging.info("loading from file %s", links_file)
            self.load_link_info(links_file=links_file)
            logging.info("created in %d secs", time.time() - start)

    def load_link_info(self, links_file):
        logging.info("loading links %s ...", links_file)
        bad = 0
        mmap = {}
        for idx, line in enumerate(open(links_file)):
            if idx > 0 and idx % 1000000 == 0:
                logging.info("read %d", idx)
            line = line.strip().split('\t')
            if len(line) != 2:
                # logging.info("skipping bad line %s", line)
                bad += 1
                if bad % 10000 == 0:
                    logging.info("bad %d total %d", bad, idx)
                continue
            src = line[0]
            trgs = line[1].split(' ')
            for trg in trgs:
                if trg not in mmap:
                    mmap[trg] = []
                mmap[trg].append(src)
        logging.info("inserting regular map into mongo")
        self.inlinks.bulk_insert(regular_map=mmap, insert_freq=len(mmap))
        # DONT DO THIS! this inserts one by one, which is slow
        # for trg in mmap:
        #     self.inlinks[trg] = mmap[trg]
        logging.info("mongo map made")
Exemplo n.º 7
0
class Wiki2Lorelei:
    def __init__(self, ilcode, overwrite=False):
        cll_name = "wiki2eid_il" + ilcode
        self.wiki2eids = MongoBackedDict(dbname=cll_name)
        self.normalizer = TitleNormalizer()
        if overwrite:
            self.wiki2eids.drop_collection()
            logging.info("computing wiki2eids map ...")
            self.compute_map(ilcode)
        logging.info("wiki2eids map loaded (size=%d)", self.wiki2eids.size())

    # @profile
    def compute_map(self, ilcode):
        basepath = "/shared/corpora/corporaWeb/lorelei/evaluation-2019/"
        kbfile = basepath + "il{}/source/kb/IL{}_kb/data/entities.tab".format(
            ilcode, ilcode, ilcode)
        tmp_map = {}
        for idx, line in enumerate(open(kbfile)):

            if idx > 0 and idx % 100000 == 0:
                logging.info("read %d lines", idx)

            parts = line.rstrip('\n').split('\t')
            if len(parts) < len(fields):
                logging.info("bad line %d nfields:%d expected:%d", idx,
                             len(parts), len(fields))
                continue

            kbentry = {}
            for field, v in zip(fields, parts):
                if len(v) != 0:
                    kbentry[field] = v

            eid = kbentry["entityid"]
            title = get_normalized_wikititle_kbentry(
                title_normalizer=self.normalizer, kbentry=kbentry)

            if title == NULL_TITLE:
                continue

            if title not in tmp_map:
                tmp_map[title] = []
            tmp_map[title].append(eid)
        self.wiki2eids.bulk_insert(regular_map=tmp_map,
                                   insert_freq=len(tmp_map))
Exemplo n.º 8
0
def load_nicknames(path="/shared/experiments/upadhya3/ppoudyaltest/wiki_list", overwrite=False):
    nicknames = MongoBackedDict(dbname="nicknames")
    if nicknames.size() == 0 or overwrite:
        nn_map = {}
        # populate nn_map
        for idx,line in enumerate(open(path)):
            parts = line.strip().split('\t')
            if idx > 0 and idx % 10000==0:
                logging.info("read %d lines", idx)
            # if len(parts)!=3:
            #     logging.info("bad line %s",line)
            #     continue
            title, tid = parts[:2]
            fr_strs = parts[2:]
            # print(title,tid,fr_strs)
            for fr_str in fr_strs:
                if fr_str not in nn_map:
                    nn_map[fr_str] = title
        nicknames.bulk_insert(regular_map=nn_map, insert_freq=len(nn_map))
    return nicknames
Exemplo n.º 9
0
class CandGen:
    def __init__(self,
                 lang=None,
                 year=None,
                 inlinks=None,
                 tsl=None,
                 wiki_cg=None,
                 tsl_concept_pair=None,
                 tsl_translit_dict=None,
                 spellchecker=None):
        self.init_counters()
        self.lang = lang
        self.year = year
        self.inlinks = inlinks
        self.spellchecker = spellchecker
        self.cheap_dict = dictionary[lang] if lang in dictionary else {}
        self.translit_model = tsl
        self.concept_pair, self.translit_dict = tsl_concept_pair, tsl_translit_dict
        self.wiki_cg = wiki_cg
        self.en_normalizer = TitleNormalizer(lang="en")

    def load_or_checker(self):
        or_spell = None
        # or_spell = SpellChecker(language=None, distance=2)
        or_spell.word_frequency.load_dictionary(
            'spellchecker/or_entity_dictionary.gz')
        return or_spell

    def load_kb(self, kbdir):
        self.m = MongoBackedDict(
            dbname='data/enwiki/idmap/enwiki-20190701.id2t.t2id')
        self.en_t2id = MongoBackedDict(
            dbname=f"data/enwiki/idmap/enwiki-20190701.id2t.t2id")
        self.en_id2t = MongoBackedDict(
            dbname=f"data/enwiki/idmap/enwiki-20190701.id2t.id2t")
        en_id2t_filepath = os.path.join(kbdir, 'enwiki', 'idmap',
                                        f'enwiki-{self.year}.id2t')
        self.fr2entitles = MongoBackedDict(dbname=f"{self.lang}2entitles")
        fr2entitles_filepath = os.path.join(kbdir, f'{self.lang}wiki', 'idmap',
                                            f'fr2entitles')
        self.t2id = MongoBackedDict(dbname=f"{self.lang}_t2id")

        if self.en_t2id.size() == 0 or self.en_id2t.size() == 0:
            logging.info(f'Loading en t2id and id2t...')
            en_id2t = []
            ent2id = defaultdict(list)
            for line in tqdm(open(en_id2t_filepath)):
                parts = line.strip().split("\t")
                if len(parts) != 3:
                    logging.info("bad line %s", line)
                    continue
                page_id, page_title, is_redirect = parts
                key = page_title.replace('_', ' ').lower()
                ent2id[key].append(page_id)
                en_id2t.append({
                    'key': page_id,
                    'value': {
                        'page_id': page_id,
                        'name': page_title,
                        'searchname': key
                    },
                    'redirect': is_redirect
                })
            ent2id_list = []
            for k, v in ent2id.items():
                ent2id_list.append({'key': k, 'value': v})
            logging.info("inserting %d entries into english t2id",
                         len(ent2id_list))
            self.en_t2id.cll.insert_many(ent2id_list)
            self.en_t2id.cll.create_index([("key", pymongo.HASHED)])

            logging.info("inserting %d entries into english id2t",
                         len(en_id2t))
            self.en_id2t.cll.insert_many(en_id2t)
            self.en_id2t.cll.create_index([("key", pymongo.HASHED)])

        if self.fr2entitles.size() == 0:
            logging.info(f'Loading fr2entitles and {self.lang} t2id...')
            fr2en = []
            t2id = []
            f = open(fr2entitles_filepath)
            for idx, l in enumerate(f):
                parts = l.strip().split("\t")
                if len(parts) != 2:
                    logging.info("error on line %d %s", idx, parts)
                    continue
                frtitle, entitle = parts
                key = frtitle.replace('_', ' ').lower()
                enkey = entitle.replace('_', ' ').lower()
                fr2en.append({
                    "key": key,
                    "value": {
                        'frtitle': frtitle,
                        'entitle': entitle,
                        'enkey': enkey
                    }
                })
                t2id.append({"key": key, "value": self.en_t2id[enkey]})
            logging.info(f"inserting %d entries into {self.lang}2entitles",
                         len(fr2en))
            self.fr2entitles.cll.insert_many(fr2en)
            self.fr2entitles.cll.create_index([("key", pymongo.HASHED)])
            logging.info(f"inserting %d entries into {self.lang} t2id",
                         len(t2id))
            self.t2id.cll.insert_many(t2id)
            self.t2id.cll.create_index([("key", pymongo.HASHED)])

    def extract_cands(self, cands):
        wiki_titles, wids, wid_cprobs = [], [], []
        for cand in cands:
            wikititle, p_t_given_s, p_s_given_t = cand.en_title, cand.p_t_given_s, cand.p_s_given_t
            nrm_title = self.en_normalizer.normalize(wikititle)
            if nrm_title == K.NULL_TITLE:  # REMOVED or nrm_title not in en_normalizer.title2id
                logging.info("bad cand %s nrm=%s", wikititle, nrm_title)
                continue
            wiki_id = self.en_normalizer.title2id[nrm_title]
            # if wiki_id is none
            if wiki_id is None:
                wiki_id = self.en_normalizer.title2id[wikititle]
                if wiki_id is None:
                    continue
                wiki_titles.append(wikititle)
                wids.append(wiki_id)
                wid_cprobs.append(p_t_given_s)
                continue
            wiki_titles.append(nrm_title)
            wids.append(wiki_id)
            wid_cprobs.append(p_t_given_s)
        return wiki_titles, wids, wid_cprobs

    def init_counters(self):
        self.eng_words = 0
        self.nils = 0
        self.no_wikis = 0
        self.trans_hits = 0
        self.total, self.total_hits, self.prior_correct, self.nil_correct = 0, 0, 0, 0

    def init_l2s_map(self, eids, args=None):
        l2s_map = {}
        for eid in eids:
            title = self.en_id2t[eid]
            # print(eid, title)
            # if eid is None:
            #     print("eid none", eid)
            # if title is None:
            #     print("title none", title)
            try:
                key = "|".join([eid, title])
            except:
                print("eid", eid, title)
                print("end")
                continue
            if not key in l2s_map:
                # if title not in self.inlinks:
                #     logging.info("not in inlinks %s, keeping []", title)
                #     inlinks = []
                # else:
                #     inlinks = self.inlinks[title]
                l2s_map[key] = 100
        # for k in l2s_map:
        #     if args.inlink:
        #         l2s_map[k] /= Z
        #     else:
        #         l2s_map[k] = 1
        return l2s_map

    def cross_check_score(self, l2s_map, eids):
        freq = dict(Counter(eids))
        for cand, v in l2s_map.items():
            cand_eid = cand.split("|")[0]
            l2s_map[cand] = l2s_map[cand] * (
                3**freq[cand_eid]) if cand_eid in eids else l2s_map[cand] * 0.1
        return l2s_map

    def wiki_contain_score(self, l2s_map, query_str, args):
        for cand in l2s_map.keys():
            cand_name = cand.split("|")[1]
            score = 1
            if cand_name in args.eid2wikisummary:
                summary = args.eid2wikisummary[cand_name]
            else:
                try:
                    summary = wikipedia.summary(cand_name)
                    args.eid2wikisummary.cll.insert_one({
                        "key": cand_name,
                        "value": summary
                    })
                except:
                    args.eid2wikisummary.cll.insert_one({
                        "key": cand_name,
                        "value": ""
                    })
                    summary = ""
            # check summary contains the query
            score = score * 2 if query_str + "," in summary else score * 1
            l2s_map[cand] *= score
        return l2s_map

    def bert_score(self, l2smap, query_emb, l2s_map, args):
        cand2sim = {}
        max_cand = None
        max_sim = -1000
        for cand in l2smap:
            cand_eid = cand.split("|")[0]
            cand_name = cand.split('|')[1]
            # request summary
            if cand_eid in args.eid2wikisummary:
                summary = args.eid2wikisummary[cand_eid]
                entity_wiki = args.eid2wikisummary_entity[cand_eid]
            else:
                summary, entity_wiki = get_wiki_summary(
                    f"https://en.wikipedia.org/?curid={cand_eid}")
                args.eid2wikisummary.cll.insert_one({
                    "key": cand_eid,
                    "value": summary
                })
                args.eid2wikisummary_entity.cll.insert_one({
                    "key": cand_eid,
                    "value": entity_wiki
                })
            summary = summary.lower()
            # bert
            try:
                cand_name = entity_wiki
            except:
                print("error")
            if cand_name in summary:
                cand_emb = s2maskedvec(mask_sents(cand_name, summary))
                sim = cosine_similarity([cand_emb], [query_emb])[0][0]
                cand2sim[cand] = sim
                if sim > max_sim:
                    max_sim = sim
                    max_cand = cand
            elif cand_name.lower() in summary:
                cand_emb = s2maskedvec(mask_sents(cand_name.lower(), summary))
                sim = cosine_similarity([cand_emb], [query_emb])[0][0]
                cand2sim[cand] = sim
                if sim > max_sim:
                    max_sim = sim
                    max_cand = cand
            else:
                logging.info("not in summary", cand_name)
                continue

        if len(cand2sim) > 1:
            l2s_map[max_cand] *= 1.5
        # logging.info("cand2sim", cand2sim)
        return l2s_map

    def get_context(self, query_str, text, k=10):
        if query_str in text:
            tokenizer = MWETokenizer()
            query_str_tokens = tuple(query_str.split())
            query_str_dashed = "_".join(query_str_tokens)
            tokenizer.add_mwe(query_str_tokens)
            text_token = tokenizer.tokenize(text.split())
            try:
                t_start = text_token.index(query_str_dashed)
            except:
                return None, None, None
            t_end = t_start + 1
            start_index = max(t_start - k, 0)
            end_index = min(t_end + k, len(text_token))
            text_token_query = text_token[start_index:t_start] + text_token[
                t_end + 1:end_index]
            context = " ".join(text_token_query)
            context_mention = text_token[start_index:t_start] + [
                query_str
            ] + text_token[t_end + 1:end_index]
            context_mention = " ".join(context_mention)
            return context, text_token_query, context_mention
        else:
            return None, None, None

    def get_l2s_map(self, eids_google, eids_google_maps, eids_wikicg,
                    eids_total, ner_type, query_str, text, args):
        l2s_map = self.init_l2s_map(eids_total, args=args)
        # check if generated cadidates
        if len(l2s_map) == 0 or len(l2s_map) == 1:
            return l2s_map

        l2s_map = self.cross_check_score(
            l2s_map, eids_google + eids_google_maps + eids_wikicg)

        #update score
        for cand in l2s_map.copy().keys():
            cand_eid = cand.split("|")[0]
            score = 1
            l2s_map[cand] *= score

        logging.info("Processed looping candidates")

        # get context:
        if args.bert:
            context, context_tokens, context_mention = self.get_context(
                query_str, text, k=10)

        # check context bert
        if args.bert and context is not None:
            query_emb = s2maskedvec(mask_sents(query_str, context_mention))
            l2s_map = self.bert_score(l2s_map, query_emb, l2s_map, args)
            logging.info("Processed candidates bert")

        # Normalize
        sum_s = sum(list(l2s_map.values()))
        for can, s in l2s_map.items():
            l2s_map[can] = s / sum_s
        return l2s_map

    def correct_surf(self, token):
        region_list = [
            "district of", "district", "city of", "state of", "province of",
            "division", "city", "valley", "province"
        ]
        token = token.lower()
        for i in region_list:
            token = token.replace(i, "").strip()
        return token

    def default_type(self, l2s_map, max_cand):
        l2s_map = dict(
            sorted(l2s_map.items(), key=operator.itemgetter(1), reverse=True))
        # del l2s_map[max_cand]
        max_cand_name, max_cand_eid = max_cand.split("|")[1], max_cand.split(
            "|")[0]
        max_cand_name = self.correct_surf(max_cand_name)
        if "feature_class" in self.en_id2t[max_cand_eid]:
            eid_type = self.en_id2t[max_cand_eid]["feature_class"]
        else:
            return max_cand
        capital_features = ["PPLA", "PPLA2", "PPLC"]
        district_set = pickle.load(
            open(
                f"/shared/experiments/xyu71/lorelei2017/src/lorelei_kb/IL11-12-District/dis_il{self.lang}.pickle",
                "rb"))
        if "feature_code" in self.en_id2t[max_cand_eid]:
            eid_fcode = self.en_id2t[max_cand_eid]["feature_code"][2:]
        else:
            eid_fcode = ""

        if self.lang == "ilo":
            if eid_fcode not in capital_features:
                if max_cand_name in district_set and eid_type == "P":
                    for k, v in l2s_map.items():
                        k_name, k_id = self.correct_surf(
                            k.split("|")[1]), k.split("|")[0]
                        if "feature_class" in self.en_id2t[k_id]:
                            if k_name == max_cand_name and self.en_id2t[k_id][
                                    "feature_class"] == "A":
                                return k
        elif self.lang == "or":
            if eid_type == "P":
                for k, v in l2s_map.items():
                    k_name, k_id = self.correct_surf(
                        k.split("|")[1]), k.split("|")[0]
                    if "feature_class" in self.en_id2t[k_id]:
                        if k_name == max_cand_name and self.en_id2t[k_id][
                                "feature_class"] == "A":
                            return k
        return max_cand

    def get_maxes_l2s_map(self, l2s_map):
        # pick max
        if len(l2s_map) == 0:
            max_cand, max_score = "NIL", 1.0
        else:
            maxes_l2s_map = {
                cand: score
                for cand, score in l2s_map.items()
                if score == max(l2s_map.values())
            }
            max_cand = list(maxes_l2s_map.keys())[0]
            max_score = l2s_map[max_cand]
        return max_cand, max_score

    def compute_hits_for_ta(self, docta, outfile, only_nils=False, args=None):
        if not args.overwrite:
            if os.path.exists(outfile):
                logging.error("file %s exists ... skipping", outfile)
                return
        try:
            ner_view = docta.get_view("NER_CONLL")
        except:
            return
        candgen_view_json = copy.deepcopy(ner_view.as_json)
        text = docta.text
        predict_mode = True

        if "constituents" not in candgen_view_json["viewData"][0]:
            return
        for idx, cons in enumerate(
                candgen_view_json["viewData"][0]["constituents"]):
            self.total += 1
            query_str = cons["tokens"]
            # query_str = clean_query(query_str)
            ner_type = cons["label"]

            eids_google, eids_wikicg, eids_google_maps = [], [], []

            # swj query
            #
            eids_google, eids_wikicg, eids_google_maps = self.get_lorelei_candidates(
                query_str, ner_type=ner_type, args=args)
            eids_total = eids_google + eids_google_maps + eids_wikicg

            # for eid in eids_wikicg:
            #     if eid is None:
            #         print("error: ", eid, query_str)
            #         1/0

            logging.info("got %d candidates for query:%s",
                         len(set(eids_total)), query_str)

            l2s_map = self.get_l2s_map(eids_google,
                                       eids_google_maps,
                                       eids_wikicg,
                                       eids_total,
                                       ner_type=ner_type,
                                       query_str=query_str,
                                       text=text,
                                       args=args)

            logging.info(
                f"got {len(l2s_map)} candidates after ranking for {query_str}: {l2s_map}"
            )
            max_cand, max_score = self.get_maxes_l2s_map(l2s_map)

            if len(l2s_map) > 0 and args.lang in ["ilo", "or"]:
                max_cand_default = self.default_type(l2s_map, max_cand)
                if max_cand_default != max_cand:
                    max_score = 1.0
                    max_cand = max_cand_default
                logging.info(f"got candidate after default")

            if len(l2s_map) > 0:
                cons["labelScoreMap"] = l2s_map
            cons["label"] = max_cand
            cons["score"] = max_score

        candgen_view_json["viewName"] = "CANDGEN"
        candgen_view = View(candgen_view_json, docta.get_tokens)
        docta.view_dictionary["CANDGEN"] = candgen_view
        docta_json = docta.as_json
        with open(outfile, 'w', encoding='utf-8') as f:
            json.dump(docta_json, f, ensure_ascii=False, indent=True)
        # json.dump(docta_json, open(outfile, "w"), indent=True)
        self.report(predict_mode=predict_mode)

    def get_lorelei_candidates(self,
                               query_str,
                               romanized_query_str=None,
                               ner_type=None,
                               args=None):
        # SKB+SG
        desuf_query_str = remove_suffix(query_str, self.lang)
        dot_query_str_list = correct_surface(query_str, self.lang)
        desuf_dot_query_str_list = correct_surface(desuf_query_str, self.lang)

        if args.wikicg:
            wiki_titles, wids, wid_cprobs = self.extract_cands(
                self.wiki_cg.get_candidates(surface=query_str))
            eids_wikicg = wids
        else:
            eids_wikicg = []

        # SG suffix dot suffix+dot
        if args.google == 1:
            eids_google, g_wikititles = self._get_candidates_google(
                query_str, top_num=args.google_top)
            # if not eids_google:
            for i in [desuf_query_str
                      ] + dot_query_str_list + desuf_dot_query_str_list:
                e, t = self._get_candidates_google(i, top_num=args.google_top)
                eids_google += e
                g_wikititles += t
            eids_google = list(set(eids_google))
        else:
            eids_google = []

        # logging.info("got %d candidates for query:%s from google", len(eids_google), query_str)

        eids_google_maps = []
        if ner_type in ['GPE', 'LOC'] and args.google_map:
            google_map_name = query2gmap_api(query_str, self.lang)
            eids_google_maps += self._exact_match_kb(google_map_name, args)
            eids_google_maps += self._get_candidates_google(
                google_map_name, lang='en', top_num=args.google_top)[0]
            # if not eids_google_maps:
            google_map_name_suf = query2gmap_api(desuf_query_str, self.lang)
            google_map_name_dot = [
                query2gmap_api(k, self.lang) for k in dot_query_str_list
            ]
            google_map_name_suf_dot = [
                query2gmap_api(k, self.lang) for k in desuf_dot_query_str_list
            ]
            eids_google_maps += self._exact_match_kb(google_map_name_suf, args)
            eids_google_maps += self._get_candidates_google(
                google_map_name_suf, lang='en', top_num=args.google_top)[0]

            eids_google_maps += [
                h for k in google_map_name_dot
                for h in self._exact_match_kb(k, args)
            ]
            eids_google_maps += [
                h for k in google_map_name_dot
                for h in self._get_candidates_google(
                    k, lang='en', top_num=args.google_top)[0]
            ]

            eids_google_maps += [
                h for k in google_map_name_suf_dot
                for h in self._exact_match_kb(k, args)
            ]
            eids_google_maps += [
                h for k in google_map_name_suf_dot
                for h in self._get_candidates_google(
                    k, lang='en', top_num=args.google_top)[0]
            ]
        eids_google_maps = list(set(eids_google_maps))
        # logging.info("got %d candidates for query:%s from google map", len(set(eids_google_maps)), query_str)

        return eids_google, eids_wikicg, eids_google_maps

    def _get_candidates_google(self, surface, top_num=1, lang=None):
        eids = []
        wikititles = []
        if surface is None or len(surface) < 2:
            return eids, wikititles
        if surface in self.cheap_dict:
            surface = self.cheap_dict[surface]
        if lang is None:
            lang = self.lang
        en_surfaces = query2enwiki(surface, lang)[:top_num]
        wikititles = []
        for ent in en_surfaces:
            ent_normalize = ent.replace(" ", "_")
            wikititles.append(ent_normalize)
            if ent_normalize not in self.en_t2id:
                print("bad title")
                continue
            else:
                eids.append(self.en_t2id[ent_normalize])
        return eids, wikititles

    def _exact_match_kb(self, surface, args):
        eids = []
        if surface is None:
            return eids
        # surface = surface.lower()
        # logging.info("===> query string:%s len:%d", surface, len(surface))
        if len(surface) < 2:
            # logging.info("too short a query string")
            return []
        # Exact Match
        eids += self.get_phrase_cands(surface)
        return eids

    def get_phrase_cands(self, surf):
        surf = surf.lower()
        ans = []
        if surf in self.t2id:
            cands = self.t2id[surf]
            # logging.info("#phrase cands geoname %d for %s", len(cands), surf)
            ans += cands
        if surf in self.en_t2id:
            cands = self.en_t2id[surf]
            ans += cands
        return ans

    def report(self, predict_mode):
        if not predict_mode:
            logging.info("total_hits %d/%d=%.3f", self.total_hits, self.total,
                         self.total_hits / self.total)
            logging.info("prior correct %d/%d=%.3f", self.prior_correct,
                         self.total, self.prior_correct / self.total)
            logging.info("nil correct %d/%d=%.3f", self.nil_correct,
                         self.total, self.nil_correct / self.total)
        else:
            logging.info("saw total %d", self.total)

    def get_romanized(self, cons, rom_view):
        overlap_cons = rom_view.get_overlapping_constituents(
            cons['start'], cons['end'])
        romanized = " ".join([c["label"] for c in overlap_cons[1:-1]])
        logging.info("str:%s romanized:%s", cons["tokens"], romanized)
        return romanized
Exemplo n.º 10
0
class AbstractIndex:
    def __init__(self, index_name, kbfile, overwrite=False, ngramorders=[]):
        self.name2ent = MongoBackedDict(dbname=index_name + ".phrase")
        self.word2ent = MongoBackedDict(dbname=index_name + ".word")
        self.ngram2ent = {}
        self.kbfile = kbfile
        self.ngramorders = ngramorders

        for i in self.ngramorders:
            self.ngram2ent[i] = MongoBackedDict(dbname=index_name +
                                                ".ngram-{}".format(i))
        index_type = None
        indices = []
        all_empty = all([i.size() == 0 for i in indices])
        if overwrite or all_empty:
            self.name2ent.drop_collection()
            self.word2ent.drop_collection()
            for i in self.ngramorders:
                self.ngram2ent[i].drop_collection()
            index_type = "all"
        else:
            # TODO The logic here is messed up
            if self.name2ent.size() == 0:
                self.name2ent.drop_collection()
                index_type = "name2ent"

            if self.word2ent.size() == 0:
                self.word2ent.drop_collection()
                index_type = "word2ent"

            for i in self.ngramorders:
                if self.ngram2ent[i].size() == 0:
                    self.ngram2ent[i].drop_collection()
                    index_type = "ngram2ent"

        if index_type is not None:
            start = time.time()
            logging.info("loading from file %s", index_name)
            self.load_kb(index_type=index_type)
            logging.info("created in %d secs", time.time() - start)
        logging.info("%s loaded", index_name)

    def process_kb(self):
        raise NotImplementedError

    def load_kb(self, index_type):
        name_map = {}
        word_map = {}
        ngram_map = {}
        logging.info("index type:%s", index_type)
        for i in self.ngramorders:
            ngram_map[i] = {}
        try:
            for names, eid in self.process_kb():

                names = set(names)
                if index_type in ["all", "name2ent"]:
                    add_to_dict(names, eid, name_map)

                if index_type in ["all", "word2ent"]:
                    toks = set([tok for n in names for tok in n.split(" ")])
                    add_to_dict(toks, eid, word_map)

                if index_type in ["all", "ngram2ent"]:
                    for i in self.ngramorders:
                        ngramset = set([
                            gram for n in names
                            for gram in getngrams(n, ngram=i)
                        ])
                        add_to_dict(ngramset, eid, ngram_map[i])

            self.put_in_mongo(index_type, name_map, word_map, ngram_map)
        except KeyboardInterrupt:
            logging.info("ending prematurely.")
            self.put_in_mongo(index_type, name_map, word_map, ngram_map)

    def put_in_mongo(self, index_type, name_map, word_map, ngram_map):
        if index_type in ["all", "name2ent"]:
            self.name2ent.bulk_insert(name_map,
                                      insert_freq=len(name_map),
                                      value_func=lambda x: list(x))
        if index_type in ["all", "word2ent"]:
            self.word2ent.bulk_insert(word_map,
                                      insert_freq=len(word_map),
                                      value_func=lambda x: list(x))
        if index_type in ["all", "ngram2ent"]:
            for i in self.ngramorders:
                ngram_map[i] = self.prune_map(ngram_map[i])
                self.ngram2ent[i].bulk_insert(ngram_map[i],
                                              insert_freq=len(ngram_map[i]),
                                              value_func=lambda x: list(x))

    def prune_map(self, nmap):
        # dict changes during iteration, so take care
        for k in list(nmap.keys()):
            if len(nmap[k]) > 10000:
                logging.info("pruning entry for %s len=%d", k, len(nmap[k]))
                del nmap[k]
        return nmap
Exemplo n.º 11
0
class CandGen:
    def __init__(self,
                 lang=None,
                 year=None,
                 wiki_cg=None,
                 inlinks=None,
                 tsl=None,
                 tsl_concept_pair=None,
                 tsl_translit_dict=None,
                 spellchecker=None,
                 classifier=None):
        self.init_counters()
        self.lang = lang
        self.year = year
        self.inlinks = inlinks
        self.spellchecker = spellchecker
        self.classifier = classifier
        self.cheap_dict = dictionary[lang]
        self.translit_model = tsl
        self.concept_pair, self.translit_dict = tsl_concept_pair, tsl_translit_dict
        self.wiki_cg = wiki_cg
        self.en_normalizer = TitleNormalizer(lang="en")

    def load_or_checker(self):
        or_spell = None
        # or_spell = SpellChecker(language=None, distance=2)
        or_spell.word_frequency.load_dictionary(
            'spellchecker/or_entity_dictionary.gz')
        return or_spell

    def load_kb(self, kbdir):
        # self.m = MongoBackedDict(dbname='data/enwiki/idmap/enwiki-20190701.id2t.t2id')
        self.en_t2id = MongoBackedDict(dbname=f"en_t2id")
        self.en_id2t = MongoBackedDict(dbname=f"en_id2t")
        en_id2t_filepath = os.path.join(kbdir, "enwiki", "idmap",
                                        f'enwiki-{self.year}.id2t')
        self.fr2entitles = MongoBackedDict(dbname=f"{self.lang}2entitles")
        fr2entitles_filepath = os.path.join(kbdir, f'{self.lang}wiki', 'idmap',
                                            f'fr2entitles')
        self.t2id = MongoBackedDict(dbname=f"{self.lang}_t2id")

        if self.en_t2id.size() == 0 or self.en_id2t.size() == 0:
            logging.info(f'Loading en t2id and id2t...')
            self.en_t2id.drop_collection()
            self.en_id2t.drop_collection()
            en_id2t = []
            ent2id = defaultdict(list)
            for line in tqdm(open(en_id2t_filepath)):
                parts = line.strip().split("\t")
                if len(parts) != 3:
                    logging.info("bad line %s", line)
                    continue
                page_id, page_title, is_redirect = parts
                key = page_title.replace('_', ' ').lower()
                ent2id[key].append(page_id)
                en_id2t.append({
                    'key': page_id,
                    'value': {
                        'page_id': page_id,
                        'name': page_title,
                        'searchname': key
                    },
                    'redirect': is_redirect
                })
            ent2id_list = []
            for k, v in ent2id.items():
                ent2id_list.append({'key': k, 'value': v})
            logging.info("inserting %d entries into english t2id",
                         len(ent2id_list))
            self.en_t2id.cll.insert_many(ent2id_list)
            self.en_t2id.cll.create_index([("key", pymongo.HASHED)])

            logging.info("inserting %d entries into english id2t",
                         len(en_id2t))
            self.en_id2t.cll.insert_many(en_id2t)
            self.en_id2t.cll.create_index([("key", pymongo.HASHED)])

        if self.fr2entitles.size() == 0:
            logging.info(f'Loading fr2entitles and {self.lang} t2id...')
            fr2en = []
            t2id = []
            f = open(fr2entitles_filepath)
            for idx, l in enumerate(f):
                parts = l.strip().split("\t")
                if len(parts) != 2:
                    logging.info("error on line %d %s", idx, parts)
                    continue
                frtitle, entitle = parts
                key = frtitle.replace('_', ' ').lower()
                enkey = entitle.replace('_', ' ').lower()
                fr2en.append({
                    "key": key,
                    "value": {
                        'frtitle': frtitle,
                        'entitle': entitle,
                        'enkey': enkey
                    }
                })
                t2id.append({"key": key, "value": self.en_t2id[enkey]})
            logging.info(f"inserting %d entries into {self.lang}2entitles",
                         len(fr2en))
            self.fr2entitles.cll.insert_many(fr2en)
            self.fr2entitles.cll.create_index([("key", pymongo.HASHED)])
            logging.info(f"inserting %d entries into {self.lang} t2id",
                         len(t2id))
            self.t2id.cll.insert_many(t2id)
            self.t2id.cll.create_index([("key", pymongo.HASHED)])

    def init_counters(self):
        self.eng_words = 0
        self.nils = 0
        self.no_wikis = 0
        self.trans_hits = 0
        self.total, self.total_hits, self.prior_correct, self.nil_correct = 0, 0, 0, 0

    def get_country_code(self):
        if (self.lang == 'si'):
            country_code = "LK"
        elif (self.lang == 'rw'):
            country_code = "RW"
        elif (self.lang == 'or'):
            country_code = "IN"
        elif (self.lang == 'ilo'):
            country_code = "PH"
        else:
            raise ValueError('country code not provided')
        return country_code

    def clean_query(self, query_str):
        if self.lang == 'rw':
            if 'North' == query_str[-5:]:
                query_str = query_str[:-5] + ' ' + query_str[-5:]

        if query_str[-1] == ",":
            query_str = query_str[:-1]
        f = re.compile('(#|\(|\)|@)')
        query_str = f.sub(' ', query_str)
        query_str = re.sub('\s+', ' ', query_str).strip()
        return query_str.lower()

    def init_l2s_map(self, eids, args=None):

        l2s_map = {}
        for eid in eids:
            flag = False
            title = (self.en_id2t[eid]["name"] if eid in self.en_id2t else
                     eid.replace(' ', '_')).lower()
            if 'category:' in title:
                continue
            key = "|".join([eid, title])
            # for exist_key in l2s_map:
            #     if title == exist_key.split('|')[1]:
            #         l2s_map[exist_key] += 100
            #         flag = True
            if not flag and not key in l2s_map:
                # if title not in self.inlinks:
                #     logging.info("not in inlinks %s, keeping []", title)
                #     inlinks = []
                # else:
                #     inlinks = self.inlinks[title]
                l2s_map[key] = 100
        return l2s_map

    def merge_l2s_map_by_entity(self, l2s_map):
        l2s_map_new = {}
        for key in l2s_map:
            flag = False
            title = key.split('|')[1]
            for exist_key in l2s_map_new:
                if title == exist_key.split('|')[1]:
                    l2s_map_new[exist_key] += l2s_map[key]
                    flag = True
            if not flag and not key in l2s_map_new:
                l2s_map_new[key] = l2s_map[key]
        return l2s_map_new

    def cross_check_score(self, l2s_map, eids):
        freq = dict(Counter(eids))
        for cand, v in l2s_map.items():
            cand_eid = cand.split("|")[0]
            l2s_map[cand] = l2s_map[cand] * (
                3**freq[cand_eid]) if cand_eid in eids else l2s_map[cand] * 0.1
        return l2s_map

    def wiki_contain_score(self, l2s_map, query_str, args):
        for cand in l2s_map.keys():
            cand_name = cand.split("|")[1]
            score = 1
            if cand_name in args.eid2wikisummary:
                summary = args.eid2wikisummary[cand_name]
            else:
                try:
                    summary = wikipedia.summary(cand_name)
                    args.eid2wikisummary.cll.insert_one({
                        "key": cand_name,
                        "value": summary
                    })
                except:
                    args.eid2wikisummary.cll.insert_one({
                        "key": cand_name,
                        "value": ""
                    })
                    summary = ""
            # check summary contains the query
            score = score * 2 if query_str + "," in summary else score * 1
            l2s_map[cand] *= score
        return l2s_map

    def bert_score(self, l2smap, query_emb, l2s_map, args):
        cand2sim = {}
        max_cand = None
        max_sim = -1000
        no_sum = 0
        not_in_sum = 0
        for cand in l2smap:
            cand_title = cand.split("|")[0]
            # request summary
            # entity_wiki = cand_eid
            if cand_title in args.eid2wikisummary:
                summary = args.eid2wikisummary[cand_title]
            else:
                # summary = get_wiki_summary(f"https://en.wikipedia.org/?curid={quote(cand_title)}")
                summary = get_wiki_summary(cand_title)
                args.eid2wikisummary.cll.insert_one({
                    "key": cand_title,
                    "value": summary
                })
                # args.eid2wikisummary_entity.cll.insert_one({"key": cand_eid, "value": entity_wiki})
            summary = summary.lower()
            cand_name = cand_title.lower()
            if summary == '':
                no_sum += 1
            # bert
            else:
                if cand_name in summary:
                    cand_emb = s2maskedvec(mask_sents(cand_name, summary))
                    sim = cosine_similarity([cand_emb], [query_emb])[0][0]
                    cand2sim[cand] = sim
                    if sim > max_sim:
                        max_sim = sim
                        max_cand = cand
                else:
                    not_in_sum += 1
                    continue

        logging.info(
            f"{no_sum} / {len(l2s_map)} dont have summary, {not_in_sum} / {len(l2s_map) - no_sum} not in summary"
        )

        if len(cand2sim) > 1:
            l2s_map[max_cand] *= 3 * (len(l2s_map) - no_sum -
                                      not_in_sum) / len(l2s_map)
        # logging.info("cand2sim", cand2sim)
        return l2s_map

    def get_context(self, query_str, text, k=10):
        if query_str in text:
            tokenizer = MWETokenizer()
            query_str_tokens = tuple(query_str.split())
            query_str_dashed = "_".join(query_str_tokens)
            tokenizer.add_mwe(query_str_tokens)
            text_token = tokenizer.tokenize(text.split())
            try:
                t_start = text_token.index(query_str_dashed)
            except:
                return None, None, None
            t_end = t_start + 1
            start_index = max(t_start - k, 0)
            end_index = min(t_end + k, len(text_token))
            text_token_query = text_token[start_index:t_start] + text_token[
                t_end + 1:end_index]
            context = " ".join(text_token_query)
            context_mention = text_token[start_index:t_start] + [
                query_str
            ] + text_token[t_end + 1:end_index]
            context_mention = " ".join(context_mention)
            return context, text_token_query, context_mention
        else:
            logging.info('error, query not in text')
            return None, None, None

    def get_l2s_map(self, eids, eids_google, eids_hindi, eids_google_maps,
                    eids_spell, eids_gtrans, eids_trans, eids_wikicg,
                    eids_total, ner_type, query_str, text, args):
        if args.wikidata:
            l2s_map = self.init_l2s_map(eids_wikicg + eids + eids_google +
                                        eids_hindi + eids_google_maps +
                                        eids_spell + eids_gtrans + eids_trans,
                                        args=args)
        else:
            l2s_map = self.init_l2s_map(eids_total, args=args)

        # check if generated cadidates
        if len(l2s_map) <= 1:
            return l2s_map

        if ner_type in ['GPE', 'LOC']:
            l2s_map = self.cross_check_score(l2s_map,
                                             eids_google + eids_google_maps)

        # if self.lang == 'rw':
        #     l2s_map = self.cross_check_score(l2s_map, eids + eids_google)
        # else:
        if args.wikidata:
            l2s_map = self.cross_check_score(l2s_map, eids_wikicg)
        else:
            l2s_map = self.cross_check_score(
                l2s_map, eids + eids_google + eids_google_maps)

        feat2id = {}
        cand2link = {}
        # True vs NER
        type_mismatches = [('GPE', 'PER'), ('LOC', 'PER'), ('PER', 'GPE'),
                           ('PER', 'LOC')]

        #update score
        for cand in l2s_map.copy().keys():
            cand_eid, cand_text = cand.split("|")
            score = 1

            if self.lang == 'rw':
                # score *= max(2 - abs(len(cand_text.split('_')) - len(query_str.split())), 1)
                # score *= 3 if len(set(query_str.split()).intersection(set(cand_text.split('_')))) > 0 else 1
                if 'rwanda' in cand_text.split('_'):
                    score *= 3

            # # True vs NER
            # type_tup = (self.en_id2t[cand_eid]['entity_type'], ner_type)
            # check_notmatch = type_tup in type_mismatches
            # if args.mtype:
            #     if check_notmatch:
            #         del l2s_map[cand]
            #         continue

            # country_code = self.get_country_code()
            # if 'country_code' in self.en_id2t[cand_eid]:
            #     score = score*1 if self.en_id2t[cand_eid]['country_code'] == country_code else score*0.3
            # else:
            #     score *= 0.5

            # if "admin1_code_name" in self.en_id2t[cand_eid]:
            #     if self.lang=="or":
            #         score = score * 1 if self.en_id2t[cand_eid]['admin1_code_name'] == "Odisha" else score * 0.6
            #     elif self.lang=="ilo":
            #         score = score * 1 if self.en_id2t[cand_eid]['admin1_code_name'] == "Ilocos" else score * 0.6

            # # Link add weight
            # if ner_type in ['LOC', 'GPE']:
            #     if 'external_link' in self.en_id2t[cand_eid]:
            #         links = self.en_id2t[cand_eid]["external_link"].split("|")
            #         link = [l for l in links if '.wikipedia.org' in l][:1]
            #         if link:
            #             cand2link[cand] = link[0]
            #         if self.en_id2t[cand_eid]["external_link"]:
            #             score *= 1.5
            # if ner_type in ['ORG']:
            #     if 'org_website' in self.en_id2t[cand_eid]:
            #         if self.en_id2t[cand_eid]['org_website']:
            #             score *= 1.5

            l2s_map[cand] *= score

            # Serve Classifier
            if args.classifier and ner_type in ['GPE', 'LOC']:
                if 'feature_class' in self.en_id2t[cand_eid]:
                    if self.en_id2t[cand_eid]["feature_class"] == "P":
                        feat2id.setdefault("P", set([])).add(cand)
                    if self.en_id2t[cand_eid]["feature_class"] == "A":
                        feat2id.setdefault("A", set([])).add(cand)

        logging.info("Processed looping candidates")

        # check wiki contain for cand2links
        if args.wiki_contain:
            if self.lang == 'or':
                wiki_lang_str = "Odia: "
            elif self.lang == 'ilo':
                wiki_lang_str = "Ilokano: "
            for cand in cand2link:
                cand_eid = cand.split("|")[0]
                if cand_eid in args.eid2wikisummary:
                    summary = args.eid2wikisummary[cand_eid]
                else:
                    summary = get_wiki_summary(cand2link[cand])
                    args.eid2wikisummary.cll.insert_one({
                        "key": cand_eid,
                        "value": summary
                    })
                if wiki_lang_str + query_str in summary:
                    l2s_map[cand] = l2s_map[cand] * 3
            logging.info("Processed candidates wiki contain")

        clas = args.classifier and ner_type in ['GPE', 'LOC']
        # get context:
        if clas or args.bert:
            context, context_tokens, context_mention = self.get_context(
                query_str, text, k=10)

        # check classifier
        if clas and (context_tokens is not None):
            if len(feat2id.keys()) == 2:
                predicted = feature_map[self.classifier.sent2pred(
                    context_tokens, self.lang)]
                for cand in feat2id[predicted]:
                    l2s_map[cand] *= 1.5
            logging.info("Processed candidates classifier")

        # check context bert
        if args.bert and context is not None:
            logging.info("Processing candidates bert")
            query_emb = s2maskedvec(mask_sents(query_str, context_mention))
            l2s_map = self.bert_score(l2s_map, query_emb, l2s_map, args)

        # FOr rw, merge l2s_map
        if self.lang == 'rw':
            l2s_map = self.merge_l2s_map_by_entity(l2s_map)

        # Normalize
        sum_s = sum(list(l2s_map.values()))
        for can, s in l2s_map.items():
            l2s_map[can] = s / sum_s
        return l2s_map

    def correct_surf(self, token):
        region_list = [
            "district of", "district", "city of", "state of", "province of",
            "division", "city", "valley", "province"
        ]
        token = token.lower()
        for i in region_list:
            token = token.replace(i, "").strip()
        return token

    def default_type(self, l2s_map, max_cand):
        l2s_map = dict(
            sorted(l2s_map.items(), key=operator.itemgetter(1), reverse=True))
        # del l2s_map[max_cand]
        max_cand_name, max_cand_eid = max_cand.split("|")[1], max_cand.split(
            "|")[0]
        max_cand_name = self.correct_surf(max_cand_name)
        if "feature_class" in self.en_id2t[max_cand_eid]:
            eid_type = self.en_id2t[max_cand_eid]["feature_class"]
        else:
            return max_cand
        capital_features = ["PPLA", "PPLA2", "PPLC"]
        district_set = pickle.load(
            open(
                f"/shared/experiments/xyu71/lorelei2017/src/lorelei_kb/IL11-12-District/dis_il{self.lang}.pickle",
                "rb"))
        if "feature_code" in self.en_id2t[max_cand_eid]:
            eid_fcode = self.en_id2t[max_cand_eid]["feature_code"][2:]
        else:
            eid_fcode = ""

        if self.lang == "ilo":
            if eid_fcode not in capital_features:
                if max_cand_name in district_set and eid_type == "P":
                    for k, v in l2s_map.items():
                        k_name, k_id = self.correct_surf(
                            k.split("|")[1]), k.split("|")[0]
                        if "feature_class" in self.en_id2t[k_id]:
                            if k_name == max_cand_name and self.en_id2t[k_id][
                                    "feature_class"] == "A":
                                return k
        elif self.lang == "or":
            if eid_type == "P":
                for k, v in l2s_map.items():
                    k_name, k_id = self.correct_surf(
                        k.split("|")[1]), k.split("|")[0]
                    if "feature_class" in self.en_id2t[k_id]:
                        if k_name == max_cand_name and self.en_id2t[k_id][
                                "feature_class"] == "A":
                            return k
        return max_cand

    def get_maxes_l2s_map(self, l2s_map):
        # pick max
        if len(l2s_map) == 0:
            max_cand, max_score = "NIL", 1.0
        else:
            maxes_l2s_map = {
                cand: score
                for cand, score in l2s_map.items()
                if score == max(l2s_map.values())
            }
            max_cand = list(maxes_l2s_map.keys())[0]
            max_score = l2s_map[max_cand]
        return max_cand, max_score

    def compute_hits_for_ta(self,
                            docta,
                            outfile=None,
                            only_nils=False,
                            args=None):
        try:
            ner_view = docta.get_view("NER_CONLL")
            # rom_view = docta.get_view("ROMANIZATION")
        except:
            return
        candgen_view_json = copy.deepcopy(ner_view.as_json)
        text = docta.text
        predict_mode = True

        if "constituents" not in candgen_view_json["viewData"][0]:
            return
        for idx, cons in enumerate(
                candgen_view_json["viewData"][0]["constituents"]):
            self.total += 1
            orig_query_str = cons["tokens"]
            ner_type = cons["label"]
            # rom_query_str = self.get_romanized(cons, rom_view)

            query_str = self.clean_query(orig_query_str)
            eids, eids_google, eids_hindi, eids_google_maps, eids_spell, eids_gtrans, eids_trans, eids_wikicg = [], [], [], [], [], [], [], []

            # query_str = "ଭଦ୍ରକରେ"
            # self.lang = 'or'
            # ner_type = 'LOC'

            mention_cheap_dict = {}  #mention2eid[self.lang]
            if query_str in mention_cheap_dict:
                eids_total = mention_cheap_dict[query_str]
            else:
                eids, eids_google, eids_hindi, eids_google_maps, eids_spell, eids_gtrans, eids_trans, eids_wikicg = self.get_lorelei_candidates(
                    orig_query_str, query_str, ner_type=ner_type, args=args)
                eids_total = eids + eids_google + eids_hindi + eids_google_maps + eids_spell + eids_gtrans + eids_trans + eids_wikicg

            logging.info("got %d candidates for query:%s",
                         len(set(eids_total)), orig_query_str)

            # suggest to put inlink into scoring function?

            l2s_map = self.get_l2s_map(eids,
                                       eids_google,
                                       eids_hindi,
                                       eids_google_maps,
                                       eids_spell,
                                       eids_gtrans,
                                       eids_trans,
                                       eids_wikicg,
                                       eids_total,
                                       ner_type=ner_type,
                                       query_str=orig_query_str,
                                       text=text,
                                       args=args)
            l2s_map = dict((x, y) for x, y in sorted(
                l2s_map.items(), key=operator.itemgetter(1), reverse=True))
            logging.info(
                f"got {len(l2s_map)} candidates after ranking for {orig_query_str}: {l2s_map}"
            )
            max_cand, max_score = self.get_maxes_l2s_map(l2s_map)

            # if len(l2s_map) > 0 and args.lang in ["ilo", "or"]:
            #     max_cand_default = self.default_type(l2s_map, max_cand)
            #     if max_cand_default != max_cand:
            #         max_score = 1.0
            #         max_cand = max_cand_default
            #     logging.info(f"got candidate after default")

            if len(l2s_map) > 0:
                # do not send empty label2scoremaps!
                cons["labelScoreMap"] = l2s_map
            cons["label"] = max_cand
            cons["score"] = max_score

        candgen_view_json["viewName"] = "CANDGEN"
        candgen_view = View(candgen_view_json, docta.get_tokens)
        docta.view_dictionary["CANDGEN"] = candgen_view
        docta_json = docta.as_json
        if outfile is not None:
            with open(outfile, 'w', encoding='utf-8') as f:
                json.dump(docta_json, f, ensure_ascii=False, indent=True)
        # json.dump(docta_json, open(outfile, "w"), indent=True)
        # self.report(predict_mode=predict_mode)
        return candgen_view_json

    def get_lorelei_candidates(self,
                               orig_query_str,
                               query_str,
                               romanized_query_str=None,
                               ner_type=None,
                               args=None):
        # Cheap Dict
        if query_str in self.cheap_dict:
            logging.info("found %s in dictionary!", query_str)
            query_str = self.cheap_dict[query_str]

        # SKB+SG
        desuf_query_str = remove_suffix(query_str, self.lang)
        dot_query_str_list = correct_surface(query_str, self.lang)
        desuf_dot_query_str_list = correct_surface(desuf_query_str, self.lang)
        # SKB suffix dot suffix+dot
        eids = self._exact_match_kb(query_str, args)
        # if not eids:
        eids += self._exact_match_kb(desuf_query_str, args)
        for i in dot_query_str_list:
            eids += self._exact_match_kb(i, args)
        for i in desuf_dot_query_str_list:
            eids += self._exact_match_kb(i, args)
        eids = list(set(eids))
        # logging.info("got %d candidates for query:%s from exact match", len(eids), query_str)

        eids_google = []
        if args.google and ((not args.wikidata) or
                            (args.wikidata and len(eids_wikicg) == 0)):
            # SG suffix dot suffix+dot
            eids_google, g_wikititles = self._get_candidates_google(
                query_str, top_num=args.google_top)
            # eids_googlew, g_wikititlesw = self._get_candidates_google( query_str + ' wiki', top_num=args.google_top)
            # eids_googlewl, g_wikititleswl = self._get_candidates_google(query_str + ' wiki ' + lang2whole.get(self.lang, self.lang), top_num=args.google_top)
            # eids_googlewc, g_wikititleswc = self._get_candidates_google(query_str + ' wiki ' + lang2country.get(self.lang, lang2whole.get(self.lang, self.lang)), top_num=args.google_top)
            # for e in eids_googlew + eids_googlewl +  eids_googlewc:
            #     if e not in eids_google:
            #         eids_google.append(e)
            # for t in g_wikititlesw + g_wikititleswl + g_wikititleswc:
            #     if t not in g_wikititles:
            #         g_wikititles.append(t)

            # if not eids_google:
            for i in [desuf_query_str
                      ] + dot_query_str_list + desuf_dot_query_str_list:
                es, ts = self._get_candidates_google(i,
                                                     top_num=args.google_top)
                for e in es:
                    if e not in eids_google:
                        eids_google.append(e)
                for t in ts:
                    if t not in g_wikititles:
                        g_wikititles.append(t)

        eids_google = list(set(eids_google))
        logging.info("got %d candidates for query:%s from google",
                     len(eids_google), query_str)

        eids_google_maps = []
        if ner_type in ['GPE', 'LOC'] and args.google_map and (
            (not args.wikidata) or (args.wikidata and len(eids_wikicg) == 0)):
            google_map_name = query2gmap_api(query_str, self.lang)
            eids_google_maps += self._exact_match_kb(google_map_name, args)
            eids_google_maps += self._get_candidates_google(
                google_map_name, lang='en', top_num=args.google_top)[0]
            # if not eids_google_maps:
            google_map_name_suf = query2gmap_api(desuf_query_str, self.lang)
            google_map_name_dot = [
                query2gmap_api(k, self.lang) for k in dot_query_str_list
            ]
            google_map_name_suf_dot = [
                query2gmap_api(k, self.lang) for k in desuf_dot_query_str_list
            ]
            eids_google_maps += self._exact_match_kb(google_map_name_suf, args)
            eids_google_maps += self._get_candidates_google(
                google_map_name_suf, lang='en', top_num=args.google_top)[0]

            eids_google_maps += [
                h for k in google_map_name_dot
                for h in self._exact_match_kb(k, args)
            ]
            eids_google_maps += [
                h for k in google_map_name_dot
                for h in self._get_candidates_google(
                    k, lang='en', top_num=args.google_top)[0]
            ]

            eids_google_maps += [
                h for k in google_map_name_suf_dot
                for h in self._exact_match_kb(k, args)
            ]
            eids_google_maps += [
                h for k in google_map_name_suf_dot
                for h in self._get_candidates_google(
                    k, lang='en', top_num=args.google_top)[0]
            ]
        eids_google_maps = list(set(eids_google_maps))
        # logging.info("got %d candidates for query:%s from google map", len(set(eids_google_maps)), query_str)

        if args.wikicg:
            wiki_titles, wids, wid_cprobs = self._extract_ptm_cands(
                self.wiki_cg.get_candidates(surface=orig_query_str))
            # eids_wikicg = wids
            eids_wikicg = []
            for w in wids:
                if self.en_id2t[w]["name"] not in eids_wikicg:
                    eids_wikicg.append(self.en_id2t[w]["name"])
            if args.google + args.google_map == 0:
                eids = []
        else:
            eids_wikicg = []

        eids_hindi = []
        if args.pivoting:
            if self.lang == 'or':
                if len(eids) + len(eids_google) + len(eids_google_maps) == 0:
                    orgin2hin = or2hindi(query_str)
                    eids_hindi += self._get_candidates_google(
                        orgin2hin, lang='hi', top_num=args.google_top)[0]
                    # if not eids_hindi:
                    suf2hin = or2hindi(desuf_query_str)
                    dot2hin = [or2hindi(k) for k in dot_query_str_list]
                    suf_dot2hin = [
                        or2hindi(k) for k in desuf_dot_query_str_list
                    ]
                    eids_hindi += self._get_candidates_google(
                        suf2hin, lang='hi', top_num=args.google_top)[0]
                    eids_hindi += [
                        h for k in dot2hin
                        for h in self._get_candidates_google(
                            k, lang='hi', top_num=args.google_top)[0]
                    ]
                    eids_hindi += [
                        h for k in suf_dot2hin
                        for h in self._get_candidates_google(
                            k, lang='hi', top_num=args.google_top)[0]
                    ]
            else:
                if len(eids) + len(eids_google) + len(eids_google_maps) == 0:
                    eids_hindi += self._get_candidates_google(
                        query_str + 'wiki',
                        top_num=args.google_top,
                        include_all_lang=True)[0]
            eids_hindi = list(set(eids_hindi))
        # logging.info("got %d candidates for query:%s from hindi", len(eids_hindi), query_str)

        eids_spell = []
        if args.spell and len(eids + eids_google + eids_google_maps +
                              eids_hindi) == 0:
            self.or_checker = self.load_or_checker()
            if len(query_str) < 7:
                corrected = ' '.join([
                    self.or_checker.correction(token)
                    for token in query_str.split()
                ])
                if corrected != query_str:
                    eids_spell += self._exact_match_kb(corrected, args)
                    eids_spell += self._get_candidates_google(
                        corrected, lang=self.lang, top_num=args.google_top)[0]
                    # if not eids_spell:
                    spell_suf = ' '.join([
                        self.or_checker.correction(token)
                        for token in desuf_query_str.split()
                    ])
                    spell_dot = [
                        ' '.join([self.or_checker.correction(token)])
                        for tokens in dot_query_str_list
                        for token in tokens.split()
                    ]
                    spell_suf_dot = [
                        ' '.join([self.or_checker.correction(token)])
                        for tokens in desuf_dot_query_str_list
                        for token in tokens.split()
                    ]

                    eids_spell += self._exact_match_kb(spell_suf, args)
                    eids_spell += self._get_candidates_google(
                        spell_suf, lang=self.lang, top_num=args.google_top)[0]

                    eids_spell += [
                        h for k in spell_dot
                        for h in self._exact_match_kb(k, args)
                    ]
                    eids_spell += [
                        h for k in spell_dot
                        for h in self._get_candidates_google(
                            k, lang=self.lang, top_num=args.google_top)
                    ][0]

                    eids_spell += [
                        h for k in spell_suf_dot
                        for h in self._exact_match_kb(k, args)
                    ]
                    eids_spell += [
                        h for k in spell_suf_dot
                        for h in self._get_candidates_google(
                            k, lang=self.lang, top_num=args.google_top)
                    ][0]
            eids_spell = list(set(eids_spell))
        # logging.info("got %d candidates for query:%s from spell", len(set(eids_spell)), query_str)

        eids_gtrans = []
        if args.g_trans and len(eids + eids_google + eids_google_maps +
                                eids_hindi + eids_spell) == 0:
            il2gt = MongoBackedDict(dbname=f'il{self.lang}toGT')
            il2per = MongoBackedDict(dbname=f'il{self.lang}toPER')
            il2org = MongoBackedDict(dbname=f'il{self.lang}toORG')

            gt2kb = MongoBackedDict(dbname=f'gt_name2id_il{self.lang}')
            per2kb = MongoBackedDict(dbname=f'peo_name2id_il{self.lang}')
            org2kb = MongoBackedDict(dbname=f'org_name2id_il{self.lang}')

            if query_str in il2gt:
                eids_gtrans += gt2kb[il2gt[query_str]]
            if query_str in il2per:
                eids_gtrans += [per2kb[il2per[query_str]]]
            if query_str in il2org:
                eids_gtrans += [org2kb[il2org[query_str]]]
        eids_gtrans = list(set(eids_gtrans))
        # logging.info("got %d candidates for query:%s from google trans", len(set(eids_gtrans)), query_str)

        eids_trans = []
        if len(eids + eids_google + eids_google_maps + eids_hindi +
               eids_spell + eids_gtrans) == 0 and args.tsl:
            translited = phrase_translit(query_str, self.concept_pair,
                                         self.translit_model,
                                         self.spellchecker, self.translit_dict)
            for item in translited:
                eids_trans += self._exact_match_kb(item, args)
                eids_trans += self._get_candidates_google(
                    item, lang='en', top_num=args.google_top)[0]
            eids_trans = list(set(eids_trans))
        # logging.info("got %d candidates for query:%s from transliteration", len(set(eids_trans)), query_str)

        return eids, eids_google, eids_hindi, eids_google_maps, eids_spell, eids_gtrans, eids_trans, eids_wikicg

    def _get_candidates_google(self,
                               surface,
                               top_num=1,
                               lang=None,
                               include_all_lang=False):
        eids = []
        wikititles = []
        if surface is None or len(surface) < 2:
            return eids, wikititles
        if surface in self.cheap_dict:
            surface = self.cheap_dict[surface]
        if lang is None:
            lang = self.lang
        en_surfaces = query2enwiki(surface,
                                   lang,
                                   include_all_lang=include_all_lang)[:top_num]
        for en in en_surfaces:
            if en not in wikititles:
                wikititles.append(en)
        # wikititles += en_surfaces
        eids = wikititles
        # query_str_list = en_surfaces
        # for surf in en_surfaces:
        #     query_str_list += correct_surface(surf, self.lang)

        # for s in query_str_list:
        # eids += [self.m[s.replace(' ', '_')]]
        # eids += self.get_phrase_cands(s)
        # logging.info("#direct cands (phrase): %d", len(eids))
        return eids, wikititles

    def _extract_ptm_cands(self, cands):
        wiki_titles, wids, wid_cprobs = [], [], []
        for cand in cands:
            wikititle, p_t_given_s, p_s_given_t = cand.en_title, cand.p_t_given_s, cand.p_s_given_t
            nrm_title = self.en_normalizer.normalize(wikititle)
            if nrm_title == K.NULL_TITLE:  # REMOVED or nrm_title not in en_normalizer.title2id
                logging.info("bad cand %s nrm=%s", wikititle, nrm_title)
                continue
            wiki_id = self.en_normalizer.title2id[nrm_title]
            # if wiki_id is none
            if wiki_id is None:
                wiki_id = self.en_normalizer.title2id[wikititle]
                if wiki_id is None:
                    continue
                wiki_titles.append(wikititle)
                wids.append(wiki_id)
                wid_cprobs.append(p_t_given_s)
                continue
            wiki_titles.append(nrm_title)
            wids.append(wiki_id)
            wid_cprobs.append(p_t_given_s)
        return wiki_titles, wids, wid_cprobs

    def _exact_match_kb(self, surface, args):
        eids = []
        if surface is None:
            return eids
        # surface = surface.lower()
        # logging.info("===> query string:%s len:%d", surface, len(surface))
        if len(surface) < 2:
            # logging.info("too short a query string")
            return []
        if surface in self.cheap_dict:
            # logging.info("found %s in dictionary!", surface)
            surface = self.cheap_dict[surface]

        # Exact Match
        eids += self.get_phrase_cands(surface)
        return eids

    def get_phrase_cands(self, surf):
        surf = surf.lower()
        ans = []
        if surf in self.t2id:
            cands = self.t2id[surf]
            # logging.info("#phrase cands geoname %d for %s", len(cands), surf)
            ans += cands
        if surf in self.en_t2id:
            cands = self.en_t2id[surf]
            if len(cands) > 0:
                cand = self.en_id2t[cands[0]]["name"]
                ans += [cand]
        return ans

    def report(self, predict_mode):
        if not predict_mode:
            logging.info("total_hits %d/%d=%.3f", self.total_hits, self.total,
                         self.total_hits / self.total)
            logging.info("prior correct %d/%d=%.3f", self.prior_correct,
                         self.total, self.prior_correct / self.total)
            logging.info("nil correct %d/%d=%.3f", self.nil_correct,
                         self.total, self.nil_correct / self.total)
        else:
            logging.info("saw total %d", self.total)

    def get_romanized(self, cons, rom_view):
        overlap_cons = rom_view.get_overlapping_constituents(
            cons['start'], cons['end'])
        romanized = " ".join([c["label"] for c in overlap_cons[1:-1]])
        logging.info("str:%s romanized:%s", cons["tokens"], romanized)
        return romanized
Exemplo n.º 12
0
class CandGen:
    def __init__(self,
                 lang=None,
                 year=None,
                 wiki_cg=None,
                 google_api_cx=None,
                 google_api_key=None):
        self.lang = lang
        self.year = year
        self.wiki_cg = wiki_cg
        self.en_normalizer = TitleNormalizer(lang="en")
        self.google_api_cx = google_api_cx
        self.google_api_key = google_api_key

    def load_kb(self, kbdir):
        self.en_t2id = MongoBackedDict(dbname=f"en_t2id")
        self.en_id2t = MongoBackedDict(dbname=f"en_id2t")
        en_id2t_filepath = os.path.join(kbdir, "enwiki", "idmap",
                                        f'enwiki-{self.year}.id2t')
        self.fr2entitles = MongoBackedDict(dbname=f"{self.lang}2entitles")
        fr2entitles_filepath = os.path.join(kbdir, f'{self.lang}wiki', 'idmap',
                                            f'fr2entitles')
        self.t2id = MongoBackedDict(dbname=f"{self.lang}_t2id")

        if self.en_t2id.size() == 0 or self.en_id2t.size() == 0:
            logging.info(f'Loading en t2id and id2t...')
            self.en_t2id.drop_collection()
            self.en_id2t.drop_collection()
            en_id2t = []
            ent2id = defaultdict(list)
            for line in tqdm(open(en_id2t_filepath)):
                parts = line.strip().split("\t")
                if len(parts) != 3:
                    logging.info("bad line %s", line)
                    continue
                page_id, page_title, is_redirect = parts
                key = page_title.replace('_', ' ').lower()
                ent2id[key].append(page_id)
                en_id2t.append({
                    'key': page_id,
                    'value': {
                        'page_id': page_id,
                        'name': page_title,
                        'searchname': key
                    },
                    'redirect': is_redirect
                })
            ent2id_list = []
            for k, v in ent2id.items():
                ent2id_list.append({'key': k, 'value': v})
            logging.info("inserting %d entries into english t2id",
                         len(ent2id_list))
            self.en_t2id.cll.insert_many(ent2id_list)
            self.en_t2id.cll.create_index([("key", pymongo.HASHED)])

            logging.info("inserting %d entries into english id2t",
                         len(en_id2t))
            self.en_id2t.cll.insert_many(en_id2t)
            self.en_id2t.cll.create_index([("key", pymongo.HASHED)])

        if self.fr2entitles.size() == 0:
            logging.info(f'Loading fr2entitles and {self.lang} t2id...')
            fr2en = []
            t2id = []
            f = open(fr2entitles_filepath)
            for idx, l in enumerate(f):
                parts = l.strip().split("\t")
                if len(parts) != 2:
                    logging.info("error on line %d %s", idx, parts)
                    continue
                frtitle, entitle = parts
                key = frtitle.replace('_', ' ').lower()
                enkey = entitle.replace('_', ' ').lower()
                fr2en.append({
                    "key": key,
                    "value": {
                        'frtitle': frtitle,
                        'entitle': entitle,
                        'enkey': enkey
                    }
                })
                t2id.append({"key": key, "value": self.en_t2id[enkey]})
            logging.info(f"inserting %d entries into {self.lang}2entitles",
                         len(fr2en))
            self.fr2entitles.cll.insert_many(fr2en)
            self.fr2entitles.cll.create_index([("key", pymongo.HASHED)])
            logging.info(f"inserting %d entries into {self.lang} t2id",
                         len(t2id))
            self.t2id.cll.insert_many(t2id)
            self.t2id.cll.create_index([("key", pymongo.HASHED)])

    def clean_query(self, query_str):
        if self.lang == 'rw':
            if 'North' == query_str[-5:]:
                query_str = query_str[:-5] + ' ' + query_str[-5:]

        if query_str[-1] == ",":
            query_str = query_str[:-1]
        f = re.compile('(#|\(|\)|@)')
        query_str = f.sub(' ', query_str)
        query_str = re.sub('\s+', ' ', query_str).strip()
        return query_str.lower()

    def init_l2s_map(self, eids, args=None):

        l2s_map = {}
        for eid in eids:
            flag = False
            title = (self.en_id2t[eid]["name"] if eid in self.en_id2t else
                     eid.replace(' ', '_')).lower()
            if 'category:' in title:
                continue
            key = "|".join([eid, title])
            if not flag and not key in l2s_map:
                l2s_map[key] = 100
        return l2s_map

    def cross_check_score(self, l2s_map, eids):
        freq = dict(Counter(eids))
        for cand, v in l2s_map.items():
            cand_eid = cand.split("|")[0]
            l2s_map[cand] = l2s_map[cand] * (
                3**freq[cand_eid]) if cand_eid in eids else l2s_map[cand] * 0.1
        return l2s_map

    def bert_score(self, l2smap, query_emb, l2s_map, args):
        cand2sim = {}
        max_cand = None
        max_sim = -1000
        no_sum = 0
        not_in_sum = 0
        for cand in l2smap:
            cand_title = cand.split("|")[0]
            # request summary
            if cand_title in args.eid2wikisummary:
                summary = args.eid2wikisummary[cand_title]
            else:
                summary = get_wiki_summary(cand_title)
                args.eid2wikisummary.cll.insert_one({
                    "key": cand_title,
                    "value": summary
                })
            summary = summary.lower()
            cand_name = cand_title.lower()
            if summary == '':
                no_sum += 1
            # bert
            else:
                if cand_name in summary:
                    cand_emb = s2maskedvec(mask_sents(cand_name, summary))
                    sim = cosine_similarity([cand_emb], [query_emb])[0][0]
                    cand2sim[cand] = sim
                    if sim > max_sim:
                        max_sim = sim
                        max_cand = cand
                else:
                    not_in_sum += 1
                    continue
        logging.info(
            f"{no_sum} / {len(l2s_map)} dont have summary, {not_in_sum} / {len(l2s_map) - no_sum} not in summary"
        )
        if len(cand2sim) > 1:
            l2s_map[max_cand] *= 3 * (len(l2s_map) - no_sum -
                                      not_in_sum) / len(l2s_map)
        return l2s_map

    def get_context(self, query_str, text, k=10):
        if query_str in text:
            tokenizer = MWETokenizer()
            query_str_tokens = tuple(query_str.split())
            query_str_dashed = "_".join(query_str_tokens)
            tokenizer.add_mwe(query_str_tokens)
            text_token = tokenizer.tokenize(text.split())
            try:
                t_start = text_token.index(query_str_dashed)
            except:
                return None, None, None
            t_end = t_start + 1
            start_index = max(t_start - k, 0)
            end_index = min(t_end + k, len(text_token))
            text_token_query = text_token[start_index:t_start] + text_token[
                t_end + 1:end_index]
            context = " ".join(text_token_query)
            context_mention = text_token[start_index:t_start] + [
                query_str
            ] + text_token[t_end + 1:end_index]
            context_mention = " ".join(context_mention)
            return context, text_token_query, context_mention
        else:
            logging.info('error, query not in text')
            return None, None, None

    def get_l2s_map(self, eids, eids_google, eids_pivot, eids_google_maps,
                    eids_wikicg, eids_total, ner_type, query_str, text, args):
        if args.wikidata:
            l2s_map = self.init_l2s_map(eids_wikicg + eids + eids_google +
                                        eids_pivot + eids_google_maps,
                                        args=args)
        else:
            l2s_map = self.init_l2s_map(eids_total, args=args)

        # check if generated cadidates
        if len(l2s_map) <= 1:
            return l2s_map
        if ner_type in ['GPE', 'LOC']:
            l2s_map = self.cross_check_score(l2s_map,
                                             eids_google + eids_google_maps)
        if args.wikidata:
            l2s_map = self.cross_check_score(l2s_map, eids_wikicg)
        else:
            l2s_map = self.cross_check_score(
                l2s_map, eids + eids_google + eids_google_maps)
        #update score
        for cand in l2s_map.copy().keys():
            cand_eid, cand_text = cand.split("|")
            score = 1
            if self.lang == 'rw':
                if 'rwanda' in cand_text.split('_'):
                    score *= 3
            l2s_map[cand] *= score
        logging.info("Processed looping candidates")

        # get context:
        if args.bert:
            context, context_tokens, context_mention = self.get_context(
                query_str, text, k=10)
            # check context bert
            if context is not None:
                logging.info("Processing candidates bert")
                query_emb = s2maskedvec(mask_sents(query_str, context_mention))
                l2s_map = self.bert_score(l2s_map, query_emb, l2s_map, args)

        # Normalize
        sum_s = sum(list(l2s_map.values()))
        for can, s in l2s_map.items():
            l2s_map[can] = s / sum_s
        return l2s_map

    def correct_surf(self, token):
        region_list = [
            "district of", "district", "city of", "state of", "province of",
            "division", "city", "valley", "province"
        ]
        token = token.lower()
        for i in region_list:
            token = token.replace(i, "").strip()
        return token

    def get_maxes_l2s_map(self, l2s_map):
        # pick max
        if len(l2s_map) == 0:
            max_cand, max_score = "NIL", 1.0
        else:
            maxes_l2s_map = {
                cand: score
                for cand, score in l2s_map.items()
                if score == max(l2s_map.values())
            }
            max_cand = list(maxes_l2s_map.keys())[0]
            max_score = l2s_map[max_cand]
        return max_cand, max_score

    def compute_hits_for_ta(self, input_json, outfile=None, args=None):
        output_json = copy.deepcopy(input_json)
        ner_entities = output_json["NER"]
        text = output_json['text']

        for idx, cons in enumerate(ner_entities):
            orig_query_str = cons["tokens"]
            ner_type = cons["ner_type"]

            query_str = self.clean_query(orig_query_str)
            eids, eids_google, eids_pivot, eids_google_maps, eids_wikicg = self.get_all_candidates(
                orig_query_str, query_str, ner_type=ner_type, args=args)
            eids_total = eids + eids_google + eids_pivot + eids_google_maps + eids_wikicg

            logging.info("got %d candidates for query:%s",
                         len(set(eids_total)), orig_query_str)

            l2s_map = self.get_l2s_map(eids,
                                       eids_google,
                                       eids_pivot,
                                       eids_google_maps,
                                       eids_wikicg,
                                       eids_total,
                                       ner_type=ner_type,
                                       query_str=orig_query_str,
                                       text=text,
                                       args=args)
            l2s_map = dict((x, y) for x, y in sorted(
                l2s_map.items(), key=operator.itemgetter(1), reverse=True))
            logging.info(
                f"got {len(l2s_map)} candidates after ranking for {orig_query_str}: {l2s_map}"
            )
            max_cand, max_score = self.get_maxes_l2s_map(l2s_map)

            if len(l2s_map) > 0:
                # do not send empty label2scoremaps!
                cons["labelScoreMap"] = l2s_map
            cons["label"] = max_cand
            cons["score"] = max_score
        if outfile is not None:
            json.dump(output_json, open(outfile, 'w'))
        return output_json

    def get_all_candidates(self,
                           orig_query_str,
                           query_str,
                           ner_type=None,
                           args=None):
        # SKB+SG
        desuf_query_str = remove_suffix(query_str, self.lang)
        dot_query_str_list = correct_surface(query_str, self.lang)
        desuf_dot_query_str_list = correct_surface(desuf_query_str, self.lang)
        # SKB suffix dot suffix+dot
        eids = self._exact_match_kb(query_str, args)
        eids += self._exact_match_kb(desuf_query_str, args)
        for i in dot_query_str_list:
            eids += self._exact_match_kb(i, args)
        for i in desuf_dot_query_str_list:
            eids += self._exact_match_kb(i, args)
        eids = list(set(eids))

        if args.wikicg:
            wiki_titles, wids, wid_cprobs = self._extract_ptm_cands(
                self.wiki_cg.get_candidates(surface=orig_query_str))
            eids_wikicg = []
            for w in wids:
                if self.en_id2t[w]["name"] not in eids_wikicg:
                    eids_wikicg.append(self.en_id2t[w]["name"])
            if args.google + args.google_map == 0:
                eids = []
        else:
            eids_wikicg = []

        eids_google = []
        if args.google and ((not args.wikidata) or
                            (args.wikidata and len(eids_wikicg) == 0)):
            # SG suffix dot suffix+dot
            eids_google, g_wikititles = self._get_candidates_google(
                query_str, top_num=args.google_top)

            # if not eids_google:
            for i in [desuf_query_str
                      ] + dot_query_str_list + desuf_dot_query_str_list:
                es, ts = self._get_candidates_google(i,
                                                     top_num=args.google_top)
                for e in es:
                    if e not in eids_google:
                        eids_google.append(e)
                for t in ts:
                    if t not in g_wikititles:
                        g_wikititles.append(t)
        eids_google = list(set(eids_google))
        logging.info("got %d candidates for query:%s from google",
                     len(eids_google), query_str)

        eids_google_maps = []
        if ner_type in ['GPE', 'LOC'] and args.google_map and (
            (not args.wikidata) or (args.wikidata and len(eids_wikicg) == 0)):
            google_map_name = query2gmap_api(query_str, self.lang,
                                             self.google_api_key)
            eids_google_maps += self._exact_match_kb(google_map_name, args)
            eids_google_maps += self._get_candidates_google(
                google_map_name, lang='en', top_num=args.google_top)[0]
            google_map_name_suf = query2gmap_api(desuf_query_str, self.lang,
                                                 self.google_api_key)
            google_map_name_dot = [
                query2gmap_api(k, self.lang, self.google_api_key)
                for k in dot_query_str_list
            ]
            google_map_name_suf_dot = [
                query2gmap_api(k, self.lang, self.google_api_key)
                for k in desuf_dot_query_str_list
            ]
            eids_google_maps += self._exact_match_kb(google_map_name_suf, args)
            eids_google_maps += self._get_candidates_google(
                google_map_name_suf, lang='en', top_num=args.google_top)[0]
            eids_google_maps += [
                h for k in google_map_name_dot
                for h in self._exact_match_kb(k, args)
            ]
            eids_google_maps += [
                h for k in google_map_name_dot
                for h in self._get_candidates_google(
                    k, lang='en', top_num=args.google_top)[0]
            ]
            eids_google_maps += [
                h for k in google_map_name_suf_dot
                for h in self._exact_match_kb(k, args)
            ]
            eids_google_maps += [
                h for k in google_map_name_suf_dot
                for h in self._get_candidates_google(
                    k, lang='en', top_num=args.google_top)[0]
            ]
        eids_google_maps = list(set(eids_google_maps))
        logging.info("got %d candidates for query:%s from google map",
                     len(set(eids_google_maps)), query_str)

        eids_pivot = []
        if args.pivoting:
            if self.lang == 'or':
                if len(eids) + len(eids_google) + len(eids_google_maps) == 0:
                    orgin2hin = or2hindi(query_str)
                    eids_pivot += self._get_candidates_google(
                        orgin2hin, lang='hi', top_num=args.google_top)[0]
                    suf2hin = or2hindi(desuf_query_str)
                    dot2hin = [or2hindi(k) for k in dot_query_str_list]
                    suf_dot2hin = [
                        or2hindi(k) for k in desuf_dot_query_str_list
                    ]
                    eids_pivot += self._get_candidates_google(
                        suf2hin, lang='hi', top_num=args.google_top)[0]
                    eids_pivot += [
                        h for k in dot2hin
                        for h in self._get_candidates_google(
                            k, lang='hi', top_num=args.google_top)[0]
                    ]
                    eids_pivot += [
                        h for k in suf_dot2hin
                        for h in self._get_candidates_google(
                            k, lang='hi', top_num=args.google_top)[0]
                    ]
            else:
                if len(eids) + len(eids_google) + len(eids_google_maps) == 0:
                    eids_pivot += self._get_candidates_google(
                        query_str + 'wiki',
                        top_num=args.google_top,
                        include_all_lang=True)[0]
            eids_pivot = list(set(eids_pivot))
        logging.info("got %d candidates for query:%s from pivoting",
                     len(eids_pivot), query_str)

        return eids, eids_google, eids_pivot, eids_google_maps, eids_wikicg

    def _get_candidates_google(self,
                               surface,
                               top_num=1,
                               lang=None,
                               include_all_lang=False):
        eids = []
        wikititles = []
        if surface is None or len(surface) < 2:
            return eids, wikititles
        if lang is None:
            lang = self.lang
        en_surfaces = query2enwiki(surface,
                                   lang,
                                   self.google_api_key,
                                   self.google_api_cx,
                                   include_all_lang=include_all_lang)[:top_num]
        for en in en_surfaces:
            if en not in wikititles:
                wikititles.append(en)
        eids = wikititles
        return eids, wikititles

    def _extract_ptm_cands(self, cands):
        wiki_titles, wids, wid_cprobs = [], [], []
        for cand in cands:
            wikititle, p_t_given_s, p_s_given_t = cand.en_title, cand.p_t_given_s, cand.p_s_given_t
            nrm_title = self.en_normalizer.normalize(wikititle)
            if nrm_title == K.NULL_TITLE:
                logging.info("bad cand %s nrm=%s", wikititle, nrm_title)
                continue
            wiki_id = self.en_normalizer.title2id[nrm_title]
            if wiki_id is None:
                wiki_id = self.en_normalizer.title2id[wikititle]
                if wiki_id is None:
                    continue
                wiki_titles.append(wikititle)
                wids.append(wiki_id)
                wid_cprobs.append(p_t_given_s)
                continue
            wiki_titles.append(nrm_title)
            wids.append(wiki_id)
            wid_cprobs.append(p_t_given_s)
        return wiki_titles, wids, wid_cprobs

    def _exact_match_kb(self, surface):
        eids = []
        if surface is None:
            return eids
        if len(surface) < 2:
            return []

        # Exact Match
        eids += self.get_phrase_cands(surface)
        return eids

    def get_phrase_cands(self, surf):
        surf = surf.lower()
        ans = []
        if surf in self.t2id:
            cands = self.t2id[surf]
            ans += cands
        if surf in self.en_t2id:
            cands = self.en_t2id[surf]
            if len(cands) > 0:
                cand = self.en_id2t[cands[0]]["name"]
                ans += [cand]
        return ans