Example #1
0
    def __init__(self,
                 load_path: str,
                 wiki_filename: str,
                 entities_filename: str,
                 inverted_index_filename: str,
                 id_to_name_file: str,
                 lemmatize: bool = True,
                 debug: bool = False,
                 rule_filter_entities: bool = True,
                 use_inverted_index: bool = True,
                 language: str = 'rus',
                 *args,
                 **kwargs) -> None:
        """

        Args:
            load_path: path to folder with wikidata files
            wiki_filename: file with Wikidata triplets
            entities_filename: file with dict of entity titles (keys) and entity ids (values)
            inverted_index_filename: file with dict of words (keys) and entities containing these words (values)
            id_to_name_file: file with dict of entity ids (keys) and entities names and aliases (values)
            lemmatize: whether to lemmatize tokens of extracted entity
            debug: whether to print entities extracted from Wikidata
            rule_filter_entities: whether to filter entities which do not fit the question
            use_inverted_index: whether to use inverted index for entity linking
            language - the language of the linker (used for filtration of some questions to improve overall performance)
            *args:
            **kwargs:
        """
        super().__init__(save_path=None, load_path=load_path)
        self.morph = pymorphy2.MorphAnalyzer()
        self.lemmatize = lemmatize
        self.debug = debug
        self.rule_filter_entities = rule_filter_entities
        self.use_inverted_index = use_inverted_index
        self._language = language
        if language not in self.LANGUAGES:
            log.warning(
                f'EntityLinker supports only the following languages: {self.LANGUAGES}'
            )

        self._wiki_filename = wiki_filename
        self._entities_filename = entities_filename
        self.inverted_index_filename = inverted_index_filename
        self.id_to_name_file = id_to_name_file

        self.name_to_q: Optional[Dict[str, List[Tuple[str]]]] = None
        self.wikidata: Optional[Dict[str, List[List[str]]]] = None
        self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None
        self.id_to_name: Optional[Dict[str, Dict[List[str]]]] = None
        self.load()
        if self.use_inverted_index:
            alphabet = "abcdefghijklmnopqrstuvwxyzабвгдеёжзийклмнопрстуфхцчшщъыьэюя1234567890-_()=+!?.,/;:&@<>|#$%^*"
            dictionary_words = list(self.inverted_index.keys())
            self.searcher = LevenshteinSearcher(alphabet, dictionary_words)
Example #2
0
    def __init__(self, load_path: str,
                 inverted_index_filename: str,
                 entities_list_filename: str,
                 q2name_filename: str,
                 save_path: str = None,
                 q2descr_filename: str = None,
                 rel_ranker: RelRankerBertInfer = None,
                 build_inverted_index: bool = False,
                 kb_format: str = "hdt",
                 kb_filename: str = None,
                 label_rel: str = None,
                 descr_rel: str = None,
                 aliases_rels: List[str] = None,
                 sql_table_name: str = None,
                 sql_column_names: List[str] = None,
                 lang: str = "en",
                 use_descriptions: bool = False,
                 lemmatize: bool = False,
                 use_prefix_tree: bool = False,
                 **kwargs) -> None:
        """

        Args:
            load_path: path to folder with inverted index files
            save_path: path where to save inverted index files
            inverted_index_filename: file with dict of words (keys) and entities containing these words
            entities_list_filename: file with the list of entities from the knowledge base
            q2name_filename: name of file which maps entity id to name
            q2descr_filename: name of file which maps entity id to description
            rel_ranker: component deeppavlov.models.kbqa.rel_ranker_bert_infer
            build_inverted_index: if "true", inverted index of entities of the KB will be built
            kb_format: "hdt" or "sqlite3"
            kb_filename: file with the knowledge base, which will be used for building of inverted index
            label_rel: relation in the knowledge base which connects entity ids and entity titles
            descr_rel: relation in the knowledge base which connects entity ids and entity descriptions
            aliases_rels: list of relations which connect entity ids and entity aliases
            sql_table_name: name of the table with the KB if the KB is in sqlite3 format
            sql_column_names: names of columns with subject, relation and object
            lang: language used
            use_descriptions: whether to use context and descriptions of entities for entity ranking
            lemmatize: whether to lemmatize tokens of extracted entity
            use_prefix_tree: whether to use prefix tree for search of entities with typos in entity labels
            **kwargs:
        """
        super().__init__(save_path=save_path, load_path=load_path)
        self.morph = pymorphy2.MorphAnalyzer()
        self.lemmatize = lemmatize
        self.use_prefix_tree = use_prefix_tree
        self.inverted_index_filename = inverted_index_filename
        self.entities_list_filename = entities_list_filename
        self.build_inverted_index = build_inverted_index
        self.q2name_filename = q2name_filename
        self.q2descr_filename = q2descr_filename
        self.kb_format = kb_format
        self.kb_filename = kb_filename
        self.label_rel = label_rel
        self.aliases_rels = aliases_rels
        self.descr_rel = descr_rel
        self.sql_table_name = sql_table_name
        self.sql_column_names = sql_column_names
        self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None
        self.entities_index: Optional[List[str]] = None
        self.q2name: Optional[List[Tuple[str]]] = None
        self.lang_str = f"@{lang}"
        if self.lang_str == "@en":
            self.stopwords = set(stopwords.words("english"))
        elif self.lang_str == "@ru":
            self.stopwords = set(stopwords.words("russian"))
        self.re_tokenizer = re.compile(r"[\w']+|[^\w ]")
        self.rel_ranker = rel_ranker
        self.use_descriptions = use_descriptions

        if self.use_prefix_tree:
            alphabet = "!#%\&'()+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz½¿ÁÄ" + \
                       "ÅÆÇÉÎÓÖ×ÚßàáâãäåæçèéêëíîïðñòóôöøùúûüýāăąćČčĐėęěĞğĩīİıŁłńňŌōőřŚśşŠšťũūůŵźŻżŽžơưșȚțəʻ" + \
                       "ʿΠΡβγБМавдежикмностъяḤḥṇṬṭầếờợ–‘’Ⅲ−∗"
            dictionary_words = list(self.inverted_index.keys())
            self.searcher = LevenshteinSearcher(alphabet, dictionary_words)

        if self.build_inverted_index:
            if self.kb_format == "hdt":
                self.doc = HDTDocument(str(expand_path(self.kb_filename)))
            if self.kb_format == "sqlite3":
                self.conn = sqlite3.connect(str(expand_path(self.kb_filename)))
                self.cursor = self.conn.cursor()
            self.inverted_index_builder()
            self.save()
        else:
            self.load()
Example #3
0
class EntityLinker(Component, Serializable):
    """
        This class extracts from the knowledge base candidate entities for the entity mentioned in the question and then
        extracts triplets from Wikidata for the extracted entity. Candidate entities are searched in the dictionary where 
        keys are titles and aliases of Wikidata entities and values are lists of tuples (entity_title, entity_id,
        number_of_relations). First candidate entities are searched in the dictionary by keys where the keys are
        entities extracted from the question, if nothing is found entities are searched in the dictionary using
        Levenstein distance between the entity and keys (titles) in the dictionary.
    """

    def __init__(self, load_path: str,
                 inverted_index_filename: str,
                 entities_list_filename: str,
                 q2name_filename: str,
                 save_path: str = None,
                 q2descr_filename: str = None,
                 rel_ranker: RelRankerBertInfer = None,
                 build_inverted_index: bool = False,
                 kb_format: str = "hdt",
                 kb_filename: str = None,
                 label_rel: str = None,
                 descr_rel: str = None,
                 aliases_rels: List[str] = None,
                 sql_table_name: str = None,
                 sql_column_names: List[str] = None,
                 lang: str = "en",
                 use_descriptions: bool = False,
                 lemmatize: bool = False,
                 use_prefix_tree: bool = False,
                 **kwargs) -> None:
        """

        Args:
            load_path: path to folder with inverted index files
            save_path: path where to save inverted index files
            inverted_index_filename: file with dict of words (keys) and entities containing these words
            entities_list_filename: file with the list of entities from the knowledge base
            q2name_filename: name of file which maps entity id to name
            q2descr_filename: name of file which maps entity id to description
            rel_ranker: component deeppavlov.models.kbqa.rel_ranker_bert_infer
            build_inverted_index: if "true", inverted index of entities of the KB will be built
            kb_format: "hdt" or "sqlite3"
            kb_filename: file with the knowledge base, which will be used for building of inverted index
            label_rel: relation in the knowledge base which connects entity ids and entity titles
            descr_rel: relation in the knowledge base which connects entity ids and entity descriptions
            aliases_rels: list of relations which connect entity ids and entity aliases
            sql_table_name: name of the table with the KB if the KB is in sqlite3 format
            sql_column_names: names of columns with subject, relation and object
            lang: language used
            use_descriptions: whether to use context and descriptions of entities for entity ranking
            lemmatize: whether to lemmatize tokens of extracted entity
            use_prefix_tree: whether to use prefix tree for search of entities with typos in entity labels
            **kwargs:
        """
        super().__init__(save_path=save_path, load_path=load_path)
        self.morph = pymorphy2.MorphAnalyzer()
        self.lemmatize = lemmatize
        self.use_prefix_tree = use_prefix_tree
        self.inverted_index_filename = inverted_index_filename
        self.entities_list_filename = entities_list_filename
        self.build_inverted_index = build_inverted_index
        self.q2name_filename = q2name_filename
        self.q2descr_filename = q2descr_filename
        self.kb_format = kb_format
        self.kb_filename = kb_filename
        self.label_rel = label_rel
        self.aliases_rels = aliases_rels
        self.descr_rel = descr_rel
        self.sql_table_name = sql_table_name
        self.sql_column_names = sql_column_names
        self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None
        self.entities_index: Optional[List[str]] = None
        self.q2name: Optional[List[Tuple[str]]] = None
        self.lang_str = f"@{lang}"
        if self.lang_str == "@en":
            self.stopwords = set(stopwords.words("english"))
        elif self.lang_str == "@ru":
            self.stopwords = set(stopwords.words("russian"))
        self.re_tokenizer = re.compile(r"[\w']+|[^\w ]")
        self.rel_ranker = rel_ranker
        self.use_descriptions = use_descriptions

        if self.use_prefix_tree:
            alphabet = "!#%\&'()+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz½¿ÁÄ" + \
                       "ÅÆÇÉÎÓÖ×ÚßàáâãäåæçèéêëíîïðñòóôöøùúûüýāăąćČčĐėęěĞğĩīİıŁłńňŌōőřŚśşŠšťũūůŵźŻżŽžơưșȚțəʻ" + \
                       "ʿΠΡβγБМавдежикмностъяḤḥṇṬṭầếờợ–‘’Ⅲ−∗"
            dictionary_words = list(self.inverted_index.keys())
            self.searcher = LevenshteinSearcher(alphabet, dictionary_words)

        if self.build_inverted_index:
            if self.kb_format == "hdt":
                self.doc = HDTDocument(str(expand_path(self.kb_filename)))
            if self.kb_format == "sqlite3":
                self.conn = sqlite3.connect(str(expand_path(self.kb_filename)))
                self.cursor = self.conn.cursor()
            self.inverted_index_builder()
            self.save()
        else:
            self.load()

    def load(self) -> None:
        self.inverted_index = load_pickle(self.load_path / self.inverted_index_filename)
        self.entities_list = load_pickle(self.load_path / self.entities_list_filename)
        self.q2name = load_pickle(self.load_path / self.q2name_filename)

    def save(self) -> None:
        save_pickle(self.inverted_index, self.save_path / self.inverted_index_filename)
        save_pickle(self.entities_list, self.save_path / self.entities_list_filename)
        save_pickle(self.q2name, self.save_path / self.q2name_filename)
        if self.q2descr_filename is not None:
            save_pickle(self.q2descr, self.save_path / self.q2descr_filename)

    def __call__(self, entity_substr_batch: List[List[str]], entity_positions_batch: List[List[List[int]]] = None,
                       context_tokens: List[List[str]] = None) -> Tuple[List[List[List[str]]], List[List[List[float]]]]:
        entity_ids_batch = []
        confidences_batch = []
        if entity_positions_batch is None:
            entity_positions_batch = [[[0] for i in range(len(entities_list))] for entities_list in entity_substr_batch]
        for entity_substr_list, entity_positions_list in zip(entity_substr_batch, entity_positions_batch):
            entity_ids_list = []
            confidences_list = []
            for entity_substr, entity_pos in zip(entity_substr_list, entity_positions_list):
                context = ""
                if self.use_descriptions:
                    context = ' '.join(context_tokens[:entity_pos[0]]+["[ENT]"]+context_tokens[entity_pos[-1]+1:])
                entity_ids, confidences = self.link_entity(entity_substr, context)
                entity_ids_list.append(entity_ids)
                confidences_list.append(confidences)
        entity_ids_batch.append(entity_ids_list)
        confidences_batch.append(confidences_list)

        return entity_ids_batch, confidences_batch

    def link_entity(self, entity: str, context: str = None) -> Tuple[List[str], List[float]]:
        confidences = []
        if not entity:
            entities_ids = ['None']
        else:
            candidate_entities = self.candidate_entities_inverted_index(entity)
            candidate_entities, candidate_names = self.candidate_entities_names(entity, candidate_entities)
            entities_ids, confidences, srtd_cand_ent = self.sort_found_entities(candidate_entities,
                                                                                 candidate_names, entity, context)

        return entities_ids, confidences

    def candidate_entities_inverted_index(self, entity: str) -> List[Tuple[Any, Any, Any]]:
        word_tokens = nltk.word_tokenize(entity.lower())
        candidate_entities = []

        for tok in word_tokens:
            if len(tok) > 1:
                found = False
                if tok in self.inverted_index:
                    candidate_entities += self.inverted_index[tok]
                    found = True

                if self.lemmatize:
                    morph_parse_tok = self.morph.parse(tok)[0]
                    lemmatized_tok = morph_parse_tok.normal_form
                    if lemmatized_tok in self.inverted_index:
                        candidate_entities += self.inverted_index[lemmatized_tok]
                        found = True

                if not found and self.use_prefix_tree:
                    words_with_levens_1 = self.searcher.search(tok, d=1)
                    for word in words_with_levens_1:
                        candidate_entities += self.inverted_index[word[0]]
        candidate_entities = list(set(candidate_entities))
        candidate_entities = [(entity[0], self.entities_list[entity[0]], entity[1]) for entity in candidate_entities]

        return candidate_entities

    def sort_found_entities(self, candidate_entities: List[Tuple[int, str, int]],
                            candidate_names: List[List[str]],
                            entity: str, context: str = None) -> Tuple[List[str], List[float], List[Tuple[str, str, int, int]]]:
        entities_ratios = []
        for candidate, entity_names in zip(candidate_entities, candidate_names):
            entity_num, entity_id, num_rels = candidate
            fuzz_ratio = max([fuzz.ratio(name.lower(), entity) for name in entity_names])
            entities_ratios.append((entity_num, entity_id, fuzz_ratio, num_rels))

        srtd_with_ratios = sorted(entities_ratios, key=lambda x: (x[2], x[3]), reverse=True)
        if self.use_descriptions:
            num_to_id = {entity_num: entity_id for entity_num, entity_id, _, _ in srtd_with_ratios[:30]}
            entity_numbers = [entity_num for entity_num, _, _, _ in srtd_with_ratios[:30]]
            scores = self.rel_ranker.rank_rels(context, entity_numbers)
            top_rels = [score[0] for score in scores]
            entity_ids = [num_to_id[num] for num in top_rels]
            confidences = [score[1] for score in scores]
        else:
            entity_ids = [ent[1] for ent in srtd_with_ratios]
            confidences = [float(ent[2]) * 0.01 for ent in srtd_with_ratios]

        return entity_ids, confidences, srtd_with_ratios

    def candidate_entities_names(self, entity: str,
          candidate_entities: List[Tuple[int, str, int]]) -> Tuple[List[Tuple[int, str, int]], List[List[str]]]:
        entity_length = len(entity)
        candidate_names = []
        candidate_entities_filter = []
        for candidate in candidate_entities:
            entity_num = candidate[0]
            entity_id = candidate[1]
            entity_names = []
            
            entity_names_found = self.q2name[entity_num]
            if len(entity_names_found[0]) < 6 * entity_length:
                entity_name = entity_names_found[0]
                entity_names.append(entity_name)
                if len(entity_names_found) > 1:
                    for alias in entity_names_found[1:]:
                        entity_names.append(alias)
                candidate_names.append(entity_names)
                candidate_entities_filter.append(candidate)

        return candidate_entities_filter, candidate_names

    def inverted_index_builder(self) -> None:
        log.debug("building inverted index")
        entities_set = set()
        id_to_label_dict = defaultdict(list)
        id_to_descr_dict = {}
        label_to_id_dict = {}
        label_triplets = []
        alias_triplets_list = []
        descr_triplets = []
        if self.kb_format == "hdt":
            label_triplets, c = self.doc.search_triples("", self.label_rel, "")
            if self.aliases_rels is not None:
                for alias_rel in self.aliases_rels:
                    alias_triplets, c = self.doc.search_triples("", alias_rel, "")
                    alias_triplets_list.append(alias_triplets)
            if self.descr_rel is not None:
                descr_triplets, c = self.doc.search_triples("", self.descr_rel, "")

        if self.kb_format == "sqlite3":
            subject, relation, obj = self.sql_column_names
            query = f'SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} WHERE {relation} = "{self.label_rel}";'
            res = self.cursor.execute(query)
            label_triplets = res.fetchall()
            if self.aliases_rels is not None:
                for alias_rel in self.aliases_rels:
                    query = f'SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} WHERE {relation} = "{alias_rel}";'
                    res = self.cursor.execute(query)
                    alias_triplets = res.fetchall()
                    alias_triplets_list.append(alias_triplets)
            if self.descr_rel is not None:
                query = f'SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} WHERE {relation} = "{self.descr_rel}";'
                res = self.cursor.execute(query)
                descr_triplets = res.fetchall()

        for triplets in [label_triplets] + alias_triplets_list:
            for triplet in triplets:
                entities_set.add(triplet[0])
                if triplet[2].endswith(self.lang_str):
                    label = triplet[2].replace(self.lang_str, '').replace('"', '')
                    id_to_label_dict[triplet[0]].append(label)
                    label_to_id_dict[label] = triplet[0]

        for triplet in descr_triplets:
            entities_set.add(triplet[0])
            if triplet[2].endswith(self.lang_str):
                descr = triplet[2].replace(self.lang_str, '').replace('"', '')
                id_to_descr_dict[triplet[0]].append(descr)

        popularities_dict = {}
        for entity in entities_set:
            if self.kb_format == "hdt":
                all_triplets, number_of_triplets = self.doc.search_triples(entity, "", "")
                popularities_dict[entity] = number_of_triplets
            if self.kb_format == "sqlite3":
                subject, relation, obj = self.sql_column_names
                query = f'SELECT COUNT({obj}) FROM {self.sql_table_name} WHERE {subject} = "{entity}";'
                res = self.cursor.execute(query)
                popularities_dict[entity] = res.fetchall()[0][0]

        entities_dict = {entity: n for n, entity in enumerate(entities_set)}
            
        inverted_index = defaultdict(list)
        for label in label_to_id_dict:
            tokens = re.findall(self.re_tokenizer, label.lower())
            for tok in tokens:
                if len(tok) > 1 and tok not in self.stopwords:
                    inverted_index[tok].append((entities_dict[label_to_id_dict[label]],
                                                popularities_dict[label_to_id_dict[label]]))
        self.inverted_index = dict(inverted_index)
        self.entities_list = list(entities_set)
        self.q2name = [id_to_label_dict[entity] for entity in self.entities_list]
        self.q2descr = []
        if id_to_descr_dict:
            self.q2descr = [id_to_descr_dict[entity] for entity in self.entities_list]
Example #4
0
class KBEntityLinker(Component, Serializable):
    """
        This class extracts from the knowledge base candidate entities for the entity mentioned in the question and then
        extracts triplets from Wikidata for the extracted entity. Candidate entities are searched in the dictionary
        where keys are titles and aliases of Wikidata entities and values are lists of tuples (entity_title, entity_id,
        number_of_relations). First candidate entities are searched in the dictionary by keys where the keys are
        entities extracted from the question, if nothing is found entities are searched in the dictionary using
        Levenstein distance between the entity and keys (titles) in the dictionary.
    """
    def __init__(self,
                 load_path: str,
                 inverted_index_filename: str,
                 entities_list_filename: str,
                 q2name_filename: str,
                 who_entities_filename: Optional[str] = None,
                 save_path: str = None,
                 q2descr_filename: str = None,
                 descr_rank_score_thres: float = 0.0,
                 freq_dict_filename: Optional[str] = None,
                 entity_ranker: RelRankerBertInfer = None,
                 build_inverted_index: bool = False,
                 kb_format: str = "hdt",
                 kb_filename: str = None,
                 label_rel: str = None,
                 descr_rel: str = None,
                 aliases_rels: List[str] = None,
                 sql_table_name: str = None,
                 sql_column_names: List[str] = None,
                 lang: str = "en",
                 use_descriptions: bool = False,
                 include_mention: bool = False,
                 lemmatize: bool = False,
                 use_prefix_tree: bool = False,
                 **kwargs) -> None:
        """

        Args:
            load_path: path to folder with inverted index files
            inverted_index_filename: file with dict of words (keys) and entities containing these words
            entities_list_filename: file with the list of entities from the knowledge base
            q2name_filename: name of file which maps entity id to name
            who_entities_filename: file with the list of entities in Wikidata, which can be answers to questions
                with "Who" pronoun, i.e. humans, literary characters etc.
            save_path: path where to save inverted index files
            q2descr_filename: name of file which maps entity id to description
            descr_rank_score_thres: if the score of the entity description is less than threshold, the entity is not
                added to output list
            freq_dict_filename: filename with frequences dictionary of Russian words
            entity_ranker: component deeppavlov.models.kbqa.rel_ranker_bert_infer
            build_inverted_index: if "true", inverted index of entities of the KB will be built
            kb_format: "hdt" or "sqlite3"
            kb_filename: file with the knowledge base, which will be used for building of inverted index
            label_rel: relation in the knowledge base which connects entity ids and entity titles
            descr_rel: relation in the knowledge base which connects entity ids and entity descriptions
            aliases_rels: list of relations which connect entity ids and entity aliases
            sql_table_name: name of the table with the KB if the KB is in sqlite3 format
            sql_column_names: names of columns with subject, relation and object
            lang: language used
            use_descriptions: whether to use context and descriptions of entities for entity ranking
            include_mention: whether to leave or delete entity mention from the sentence before passing to BERT ranker
            lemmatize: whether to lemmatize tokens of extracted entity
            use_prefix_tree: whether to use prefix tree for search of entities with typos in entity labels
            **kwargs:
        """
        super().__init__(save_path=save_path, load_path=load_path)
        self.morph = pymorphy2.MorphAnalyzer()
        self.lemmatize = lemmatize
        self.use_prefix_tree = use_prefix_tree
        self.inverted_index_filename = inverted_index_filename
        self.entities_list_filename = entities_list_filename
        self.build_inverted_index = build_inverted_index
        self.q2name_filename = q2name_filename
        self.who_entities_filename = who_entities_filename
        self.q2descr_filename = q2descr_filename
        self.descr_rank_score_thres = descr_rank_score_thres
        self.freq_dict_filename = freq_dict_filename
        self.kb_format = kb_format
        self.kb_filename = kb_filename
        self.label_rel = label_rel
        self.aliases_rels = aliases_rels
        self.descr_rel = descr_rel
        self.sql_table_name = sql_table_name
        self.sql_column_names = sql_column_names
        self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None
        self.entities_index: Optional[List[str]] = None
        self.q2name: Optional[List[Tuple[str]]] = None
        self.lang_str = f"@{lang}"
        if self.lang_str == "@en":
            self.stopwords = set(stopwords.words("english"))
        elif self.lang_str == "@ru":
            self.stopwords = set(stopwords.words("russian"))
        self.re_tokenizer = re.compile(r"[\w']+|[^\w ]")
        self.entity_ranker = entity_ranker
        self.use_descriptions = use_descriptions
        self.include_mention = include_mention
        if self.use_descriptions and self.entity_ranker is None:
            raise ValueError("No entity ranker is provided!")

        if self.use_prefix_tree:
            alphabet = "!#%\&'()+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz½¿ÁÄ" + \
                       "ÅÆÇÉÎÓÖ×ÚßàáâãäåæçèéêëíîïðñòóôöøùúûüýāăąćČčĐėęěĞğĩīİıŁłńňŌōőřŚśşŠšťũūůŵźŻżŽžơưșȚțəʻ" + \
                       "ʿΠΡβγБМавдежикмностъяḤḥṇṬṭầếờợ–‘’Ⅲ−∗"
            dictionary_words = list(self.inverted_index.keys())
            self.searcher = LevenshteinSearcher(alphabet, dictionary_words)

        if self.build_inverted_index:
            if self.kb_format == "hdt":
                self.doc = HDTDocument(str(expand_path(self.kb_filename)))
            elif self.kb_format == "sqlite3":
                self.conn = sqlite3.connect(str(expand_path(self.kb_filename)))
                self.cursor = self.conn.cursor()
            else:
                raise ValueError(
                    f'unsupported kb_format value {self.kb_format}')
            self.inverted_index_builder()
            self.save()
        else:
            self.load()

    def load_freq_dict(self, freq_dict_filename: str):
        with open(str(expand_path(freq_dict_filename)), 'r') as fl:
            lines = fl.readlines()
        pos_freq_dict = defaultdict(list)
        for line in lines:
            line_split = line.strip('\n').split('\t')
            if re.match("[\d]+\.[\d]+", line_split[2]):
                pos_freq_dict[line_split[1]].append(
                    (line_split[0], float(line_split[2])))
        nouns_with_freq = pos_freq_dict["s"]
        self.nouns_dict = {noun: freq for noun, freq in nouns_with_freq}

    def load(self) -> None:
        self.inverted_index = load_pickle(self.load_path /
                                          self.inverted_index_filename)
        self.entities_list = load_pickle(self.load_path /
                                         self.entities_list_filename)
        self.q2name = load_pickle(self.load_path / self.q2name_filename)
        if self.who_entities_filename:
            self.who_entities = load_pickle(self.load_path /
                                            self.who_entities_filename)
        if self.freq_dict_filename:
            self.load_freq_dict(self.freq_dict_filename)

    def save(self) -> None:
        save_pickle(self.inverted_index,
                    self.save_path / self.inverted_index_filename)
        save_pickle(self.entities_list,
                    self.save_path / self.entities_list_filename)
        save_pickle(self.q2name, self.save_path / self.q2name_filename)
        if self.q2descr_filename is not None:
            save_pickle(self.q2descr, self.save_path / self.q2descr_filename)

    def __call__(
        self,
        entity_substr_batch: List[List[str]],
        entity_positions_batch: List[List[List[int]]] = None,
        context_tokens: List[List[str]] = None
    ) -> Tuple[List[List[List[str]]], List[List[List[float]]]]:
        entity_ids_batch = []
        confidences_batch = []
        if entity_positions_batch is None:
            entity_positions_batch = [[[0] for i in range(len(entities_list))]
                                      for entities_list in entity_substr_batch]
        for entity_substr_list, entity_positions_list in zip(
                entity_substr_batch, entity_positions_batch):
            entity_ids_list = []
            confidences_list = []
            for entity_substr, entity_pos in zip(entity_substr_list,
                                                 entity_positions_list):
                context = ""
                if self.use_descriptions:
                    if self.include_mention:
                        context = ' '.join(
                            context_tokens[:entity_pos[0]] + ["[ENT]"] +
                            context_tokens[entity_pos[0]:entity_pos[-1] + 1] +
                            ["[ENT]"] + context_tokens[entity_pos[-1] + 1:])
                    else:
                        context = ' '.join(context_tokens[:entity_pos[0]] +
                                           ["[ENT]"] +
                                           context_tokens[entity_pos[-1] + 1:])
                entity_ids, confidences = self.link_entity(
                    entity_substr, context)
                entity_ids_list.append(entity_ids)
                confidences_list.append(confidences)
        entity_ids_batch.append(entity_ids_list)
        confidences_batch.append(confidences_list)

        return entity_ids_batch, confidences_batch

    def link_entity(self,
                    entity: str,
                    context: Optional[str] = None,
                    template_found: Optional[str] = None,
                    cut_entity: bool = False) -> Tuple[List[str], List[float]]:
        confidences = []
        if not entity:
            entities_ids = ['None']
        else:
            candidate_entities = self.candidate_entities_inverted_index(entity)
            if cut_entity and candidate_entities and len(
                    entity.split()) > 1 and candidate_entities[0][3] == 1:
                entity = self.cut_entity_substr(entity)
                candidate_entities = self.candidate_entities_inverted_index(
                    entity)
            candidate_entities, candidate_names = self.candidate_entities_names(
                entity, candidate_entities)
            entities_ids, confidences, srtd_cand_ent = self.sort_found_entities(
                candidate_entities, candidate_names, entity, context)
            if template_found:
                entities_ids = self.filter_entities(entities_ids,
                                                    template_found)

        return entities_ids, confidences

    def cut_entity_substr(self, entity: str):
        word_tokens = nltk.word_tokenize(entity.lower())
        word_tokens = [
            word for word in word_tokens if word not in self.stopwords
        ]
        normal_form_tokens = [
            self.morph.parse(word)[0].normal_form for word in word_tokens
        ]
        words_with_freq = [(word, self.nouns_dict.get(word, 0.0))
                           for word in normal_form_tokens]
        words_with_freq = sorted(words_with_freq, key=lambda x: x[1])
        return words_with_freq[0][0]

    def candidate_entities_inverted_index(
            self, entity: str) -> List[Tuple[Any, Any, Any]]:
        word_tokens = nltk.word_tokenize(entity.lower())
        word_tokens = [
            word for word in word_tokens if word not in self.stopwords
        ]
        candidate_entities = []

        for tok in word_tokens:
            if len(tok) > 1:
                found = False
                if tok in self.inverted_index:
                    candidate_entities += self.inverted_index[tok]
                    found = True

                if self.lemmatize:
                    morph_parse_tok = self.morph.parse(tok)[0]
                    lemmatized_tok = morph_parse_tok.normal_form
                    if lemmatized_tok != tok and lemmatized_tok in self.inverted_index:
                        candidate_entities += self.inverted_index[
                            lemmatized_tok]
                        found = True

                if not found and self.use_prefix_tree:
                    words_with_levens_1 = self.searcher.search(tok, d=1)
                    for word in words_with_levens_1:
                        candidate_entities += self.inverted_index[word[0]]
        candidate_entities = Counter(candidate_entities).most_common()
        candidate_entities = [(entity_num, self.entities_list[entity_num], entity_freq, count) for \
                                                (entity_num, entity_freq), count in candidate_entities]

        return candidate_entities

    def sort_found_entities(
        self,
        candidate_entities: List[Tuple[int, str, int]],
        candidate_names: List[List[str]],
        entity: str,
        context: str = None
    ) -> Tuple[List[str], List[float], List[Tuple[str, str, int, int]]]:
        entities_ratios = []
        for candidate, entity_names in zip(candidate_entities,
                                           candidate_names):
            entity_num, entity_id, num_rels, tokens_matched = candidate
            fuzz_ratio = max(
                [fuzz.ratio(name.lower(), entity) for name in entity_names])
            entities_ratios.append(
                (entity_num, entity_id, tokens_matched, fuzz_ratio, num_rels))

        srtd_with_ratios = sorted(entities_ratios,
                                  key=lambda x: (x[2], x[3], x[4]),
                                  reverse=True)
        if self.use_descriptions:
            log.debug(f"context {context}")
            id_to_score = {
                entity_id: (tokens_matched, score)
                for _, entity_id, tokens_matched, score, _ in
                srtd_with_ratios[:30]
            }
            entity_ids = [
                entity_id for _, entity_id, _, _, _ in srtd_with_ratios[:30]
            ]
            scores = self.entity_ranker.rank_rels(context, entity_ids)
            entities_with_scores = [(entity_id, id_to_score[entity_id][0],
                                     id_to_score[entity_id][1], score)
                                    for entity_id, score in scores]
            entities_with_scores = sorted(entities_with_scores,
                                          key=lambda x: (x[1], x[2], x[3]),
                                          reverse=True)
            entities_with_scores = [entity for entity in entities_with_scores if \
                                   (entity[3] > self.descr_rank_score_thres or entity[2] == 100.0)]
            log.debug(f"entities_with_scores {entities_with_scores[:10]}")
            entity_ids = [entity for entity, _, _, _ in entities_with_scores]
            confidences = [score for _, _, _, score in entities_with_scores]
        else:
            entity_ids = [ent[1] for ent in srtd_with_ratios]
            confidences = [float(ent[2]) * 0.01 for ent in srtd_with_ratios]

        return entity_ids, confidences, srtd_with_ratios

    def candidate_entities_names(
        self, entity: str, candidate_entities: List[Tuple[int, str, int]]
    ) -> Tuple[List[Tuple[int, str, int]], List[List[str]]]:
        entity_length = len(entity)
        candidate_names = []
        candidate_entities_filter = []
        for candidate in candidate_entities:
            entity_num = candidate[0]
            entity_names = []

            entity_names_found = self.q2name[entity_num]
            if len(entity_names_found[0]) < 6 * entity_length:
                entity_name = entity_names_found[0]
                entity_names.append(entity_name)
                if len(entity_names_found) > 1:
                    for alias in entity_names_found[1:]:
                        entity_names.append(alias)
                candidate_names.append(entity_names)
                candidate_entities_filter.append(candidate)

        return candidate_entities_filter, candidate_names

    def inverted_index_builder(self) -> None:
        log.debug("building inverted index")
        entities_set = set()
        id_to_label_dict = defaultdict(list)
        id_to_descr_dict = {}
        label_to_id_dict = {}
        label_triplets = []
        alias_triplets_list = []
        descr_triplets = []
        if self.kb_format == "hdt":
            label_triplets, c = self.doc.search_triples("", self.label_rel, "")
            if self.aliases_rels is not None:
                for alias_rel in self.aliases_rels:
                    alias_triplets, c = self.doc.search_triples(
                        "", alias_rel, "")
                    alias_triplets_list.append(alias_triplets)
            if self.descr_rel is not None:
                descr_triplets, c = self.doc.search_triples(
                    "", self.descr_rel, "")

        if self.kb_format == "sqlite3":
            subject, relation, obj = self.sql_column_names
            query = f'SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} '\
                    f'WHERE {relation} = "{self.label_rel}";'
            res = self.cursor.execute(query)
            label_triplets = res.fetchall()
            if self.aliases_rels is not None:
                for alias_rel in self.aliases_rels:
                    query = f'SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} '\
                            f'WHERE {relation} = "{alias_rel}";'
                    res = self.cursor.execute(query)
                    alias_triplets = res.fetchall()
                    alias_triplets_list.append(alias_triplets)
            if self.descr_rel is not None:
                query = f'SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} '\
                        f'WHERE {relation} = "{self.descr_rel}";'
                res = self.cursor.execute(query)
                descr_triplets = res.fetchall()

        for triplets in [label_triplets] + alias_triplets_list:
            for triplet in triplets:
                entities_set.add(triplet[0])
                if triplet[2].endswith(self.lang_str):
                    label = triplet[2].replace(self.lang_str,
                                               '').replace('"', '')
                    id_to_label_dict[triplet[0]].append(label)
                    label_to_id_dict[label] = triplet[0]

        for triplet in descr_triplets:
            entities_set.add(triplet[0])
            if triplet[2].endswith(self.lang_str):
                descr = triplet[2].replace(self.lang_str, '').replace('"', '')
                id_to_descr_dict[triplet[0]].append(descr)

        popularities_dict = {}
        for entity in entities_set:
            if self.kb_format == "hdt":
                all_triplets, number_of_triplets = self.doc.search_triples(
                    entity, "", "")
                popularities_dict[entity] = number_of_triplets
            if self.kb_format == "sqlite3":
                subject, relation, obj = self.sql_column_names
                query = f'SELECT COUNT({obj}) FROM {self.sql_table_name} WHERE {subject} = "{entity}";'
                res = self.cursor.execute(query)
                popularities_dict[entity] = res.fetchall()[0][0]

        entities_dict = {entity: n for n, entity in enumerate(entities_set)}

        inverted_index = defaultdict(list)
        for label in label_to_id_dict:
            tokens = re.findall(self.re_tokenizer, label.lower())
            for tok in tokens:
                if len(tok) > 1 and tok not in self.stopwords:
                    inverted_index[tok].append(
                        (entities_dict[label_to_id_dict[label]],
                         popularities_dict[label_to_id_dict[label]]))
        self.inverted_index = dict(inverted_index)
        self.entities_list = list(entities_set)
        self.q2name = [
            id_to_label_dict[entity] for entity in self.entities_list
        ]
        self.q2descr = []
        if id_to_descr_dict:
            self.q2descr = [
                id_to_descr_dict[entity] for entity in self.entities_list
            ]

    def filter_entities(self, entities: List[str],
                        template_found: str) -> List[str]:
        if template_found in ["who is xxx?", "who was xxx?"]:
            entities = [
                entity for entity in entities if entity in self.who_entities
            ]
        if template_found in ["what is xxx?", "what was xxx?"]:
            entities = [
                entity for entity in entities
                if entity not in self.who_entities
            ]
        return entities
    def __init__(self,
                 load_path: str,
                 inverted_index_filename: str,
                 entities_list_filename: str,
                 q2name_filename: str,
                 types_dict_filename: Optional[str] = None,
                 who_entities_filename: Optional[str] = None,
                 save_path: str = None,
                 q2descr_filename: str = None,
                 descr_rank_score_thres: float = 0.01,
                 freq_dict_filename: Optional[str] = None,
                 entity_ranker: RelRankerBertInfer = None,
                 build_inverted_index: bool = False,
                 kb_format: str = "hdt",
                 kb_filename: str = None,
                 label_rel: str = None,
                 descr_rel: str = None,
                 aliases_rels: List[str] = None,
                 sql_table_name: str = None,
                 sql_column_names: List[str] = None,
                 lang: str = "en",
                 use_descriptions: bool = False,
                 include_mention: bool = False,
                 num_entities_to_return: int = 5,
                 lemmatize: bool = False,
                 use_prefix_tree: bool = False,
                 **kwargs) -> None:
        """

        Args:
            load_path: path to folder with inverted index files
            inverted_index_filename: file with dict of words (keys) and entities containing these words
            entities_list_filename: file with the list of entities from the knowledge base
            q2name_filename: file which maps entity id to name
            types_dict_filename: file with types of entities
            who_entities_filename: file with the list of entities in Wikidata, which can be answers to questions
                with "Who" pronoun, i.e. humans, literary characters etc.
            save_path: path where to save inverted index files
            q2descr_filename: name of file which maps entity id to description
            descr_rank_score_thres: if the score of the entity description is less than threshold, the entity is not
                added to output list
            freq_dict_filename: filename with frequences dictionary of Russian words
            entity_ranker: component deeppavlov.models.kbqa.rel_ranker_bert_infer
            build_inverted_index: if "true", inverted index of entities of the KB will be built
            kb_format: "hdt" or "sqlite3"
            kb_filename: file with the knowledge base, which will be used for building of inverted index
            label_rel: relation in the knowledge base which connects entity ids and entity titles
            descr_rel: relation in the knowledge base which connects entity ids and entity descriptions
            aliases_rels: list of relations which connect entity ids and entity aliases
            sql_table_name: name of the table with the KB if the KB is in sqlite3 format
            sql_column_names: names of columns with subject, relation and object
            lang: language used
            use_descriptions: whether to use context and descriptions of entities for entity ranking
            include_mention: whether to leave or delete entity mention from the sentence before passing to BERT ranker
            num_entities_to_return: how many entities for each substring the system returns
            lemmatize: whether to lemmatize tokens of extracted entity
            use_prefix_tree: whether to use prefix tree for search of entities with typos in entity labels
            **kwargs:
        """
        super().__init__(save_path=save_path, load_path=load_path)
        self.morph = pymorphy2.MorphAnalyzer()
        self.lemmatize = lemmatize
        self.use_prefix_tree = use_prefix_tree
        self.inverted_index_filename = inverted_index_filename
        self.entities_list_filename = entities_list_filename
        self.build_inverted_index = build_inverted_index
        self.q2name_filename = q2name_filename
        self.types_dict_filename = types_dict_filename
        self.who_entities_filename = who_entities_filename
        self.q2descr_filename = q2descr_filename
        self.descr_rank_score_thres = descr_rank_score_thres
        self.freq_dict_filename = freq_dict_filename
        self.kb_format = kb_format
        self.kb_filename = kb_filename
        self.label_rel = label_rel
        self.aliases_rels = aliases_rels
        self.descr_rel = descr_rel
        self.sql_table_name = sql_table_name
        self.sql_column_names = sql_column_names
        self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None
        self.entities_index: Optional[List[str]] = None
        self.q2name: Optional[List[Tuple[str]]] = None
        self.types_dict: Optional[Dict[str, List[str]]] = None
        self.lang_str = f"@{lang}"
        if self.lang_str == "@en":
            self.stopwords = set(stopwords.words("english"))
        elif self.lang_str == "@ru":
            self.stopwords = set(stopwords.words("russian"))
        self.re_tokenizer = re.compile(r"[\w']+|[^\w ]")
        self.entity_ranker = entity_ranker
        self.use_descriptions = use_descriptions
        self.include_mention = include_mention
        self.num_entities_to_return = num_entities_to_return
        if self.use_descriptions and self.entity_ranker is None:
            raise ValueError("No entity ranker is provided!")

        if self.use_prefix_tree:
            alphabet = "!#%\&'()+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz½¿ÁÄ" + \
                       "ÅÆÇÉÎÓÖ×ÚßàáâãäåæçèéêëíîïðñòóôöøùúûüýāăąćČčĐėęěĞğĩīİıŁłńňŌōőřŚśşŠšťũūůŵźŻżŽžơưșȚțəʻ" + \
                       "ʿΠΡβγБМавдежикмностъяḤḥṇṬṭầếờợ–‘’Ⅲ−∗"
            dictionary_words = list(self.inverted_index.keys())
            self.searcher = LevenshteinSearcher(alphabet, dictionary_words)

        if self.build_inverted_index:
            if self.kb_format == "hdt":
                self.doc = HDTDocument(str(expand_path(self.kb_filename)))
            elif self.kb_format == "sqlite3":
                self.conn = sqlite3.connect(str(expand_path(self.kb_filename)))
                self.cursor = self.conn.cursor()
            else:
                raise ValueError(
                    f'unsupported kb_format value {self.kb_format}')
            self.inverted_index_builder()
            self.save()
        else:
            self.load()
Example #6
0
class EntityLinker(Component, Serializable):
    """
        This class extracts from Wikidata candidate entities for the entity mentioned in the question and then extracts
        triplets from Wikidata for the extracted entity. Candidate entities are searched in the dictionary where keys
        are titles and aliases of Wikidata entities and values are lists of tuples (entity_title, entity_id,
        number_of_relations). First candidate entities are searched in the dictionary by keys where the keys are
        entities extracted from the question, if nothing is found entities are searched in the dictionary using
        Levenstein distance between the entity and keys (titles) in the dictionary.
    """

    LANGUAGES = set(['rus'])

    def __init__(self,
                 load_path: str,
                 wiki_filename: str,
                 entities_filename: str,
                 inverted_index_filename: str,
                 id_to_name_file: str,
                 lemmatize: bool = True,
                 debug: bool = False,
                 rule_filter_entities: bool = True,
                 use_inverted_index: bool = True,
                 language: str = 'rus',
                 *args,
                 **kwargs) -> None:
        """

        Args:
            load_path: path to folder with wikidata files
            wiki_filename: file with Wikidata triplets
            entities_filename: file with dict of entity titles (keys) and entity ids (values)
            inverted_index_filename: file with dict of words (keys) and entities containing these words (values)
            id_to_name_file: file with dict of entity ids (keys) and entities names and aliases (values)
            lemmatize: whether to lemmatize tokens of extracted entity
            debug: whether to print entities extracted from Wikidata
            rule_filter_entities: whether to filter entities which do not fit the question
            use_inverted_index: whether to use inverted index for entity linking
            language - the language of the linker (used for filtration of some questions to improve overall performance)
            *args:
            **kwargs:
        """
        super().__init__(save_path=None, load_path=load_path)
        self.morph = pymorphy2.MorphAnalyzer()
        self.lemmatize = lemmatize
        self.debug = debug
        self.rule_filter_entities = rule_filter_entities
        self.use_inverted_index = use_inverted_index
        self._language = language
        if language not in self.LANGUAGES:
            log.warning(
                f'EntityLinker supports only the following languages: {self.LANGUAGES}'
            )

        self._wiki_filename = wiki_filename
        self._entities_filename = entities_filename
        self.inverted_index_filename = inverted_index_filename
        self.id_to_name_file = id_to_name_file

        self.name_to_q: Optional[Dict[str, List[Tuple[str]]]] = None
        self.wikidata: Optional[Dict[str, List[List[str]]]] = None
        self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None
        self.id_to_name: Optional[Dict[str, Dict[List[str]]]] = None
        self.load()
        if self.use_inverted_index:
            alphabet = "abcdefghijklmnopqrstuvwxyzабвгдеёжзийклмнопрстуфхцчшщъыьэюя1234567890-_()=+!?.,/;:&@<>|#$%^*"
            dictionary_words = list(self.inverted_index.keys())
            self.searcher = LevenshteinSearcher(alphabet, dictionary_words)

    def load(self) -> None:
        if self.use_inverted_index:
            with open(self.load_path / self.inverted_index_filename,
                      'rb') as inv:
                self.inverted_index = pickle.load(inv)
                self.inverted_index: Dict[str, List[Tuple[str]]]
            with open(self.load_path / self.id_to_name_file, 'rb') as i2n:
                self.id_to_name = pickle.load(i2n)
                self.id_to_name: Dict[str, Dict[List[str]]]
        else:
            with open(self.load_path / self._entities_filename, 'rb') as e:
                self.name_to_q = pickle.load(e)
                self.name_to_q: Dict[str, List[Tuple[str]]]
        with open(self.load_path / self._wiki_filename, 'rb') as w:
            self.wikidata = pickle.load(w)
            self.wikidata: Dict[str, List[List[str]]]

    def save(self) -> None:
        pass

    def __call__(
            self, entity: str, question_tokens: List[str]
    ) -> Tuple[List[List[List[str]]], List[str]]:
        confidences = []
        srtd_cand_ent = []
        if not entity:
            wiki_entities = ['None']
        else:
            if self.use_inverted_index:
                candidate_entities = self.candidate_entities_inverted_index(
                    entity)
                candidate_names = self.candidate_entities_names(
                    candidate_entities)
                wiki_entities, confidences, srtd_cand_ent = self.sort_found_entities(
                    candidate_entities, candidate_names, entity)
            else:
                candidate_entities = self.find_candidate_entities(entity)

                srtd_cand_ent = sorted(candidate_entities,
                                       key=lambda x: x[2],
                                       reverse=True)
                if len(srtd_cand_ent) > 0:
                    wiki_entities = [ent[1] for ent in srtd_cand_ent]
                    confidences = [1.0 for i in range(len(srtd_cand_ent))]
                    srtd_cand_ent = [
                        (ent[0], ent[1], conf, ent[2])
                        for ent, conf in zip(srtd_cand_ent, confidences)
                    ]
                else:
                    candidates = self.fuzzy_entity_search(entity)
                    candidates = list(set(candidates))
                    srtd_cand_ent = [(ent[0][0], ent[0][1], ent[1], ent[0][2])
                                     for ent in candidates]
                    srtd_cand_ent = sorted(srtd_cand_ent,
                                           key=lambda x: (x[2], x[3]),
                                           reverse=True)

                    if len(srtd_cand_ent) > 0:
                        wiki_entities = [ent[1] for ent in srtd_cand_ent]
                        confidences = [
                            float(ent[2]) * 0.01 for ent in srtd_cand_ent
                        ]
                    else:
                        wiki_entities = ["None"]
                        confidences = [0.0]

        entity_triplets = self.extract_triplets_from_wiki(wiki_entities)
        if self.rule_filter_entities and self._language == 'rus':
            filtered_entities, filtered_entity_triplets = self.filter_triplets_rus(
                entity_triplets, question_tokens, srtd_cand_ent)
        if self.debug:
            self._log_entities(filtered_entities[:10])

        return filtered_entity_triplets, confidences

    def _log_entities(self, srtd_cand_ent):
        entities_to_print = []
        for name, q, ratio, n_rel in srtd_cand_ent:
            entities_to_print.append(
                f'{name}, http://wikidata.org/wiki/{q}, {ratio}, {n_rel}')
        log.debug('\n' + '\n'.join(entities_to_print))

    def find_candidate_entities(self, entity: str) -> List[str]:
        candidate_entities = list(self.name_to_q.get(entity, []))
        entity_split = entity.split(' ')
        if len(entity_split) < 6 and self.lemmatize:
            entity_lemm_tokens = []
            for tok in entity_split:
                morph_parse_tok = self.morph.parse(tok)[0]
                lemmatized_tok = morph_parse_tok.normal_form
                entity_lemm_tokens.append(lemmatized_tok)
            masks = itertools.product([False, True], repeat=len(entity_split))
            for mask in masks:
                entity_lemm = []
                for i in range(len(entity_split)):
                    if mask[i]:
                        entity_lemm.append(entity_split[i])
                    else:
                        entity_lemm.append(entity_lemm_tokens[i])
                entity_lemm = ' '.join(entity_lemm)
                if entity_lemm != entity:
                    candidate_entities += self.name_to_q.get(entity_lemm, [])
        candidate_entities = list(set(candidate_entities))

        return candidate_entities

    def fuzzy_entity_search(self, entity: str) -> List[Tuple[Tuple, str]]:
        word_length = len(entity)
        candidates = []
        for title in self.name_to_q:
            length_ratio = len(title) / word_length
            if length_ratio > 0.75 and length_ratio < 1.25:
                ratio = fuzz.ratio(title, entity)
                if ratio > 70:
                    entity_candidates = self.name_to_q.get(title, [])
                    for cand in entity_candidates:
                        candidates.append((cand, fuzz.ratio(entity, cand[0])))
        return candidates

    def extract_triplets_from_wiki(
            self, entity_ids: List[str]) -> List[List[List[str]]]:
        entity_triplets = []
        for entity_id in entity_ids:
            if entity_id in self.wikidata and entity_id.startswith('Q'):
                triplets_for_entity = self.wikidata[entity_id]
                entity_triplets.append(triplets_for_entity)
            else:
                entity_triplets.append([])

        return entity_triplets

    @staticmethod
    def filter_triplets_rus(
        entity_triplets: List[List[List[str]]], question_tokens: List[str],
        srtd_cand_ent: List[Tuple[str]]
    ) -> Tuple[List[Tuple[str]], List[List[List[str]]]]:

        question = ' '.join(question_tokens).lower()
        what_template = 'что '
        found_what_template = False
        found_what_template = question.find(what_template) > -1
        filtered_entity_triplets = []
        filtered_entities = []
        for wiki_entity, triplets_for_entity in zip(srtd_cand_ent,
                                                    entity_triplets):
            entity_is_human = False
            entity_is_asteroid = False
            entity_is_named = False
            entity_title = wiki_entity[0]
            if entity_title[0].isupper():
                entity_is_named = True
            property_is_instance_of = 'P31'
            id_for_entity_human = 'Q5'
            id_for_entity_asteroid = 'Q3863'
            for triplet in triplets_for_entity:
                if triplet[0] == property_is_instance_of and triplet[
                        1] == id_for_entity_human:
                    entity_is_human = True
                    break
                if triplet[0] == property_is_instance_of and triplet[
                        1] == id_for_entity_asteroid:
                    entity_is_asteroid = True
                    break
            if found_what_template and (entity_is_human or entity_is_named
                                        or entity_is_asteroid
                                        or wiki_entity[2] < 90):
                continue
            filtered_entity_triplets.append(triplets_for_entity)
            filtered_entities.append(wiki_entity)

        return filtered_entities, filtered_entity_triplets

    def candidate_entities_inverted_index(self,
                                          entity: str) -> List[Tuple[str]]:
        word_tokens = nltk.word_tokenize(entity)
        candidate_entities = []

        for tok in word_tokens:
            if len(tok) > 1:
                found = False
                if tok in self.inverted_index:
                    candidate_entities += self.inverted_index[tok]
                    found = True
                morph_parse_tok = self.morph.parse(tok)[0]
                lemmatized_tok = morph_parse_tok.normal_form
                if lemmatized_tok != tok and lemmatized_tok in self.inverted_index:
                    candidate_entities += self.inverted_index[lemmatized_tok]
                    found = True
                if not found:
                    words_with_levens_1 = self.searcher.search(tok, d=1)
                    for word in words_with_levens_1:
                        candidate_entities += self.inverted_index[word[0]]
        candidate_entities = list(set(candidate_entities))

        return candidate_entities

    def candidate_entities_names(
            self, candidate_entities: List[Tuple[str]]) -> List[List[str]]:
        candidate_names = []
        for candidate in candidate_entities:
            entity_id = candidate[0]
            entity_names = [self.id_to_name[entity_id]["name"]]
            if "aliases" in self.id_to_name[entity_id].keys():
                aliases = self.id_to_name[entity_id]["aliases"]
                for alias in aliases:
                    entity_names.append(alias)
            candidate_names.append(entity_names)

        return candidate_names

    def sort_found_entities(
            self, candidate_entities: List[Tuple[str]],
            candidate_names: List[List[str]],
            entity: str) -> Tuple[List[str], List[str], List[Tuple[str]]]:
        entities_ratios = []
        for candidate, entity_names in zip(candidate_entities,
                                           candidate_names):
            entity_id = candidate[0]
            num_rels = candidate[1]
            entity_name = entity_names[0]
            morph_parse_entity = self.morph.parse(entity)[0]
            lemm_entity = morph_parse_entity.normal_form
            fuzz_ratio_lemm = max([
                fuzz.ratio(name.lower(), lemm_entity.lower())
                for name in entity_names
            ])
            fuzz_ratio_nolemm = max([
                fuzz.ratio(name.lower(), entity.lower())
                for name in entity_names
            ])
            fuzz_ratio = max(fuzz_ratio_lemm, fuzz_ratio_nolemm)
            entities_ratios.append(
                (entity_name, entity_id, fuzz_ratio, num_rels))

        srtd_with_ratios = sorted(entities_ratios,
                                  key=lambda x: (x[2], x[3]),
                                  reverse=True)
        wiki_entities = [ent[1] for ent in srtd_with_ratios if ent[2] > 84]
        confidences = [
            float(ent[2]) * 0.01 for ent in srtd_with_ratios if ent[2] > 84
        ]

        return wiki_entities, confidences, srtd_with_ratios
Example #7
0
class KBEntityLinker(Component, Serializable):
    """
    This class extracts from the knowledge base candidate entities for the entity mentioned in the question and then
    extracts triplets from Wikidata for the extracted entity. Candidate entities are searched in the dictionary
    where keys are titles and aliases of Wikidata entities and values are lists of tuples (entity_title, entity_id,
    number_of_relations). First candidate entities are searched in the dictionary by keys where the keys are
    entities extracted from the question, if nothing is found entities are searched in the dictionary using
    Levenstein distance between the entity and keys (titles) in the dictionary.
    """
    def __init__(
        self,
        load_path: str,
        inverted_index_filename: str,
        entities_list_filename: str,
        q2name_filename: str,
        types_dict_filename: Optional[str] = None,
        who_entities_filename: Optional[str] = None,
        save_path: str = None,
        q2descr_filename: str = None,
        descr_rank_score_thres: float = 0.5,
        freq_dict_filename: Optional[str] = None,
        entity_ranker: RelRankerInfer = None,
        build_inverted_index: bool = False,
        kb_format: str = "hdt",
        kb_filename: str = None,
        label_rel: str = None,
        descr_rel: str = None,
        aliases_rels: List[str] = None,
        sql_table_name: str = None,
        sql_column_names: List[str] = None,
        lang: str = "en",
        use_descriptions: bool = False,
        include_mention: bool = False,
        num_entities_to_return: int = 5,
        num_entities_for_bert_ranking: int = 100,
        lemmatize: bool = False,
        use_prefix_tree: bool = False,
        **kwargs,
    ) -> None:
        """

        Args:
            load_path: path to folder with inverted index files
            inverted_index_filename: file with dict of words (keys) and entities containing these words
            entities_list_filename: file with the list of entities from the knowledge base
            q2name_filename: file which maps entity id to name
            types_dict_filename: file with types of entities
            who_entities_filename: file with the list of entities in Wikidata, which can be answers to questions
                with "Who" pronoun, i.e. humans, literary characters etc.
            save_path: path where to save inverted index files
            q2descr_filename: name of file which maps entity id to description
            descr_rank_score_thres: if the score of the entity description is less than threshold, the entity is not
                added to output list
            freq_dict_filename: filename with frequences dictionary of Russian words
            entity_ranker: component deeppavlov.models.kbqa.rel_ranker_infer
            build_inverted_index: if "true", inverted index of entities of the KB will be built
            kb_format: "hdt" or "sqlite3"
            kb_filename: file with the knowledge base, which will be used for building of inverted index
            label_rel: relation in the knowledge base which connects entity ids and entity titles
            descr_rel: relation in the knowledge base which connects entity ids and entity descriptions
            aliases_rels: list of relations which connect entity ids and entity aliases
            sql_table_name: name of the table with the KB if the KB is in sqlite3 format
            sql_column_names: names of columns with subject, relation and object
            lang: language used
            use_descriptions: whether to use context and descriptions of entities for entity ranking
            include_mention: whether to leave or delete entity mention from the sentence before passing to BERT ranker
            num_entities_to_return: how many entities for each substring the system returns
            lemmatize: whether to lemmatize tokens of extracted entity
            use_prefix_tree: whether to use prefix tree for search of entities with typos in entity labels
            **kwargs:
        """
        super().__init__(save_path=save_path, load_path=load_path)
        self.morph = pymorphy2.MorphAnalyzer()
        self.lemmatize = lemmatize
        self.use_prefix_tree = use_prefix_tree
        self.inverted_index_filename = inverted_index_filename
        self.entities_list_filename = entities_list_filename
        self.build_inverted_index = build_inverted_index
        self.q2name_filename = q2name_filename
        self.types_dict_filename = types_dict_filename
        self.who_entities_filename = who_entities_filename
        self.q2descr_filename = q2descr_filename
        self.descr_rank_score_thres = descr_rank_score_thres
        self.freq_dict_filename = freq_dict_filename
        self.kb_format = kb_format
        self.kb_filename = kb_filename
        self.label_rel = label_rel
        self.aliases_rels = aliases_rels
        self.descr_rel = descr_rel
        self.sql_table_name = sql_table_name
        self.sql_column_names = sql_column_names
        self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None
        self.entities_index: Optional[List[str]] = None
        self.q2name: Optional[List[Tuple[str]]] = None
        self.types_dict: Optional[Dict[str, List[str]]] = None
        self.lang_str = f"@{lang}"
        if self.lang_str == "@en":
            self.stopwords = set(stopwords.words("english"))
        elif self.lang_str == "@ru":
            self.stopwords = set(stopwords.words("russian"))
        self.re_tokenizer = re.compile(r"[\w']+|[^\w ]")
        self.entity_ranker = entity_ranker
        self.nlp = en_core_web_sm.load()
        self.inflect_engine = inflect.engine()
        self.use_descriptions = use_descriptions
        self.include_mention = include_mention
        self.num_entities_to_return = num_entities_to_return
        self.num_entities_for_bert_ranking = num_entities_for_bert_ranking
        self.black_list_what_is = {
            "Q277759",  # book series
            "Q11424",  # film
            "Q7889",  # video game
            "Q2743",  # musical theatre
            "Q5398426",  # tv series
            "Q506240",  # television film
            "Q21191270",  # television series episode
            "Q7725634",  # literary work
            "Q131436",  # board game
            "Q1783817",  # cooperative board game
        }
        if self.use_descriptions and self.entity_ranker is None:
            raise ValueError("No entity ranker is provided!")

        if self.use_prefix_tree:
            alphabet = (
                r"!#%\&'()+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz½¿ÁÄ"
                +
                "ÅÆÇÉÎÓÖ×ÚßàáâãäåæçèéêëíîïðñòóôöøùúûüýāăąćČčĐėęěĞğĩīİıŁłńňŌōőřŚśşŠšťũūůŵźŻżŽžơưșȚțəʻ"
                + "ʿΠΡβγБМавдежикмностъяḤḥṇṬṭầếờợ–‘’Ⅲ−∗")
            dictionary_words = list(self.inverted_index.keys())
            self.searcher = LevenshteinSearcher(alphabet, dictionary_words)

        if self.build_inverted_index:
            if self.kb_format == "hdt":
                self.doc = HDTDocument(str(expand_path(self.kb_filename)))
            elif self.kb_format == "sqlite3":
                self.conn = sqlite3.connect(str(expand_path(self.kb_filename)))
                self.cursor = self.conn.cursor()
            else:
                raise ValueError(
                    f"unsupported kb_format value {self.kb_format}")
            self.inverted_index_builder()
            self.save()
        else:
            self.load()

    def load_freq_dict(self, freq_dict_filename: str):
        with open(str(expand_path(freq_dict_filename)), "r") as fl:
            lines = fl.readlines()
        pos_freq_dict = defaultdict(list)
        for line in lines:
            line_split = line.strip("\n").split("\t")
            if re.match(r"[\d]+\.[\d]+", line_split[2]):
                pos_freq_dict[line_split[1]].append(
                    (line_split[0], float(line_split[2])))
        nouns_with_freq = pos_freq_dict["s"]
        self.nouns_dict = {noun: freq for noun, freq in nouns_with_freq}

    def load(self) -> None:
        self.inverted_index = load_pickle(self.load_path /
                                          self.inverted_index_filename)
        self.entities_list = load_pickle(self.load_path /
                                         self.entities_list_filename)
        self.q2name = load_pickle(self.load_path / self.q2name_filename)
        if self.who_entities_filename:
            self.who_entities = load_pickle(self.load_path /
                                            self.who_entities_filename)
        if self.freq_dict_filename:
            self.load_freq_dict(self.freq_dict_filename)
        if self.types_dict_filename:
            self.types_dict = load_pickle(self.load_path /
                                          self.types_dict_filename)

    def save(self) -> None:
        save_pickle(self.inverted_index,
                    self.save_path / self.inverted_index_filename)
        save_pickle(self.entities_list,
                    self.save_path / self.entities_list_filename)
        save_pickle(self.q2name, self.save_path / self.q2name_filename)
        if self.q2descr_filename is not None:
            save_pickle(self.q2descr, self.save_path / self.q2descr_filename)

    def __call__(
        self,
        entity_substr_batch: List[List[str]],
        templates_batch: List[str] = None,
        long_context_batch: List[str] = None,
        entity_types_batch: List[List[List[str]]] = None,
        short_context_batch: List[str] = None,
    ) -> Tuple[List[List[List[str]]], List[List[List[float]]]]:
        entity_ids_batch = []
        confidences_batch = []
        tokens_match_conf_batch = []
        if templates_batch is None:
            templates_batch = ["" for _ in entity_substr_batch]
        if long_context_batch is None:
            long_context_batch = ["" for _ in entity_substr_batch]
        if short_context_batch is None:
            short_context_batch = ["" for _ in entity_substr_batch]
        if entity_types_batch is None:
            entity_types_batch = [[[] for _ in entity_substr_list]
                                  for entity_substr_list in entity_substr_batch
                                  ]
        for entity_substr_list, template_found, long_context, entity_types_list, short_context in zip(
                entity_substr_batch, templates_batch, long_context_batch,
                entity_types_batch, short_context_batch):
            entity_ids_list = []
            confidences_list = []
            tokens_match_conf_list = []
            for entity_substr, entity_types in zip(entity_substr_list,
                                                   entity_types_list):
                entity_ids, confidences, tokens_match_conf = self.link_entity(
                    entity_substr, long_context, short_context, template_found,
                    entity_types)
                if self.num_entities_to_return == 1:
                    if entity_ids:
                        entity_ids_list.append(entity_ids[0])
                        confidences_list.append(confidences[0])
                        tokens_match_conf_list.append(tokens_match_conf[0])
                    else:
                        entity_ids_list.append("")
                        confidences_list.append(0.0)
                        tokens_match_conf_list.append(0.0)
                else:
                    entity_ids_list.append(
                        entity_ids[:self.num_entities_to_return])
                    confidences_list.append(
                        confidences[:self.num_entities_to_return])
                    tokens_match_conf_list.append(
                        tokens_match_conf[:self.num_entities_to_return])
            entity_ids_batch.append(entity_ids_list)
            confidences_batch.append(confidences_list)
            tokens_match_conf_batch.append(tokens_match_conf_list)

        return entity_ids_batch, confidences_batch, tokens_match_conf_batch

    def lemmatize_substr(self, text):
        lemm_text = ""
        if text:
            pr_text = self.nlp(text)
            processed_tokens = []
            for token in pr_text:
                if token.tag_ in [
                        "NNS", "NNP"
                ] and self.inflect_engine.singular_noun(token.text):
                    processed_tokens.append(
                        self.inflect_engine.singular_noun(token.text))
                else:
                    processed_tokens.append(token.text)
            lemm_text = " ".join(processed_tokens)
        return lemm_text

    def link_entity(
        self,
        entity: str,
        long_context: Optional[str] = None,
        short_context: Optional[str] = None,
        template_found: Optional[str] = None,
        entity_types: List[str] = None,
        cut_entity: bool = False,
    ) -> Tuple[List[str], List[float]]:
        confidences = []
        tokens_match_conf = []
        if not entity:
            entities_ids = ["None"]
        else:
            entity_is_uttr = False
            lets_talk_phrases = [
                "let's talk", "let's chat", "what about", "do you know",
                "tell me about"
            ]
            found_lets_talk_phrase = any(
                [phrase in short_context for phrase in lets_talk_phrases])
            if (short_context and
                (entity == short_context or entity == short_context[:-1]
                 or found_lets_talk_phrase) and len(entity.split()) == 1):
                lemm_entity = self.lemmatize_substr(entity)
                entity_is_uttr = True
            else:
                lemm_entity = entity

            candidate_entities = self.candidate_entities_inverted_index(
                lemm_entity)
            if self.types_dict:
                if entity_types:
                    entity_types = set(entity_types)
                    candidate_entities = [
                        ent
                        for ent in candidate_entities if self.types_dict.get(
                            ent[1], set()).intersection(entity_types)
                    ]
                if template_found in ["what is xxx?", "what was xxx?"
                                      ] or entity_is_uttr:
                    candidate_entities_filtered = [
                        ent for ent in candidate_entities
                        if not self.types_dict.get(ent[1], set()).intersection(
                            self.black_list_what_is)
                    ]
                    if candidate_entities_filtered:
                        candidate_entities = candidate_entities_filtered
            if cut_entity and candidate_entities and len(
                    lemm_entity.split()) > 1 and candidate_entities[0][3] == 1:
                lemm_entity = self.cut_entity_substr(lemm_entity)
                candidate_entities = self.candidate_entities_inverted_index(
                    lemm_entity)
            candidate_entities, candidate_names = self.candidate_entities_names(
                lemm_entity, candidate_entities)
            entities_ids, confidences, tokens_match_conf, srtd_cand_ent = self.sort_found_entities(
                candidate_entities, candidate_names, lemm_entity, entity,
                long_context)
            if template_found:
                entities_ids = self.filter_entities(entities_ids,
                                                    template_found)

        return entities_ids, confidences, tokens_match_conf

    def cut_entity_substr(self, entity: str):
        word_tokens = nltk.word_tokenize(entity.lower())
        word_tokens = [
            word for word in word_tokens if word not in self.stopwords
        ]
        normal_form_tokens = [
            self.morph.parse(word)[0].normal_form for word in word_tokens
        ]
        words_with_freq = [(word, self.nouns_dict.get(word, 0.0))
                           for word in normal_form_tokens]
        words_with_freq = sorted(words_with_freq, key=lambda x: x[1])
        return words_with_freq[0][0]

    def candidate_entities_inverted_index(
            self, entity: str) -> List[Tuple[Any, Any, Any]]:
        word_tokens = nltk.word_tokenize(entity.lower())
        word_tokens = [
            word for word in word_tokens if word not in self.stopwords
        ]
        candidate_entities = []

        candidate_entities_for_tokens = []
        for tok in word_tokens:
            candidate_entities_for_tok = set()
            if len(tok) > 1:
                found = False
                if tok in self.inverted_index:
                    candidate_entities_for_tok = set(self.inverted_index[tok])
                    found = True

                if self.lemmatize:
                    if self.lang_str == "@ru":
                        morph_parse_tok = self.morph.parse(tok)[0]
                        lemmatized_tok = morph_parse_tok.normal_form
                    if self.lang_str == "@en":
                        lemmatized_tok = self.lemmatizer.lemmatize(tok)

                    if lemmatized_tok != tok and lemmatized_tok in self.inverted_index:
                        candidate_entities_for_tok = candidate_entities_for_tok.union(
                            set(self.inverted_index[lemmatized_tok]))
                        found = True

                if not found and self.use_prefix_tree:
                    words_with_levens_1 = self.searcher.search(tok, d=1)
                    for word in words_with_levens_1:
                        candidate_entities_for_tok = candidate_entities_for_tok.union(
                            set(self.inverted_index[word[0]]))
                candidate_entities_for_tokens.append(
                    candidate_entities_for_tok)

        for candidate_entities_for_tok in candidate_entities_for_tokens:
            candidate_entities += list(candidate_entities_for_tok)
        candidate_entities = Counter(candidate_entities).most_common()
        candidate_entities = sorted(candidate_entities,
                                    key=lambda x: (x[0][1], x[1]),
                                    reverse=True)
        candidate_entities = candidate_entities[:1000]
        candidate_entities = [
            (entity_num, self.entities_list[entity_num], entity_freq, count)
            for (entity_num, entity_freq), count in candidate_entities
        ]

        return candidate_entities

    def sort_found_entities(
        self,
        candidate_entities: List[Tuple[int, str, int]],
        candidate_names: List[List[str]],
        lemm_entity: str,
        entity: str,
        context: str = None,
    ) -> Tuple[List[str], List[float], List[Tuple[str, str, int, int]]]:
        entities_ratios = []
        lemm_entity = lemm_entity.lower()
        for candidate, entity_names in zip(candidate_entities,
                                           candidate_names):
            entity_num, entity_id, num_rels, tokens_matched = candidate
            fuzz_ratio = max([
                fuzz.ratio(name.lower(), lemm_entity) for name in entity_names
            ])
            entity_tokens = re.findall(self.re_tokenizer, entity.lower())
            lemm_entity_tokens = re.findall(self.re_tokenizer,
                                            lemm_entity.lower())
            entity_tokens = {
                word
                for word in entity_tokens if (len(word) > 1 and word != "'s"
                                              and word not in self.stopwords)
            }
            lemm_entity_tokens = {
                word
                for word in lemm_entity_tokens if
                (len(word) > 1 and word != "'s" and word not in self.stopwords)
            }
            match_counts = []
            for name in entity_names:
                name_tokens = re.findall(self.re_tokenizer, name.lower())
                name_tokens = {
                    word
                    for word in name_tokens if (len(word) > 1 and word != "'s"
                                                and word not in self.stopwords)
                }
                entity_inters_len = len(
                    entity_tokens.intersection(name_tokens))
                lemm_entity_inters_len = len(
                    lemm_entity_tokens.intersection(name_tokens))

                entity_ratio_1 = 0.0
                entity_ratio_2 = 0.0
                if len(entity_tokens):
                    entity_ratio_1 = entity_inters_len / len(entity_tokens)
                    if entity_ratio_1 > 1.0 and entity_ratio_1 != 0.0:
                        entity_ratio_1 = 1.0 / entity_ratio_1
                if len(name_tokens):
                    entity_ratio_2 = entity_inters_len / len(name_tokens)
                    if entity_ratio_2 > 1.0 and entity_ratio_2 != 0.0:
                        entity_ratio_2 = 1.0 / entity_ratio_2

                lemm_entity_ratio_1 = 0.0
                lemm_entity_ratio_2 = 0.0
                if len(lemm_entity_tokens):
                    lemm_entity_ratio_1 = lemm_entity_inters_len / len(
                        lemm_entity_tokens)
                    if lemm_entity_ratio_1 > 1.0 and lemm_entity_ratio_1 != 0.0:
                        lemm_entity_ratio_1 = 1.0 / lemm_entity_ratio_1
                if len(name_tokens):
                    lemm_entity_ratio_2 = lemm_entity_inters_len / len(
                        name_tokens)
                    if lemm_entity_ratio_2 > 1.0 and lemm_entity_ratio_2 != 0.0:
                        lemm_entity_ratio_2 = 1.0 / lemm_entity_ratio_2

                match_count = max(entity_ratio_1, entity_ratio_2,
                                  lemm_entity_ratio_1, lemm_entity_ratio_2)
                match_counts.append(match_count)
            match_counts = sorted(match_counts, reverse=True)
            if match_counts:
                tokens_matched = match_counts[0]
            else:
                tokens_matched = 0.0

            entities_ratios.append(
                (entity_num, entity_id, tokens_matched, fuzz_ratio, num_rels))

        srtd_with_ratios = sorted(entities_ratios,
                                  key=lambda x: (x[2], x[3], x[4]),
                                  reverse=True)
        if self.use_descriptions:
            log.debug(f"context {context}")
            id_to_score = {
                entity_id: (tokens_matched, score, num_rels)
                for _, entity_id, tokens_matched, score, num_rels in
                srtd_with_ratios[:self.num_entities_for_bert_ranking]
            }
            entity_ids = [
                entity_id for _, entity_id, _, _, _ in
                srtd_with_ratios[:self.num_entities_for_bert_ranking]
            ]
            scores = self.entity_ranker.rank_rels(context, entity_ids)
            entities_with_scores = [
                (entity_id, id_to_score[entity_id][0],
                 id_to_score[entity_id][1], id_to_score[entity_id][2], score)
                for entity_id, score in scores
            ]
            entities_with_scores = sorted(entities_with_scores,
                                          key=lambda x:
                                          (x[1], x[2], x[3], x[4]),
                                          reverse=True)

            entities_with_scores = [
                ent for ent in entities_with_scores
                if (ent[4] > self.descr_rank_score_thres or ent[2] == 100.0 or
                    (ent[1] == 1.0 and ent[2] > 92.0 and ent[3] > 20
                     and ent[4] > 0.2))
            ]
            log.debug(f"entities_with_scores {entities_with_scores[:10]}")
            entity_ids = [ent for ent, *_ in entities_with_scores]
            confidences = [score for *_, score in entities_with_scores]
            tokens_match_conf = [
                ratio for _, ratio, *_ in entities_with_scores
            ]
        else:
            entity_ids = [ent[1] for ent in srtd_with_ratios]
            confidences = [ent[4] * 0.01 for ent in srtd_with_ratios]
            tokens_match_conf = [ent[2] for ent in srtd_with_ratios]

        return entity_ids, confidences, tokens_match_conf, srtd_with_ratios

    def candidate_entities_names(
        self, entity: str, candidate_entities: List[Tuple[int, str, int]]
    ) -> Tuple[List[Tuple[int, str, int]], List[List[str]]]:
        entity_length = len(entity)
        candidate_names = []
        candidate_entities_filter = []
        for candidate in candidate_entities:
            entity_num = candidate[0]
            entity_names = []

            entity_names_found = self.q2name[entity_num]
            if len(entity_names_found[0]) < 6 * entity_length:
                entity_name = entity_names_found[0]
                entity_names.append(entity_name)
                if len(entity_names_found) > 1:
                    for alias in entity_names_found[1:]:
                        entity_names.append(alias)
                candidate_names.append(entity_names)
                candidate_entities_filter.append(candidate)

        return candidate_entities_filter, candidate_names

    def inverted_index_builder(self) -> None:
        log.debug("building inverted index")
        entities_set = set()
        id_to_label_dict = defaultdict(list)
        id_to_descr_dict = {}
        label_to_id_dict = {}
        label_triplets = []
        alias_triplets_list = []
        descr_triplets = []
        if self.kb_format == "hdt":
            label_triplets, c = self.doc.search_triples("", self.label_rel, "")
            if self.aliases_rels is not None:
                for alias_rel in self.aliases_rels:
                    alias_triplets, c = self.doc.search_triples(
                        "", alias_rel, "")
                    alias_triplets_list.append(alias_triplets)
            if self.descr_rel is not None:
                descr_triplets, c = self.doc.search_triples(
                    "", self.descr_rel, "")

        if self.kb_format == "sqlite3":
            subject, relation, obj = self.sql_column_names
            query = (
                f"SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} "
                f'WHERE {relation} = "{self.label_rel}";')
            res = self.cursor.execute(query)
            label_triplets = res.fetchall()
            if self.aliases_rels is not None:
                for alias_rel in self.aliases_rels:
                    query = (
                        f"SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} "
                        f'WHERE {relation} = "{alias_rel}";')
                    res = self.cursor.execute(query)
                    alias_triplets = res.fetchall()
                    alias_triplets_list.append(alias_triplets)
            if self.descr_rel is not None:
                query = (
                    f"SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} "
                    f'WHERE {relation} = "{self.descr_rel}";')
                res = self.cursor.execute(query)
                descr_triplets = res.fetchall()

        for triplets in [label_triplets] + alias_triplets_list:
            for triplet in triplets:
                entities_set.add(triplet[0])
                if triplet[2].endswith(self.lang_str):
                    label = triplet[2].replace(self.lang_str,
                                               "").replace('"', "")
                    id_to_label_dict[triplet[0]].append(label)
                    label_to_id_dict[label] = triplet[0]

        for triplet in descr_triplets:
            entities_set.add(triplet[0])
            if triplet[2].endswith(self.lang_str):
                descr = triplet[2].replace(self.lang_str, "").replace('"', "")
                id_to_descr_dict[triplet[0]].append(descr)

        popularities_dict = {}
        for entity in entities_set:
            if self.kb_format == "hdt":
                all_triplets, number_of_triplets = self.doc.search_triples(
                    entity, "", "")
                popularities_dict[entity] = number_of_triplets
            if self.kb_format == "sqlite3":
                subject, relation, obj = self.sql_column_names
                query = f'SELECT COUNT({obj}) FROM {self.sql_table_name} WHERE {subject} = "{entity}";'
                res = self.cursor.execute(query)
                popularities_dict[entity] = res.fetchall()[0][0]

        entities_dict = {entity: n for n, entity in enumerate(entities_set)}

        inverted_index = defaultdict(list)
        for label in label_to_id_dict:
            tokens = re.findall(self.re_tokenizer, label.lower())
            for tok in tokens:
                if len(tok) > 1 and tok not in self.stopwords:
                    inverted_index[tok].append(
                        (entities_dict[label_to_id_dict[label]],
                         popularities_dict[label_to_id_dict[label]]))
        self.inverted_index = dict(inverted_index)
        self.entities_list = list(entities_set)
        self.q2name = [
            id_to_label_dict[entity] for entity in self.entities_list
        ]
        self.q2descr = []
        if id_to_descr_dict:
            self.q2descr = [
                id_to_descr_dict[entity] for entity in self.entities_list
            ]

    def filter_entities(self, entities: List[str],
                        template_found: str) -> List[str]:
        if template_found in ["who is xxx?", "who was xxx?"]:
            entities = [
                entity for entity in entities if entity in self.who_entities
            ]
        if template_found in ["what is xxx?", "what was xxx?"]:
            entities = [
                entity for entity in entities
                if entity not in self.who_entities
            ]
        return entities