def __init__(self, load_path: str, wiki_filename: str, entities_filename: str, inverted_index_filename: str, id_to_name_file: str, lemmatize: bool = True, debug: bool = False, rule_filter_entities: bool = True, use_inverted_index: bool = True, language: str = 'rus', *args, **kwargs) -> None: """ Args: load_path: path to folder with wikidata files wiki_filename: file with Wikidata triplets entities_filename: file with dict of entity titles (keys) and entity ids (values) inverted_index_filename: file with dict of words (keys) and entities containing these words (values) id_to_name_file: file with dict of entity ids (keys) and entities names and aliases (values) lemmatize: whether to lemmatize tokens of extracted entity debug: whether to print entities extracted from Wikidata rule_filter_entities: whether to filter entities which do not fit the question use_inverted_index: whether to use inverted index for entity linking language - the language of the linker (used for filtration of some questions to improve overall performance) *args: **kwargs: """ super().__init__(save_path=None, load_path=load_path) self.morph = pymorphy2.MorphAnalyzer() self.lemmatize = lemmatize self.debug = debug self.rule_filter_entities = rule_filter_entities self.use_inverted_index = use_inverted_index self._language = language if language not in self.LANGUAGES: log.warning( f'EntityLinker supports only the following languages: {self.LANGUAGES}' ) self._wiki_filename = wiki_filename self._entities_filename = entities_filename self.inverted_index_filename = inverted_index_filename self.id_to_name_file = id_to_name_file self.name_to_q: Optional[Dict[str, List[Tuple[str]]]] = None self.wikidata: Optional[Dict[str, List[List[str]]]] = None self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None self.id_to_name: Optional[Dict[str, Dict[List[str]]]] = None self.load() if self.use_inverted_index: alphabet = "abcdefghijklmnopqrstuvwxyzабвгдеёжзийклмнопрстуфхцчшщъыьэюя1234567890-_()=+!?.,/;:&@<>|#$%^*" dictionary_words = list(self.inverted_index.keys()) self.searcher = LevenshteinSearcher(alphabet, dictionary_words)
def __init__(self, load_path: str, inverted_index_filename: str, entities_list_filename: str, q2name_filename: str, save_path: str = None, q2descr_filename: str = None, rel_ranker: RelRankerBertInfer = None, build_inverted_index: bool = False, kb_format: str = "hdt", kb_filename: str = None, label_rel: str = None, descr_rel: str = None, aliases_rels: List[str] = None, sql_table_name: str = None, sql_column_names: List[str] = None, lang: str = "en", use_descriptions: bool = False, lemmatize: bool = False, use_prefix_tree: bool = False, **kwargs) -> None: """ Args: load_path: path to folder with inverted index files save_path: path where to save inverted index files inverted_index_filename: file with dict of words (keys) and entities containing these words entities_list_filename: file with the list of entities from the knowledge base q2name_filename: name of file which maps entity id to name q2descr_filename: name of file which maps entity id to description rel_ranker: component deeppavlov.models.kbqa.rel_ranker_bert_infer build_inverted_index: if "true", inverted index of entities of the KB will be built kb_format: "hdt" or "sqlite3" kb_filename: file with the knowledge base, which will be used for building of inverted index label_rel: relation in the knowledge base which connects entity ids and entity titles descr_rel: relation in the knowledge base which connects entity ids and entity descriptions aliases_rels: list of relations which connect entity ids and entity aliases sql_table_name: name of the table with the KB if the KB is in sqlite3 format sql_column_names: names of columns with subject, relation and object lang: language used use_descriptions: whether to use context and descriptions of entities for entity ranking lemmatize: whether to lemmatize tokens of extracted entity use_prefix_tree: whether to use prefix tree for search of entities with typos in entity labels **kwargs: """ super().__init__(save_path=save_path, load_path=load_path) self.morph = pymorphy2.MorphAnalyzer() self.lemmatize = lemmatize self.use_prefix_tree = use_prefix_tree self.inverted_index_filename = inverted_index_filename self.entities_list_filename = entities_list_filename self.build_inverted_index = build_inverted_index self.q2name_filename = q2name_filename self.q2descr_filename = q2descr_filename self.kb_format = kb_format self.kb_filename = kb_filename self.label_rel = label_rel self.aliases_rels = aliases_rels self.descr_rel = descr_rel self.sql_table_name = sql_table_name self.sql_column_names = sql_column_names self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None self.entities_index: Optional[List[str]] = None self.q2name: Optional[List[Tuple[str]]] = None self.lang_str = f"@{lang}" if self.lang_str == "@en": self.stopwords = set(stopwords.words("english")) elif self.lang_str == "@ru": self.stopwords = set(stopwords.words("russian")) self.re_tokenizer = re.compile(r"[\w']+|[^\w ]") self.rel_ranker = rel_ranker self.use_descriptions = use_descriptions if self.use_prefix_tree: alphabet = "!#%\&'()+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz½¿ÁÄ" + \ "ÅÆÇÉÎÓÖ×ÚßàáâãäåæçèéêëíîïðñòóôöøùúûüýāăąćČčĐėęěĞğĩīİıŁłńňŌōőřŚśşŠšťũūůŵźŻżŽžơưșȚțəʻ" + \ "ʿΠΡβγБМавдежикмностъяḤḥṇṬṭầếờợ–‘’Ⅲ−∗" dictionary_words = list(self.inverted_index.keys()) self.searcher = LevenshteinSearcher(alphabet, dictionary_words) if self.build_inverted_index: if self.kb_format == "hdt": self.doc = HDTDocument(str(expand_path(self.kb_filename))) if self.kb_format == "sqlite3": self.conn = sqlite3.connect(str(expand_path(self.kb_filename))) self.cursor = self.conn.cursor() self.inverted_index_builder() self.save() else: self.load()
class EntityLinker(Component, Serializable): """ This class extracts from the knowledge base candidate entities for the entity mentioned in the question and then extracts triplets from Wikidata for the extracted entity. Candidate entities are searched in the dictionary where keys are titles and aliases of Wikidata entities and values are lists of tuples (entity_title, entity_id, number_of_relations). First candidate entities are searched in the dictionary by keys where the keys are entities extracted from the question, if nothing is found entities are searched in the dictionary using Levenstein distance between the entity and keys (titles) in the dictionary. """ def __init__(self, load_path: str, inverted_index_filename: str, entities_list_filename: str, q2name_filename: str, save_path: str = None, q2descr_filename: str = None, rel_ranker: RelRankerBertInfer = None, build_inverted_index: bool = False, kb_format: str = "hdt", kb_filename: str = None, label_rel: str = None, descr_rel: str = None, aliases_rels: List[str] = None, sql_table_name: str = None, sql_column_names: List[str] = None, lang: str = "en", use_descriptions: bool = False, lemmatize: bool = False, use_prefix_tree: bool = False, **kwargs) -> None: """ Args: load_path: path to folder with inverted index files save_path: path where to save inverted index files inverted_index_filename: file with dict of words (keys) and entities containing these words entities_list_filename: file with the list of entities from the knowledge base q2name_filename: name of file which maps entity id to name q2descr_filename: name of file which maps entity id to description rel_ranker: component deeppavlov.models.kbqa.rel_ranker_bert_infer build_inverted_index: if "true", inverted index of entities of the KB will be built kb_format: "hdt" or "sqlite3" kb_filename: file with the knowledge base, which will be used for building of inverted index label_rel: relation in the knowledge base which connects entity ids and entity titles descr_rel: relation in the knowledge base which connects entity ids and entity descriptions aliases_rels: list of relations which connect entity ids and entity aliases sql_table_name: name of the table with the KB if the KB is in sqlite3 format sql_column_names: names of columns with subject, relation and object lang: language used use_descriptions: whether to use context and descriptions of entities for entity ranking lemmatize: whether to lemmatize tokens of extracted entity use_prefix_tree: whether to use prefix tree for search of entities with typos in entity labels **kwargs: """ super().__init__(save_path=save_path, load_path=load_path) self.morph = pymorphy2.MorphAnalyzer() self.lemmatize = lemmatize self.use_prefix_tree = use_prefix_tree self.inverted_index_filename = inverted_index_filename self.entities_list_filename = entities_list_filename self.build_inverted_index = build_inverted_index self.q2name_filename = q2name_filename self.q2descr_filename = q2descr_filename self.kb_format = kb_format self.kb_filename = kb_filename self.label_rel = label_rel self.aliases_rels = aliases_rels self.descr_rel = descr_rel self.sql_table_name = sql_table_name self.sql_column_names = sql_column_names self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None self.entities_index: Optional[List[str]] = None self.q2name: Optional[List[Tuple[str]]] = None self.lang_str = f"@{lang}" if self.lang_str == "@en": self.stopwords = set(stopwords.words("english")) elif self.lang_str == "@ru": self.stopwords = set(stopwords.words("russian")) self.re_tokenizer = re.compile(r"[\w']+|[^\w ]") self.rel_ranker = rel_ranker self.use_descriptions = use_descriptions if self.use_prefix_tree: alphabet = "!#%\&'()+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz½¿ÁÄ" + \ "ÅÆÇÉÎÓÖ×ÚßàáâãäåæçèéêëíîïðñòóôöøùúûüýāăąćČčĐėęěĞğĩīİıŁłńňŌōőřŚśşŠšťũūůŵźŻżŽžơưșȚțəʻ" + \ "ʿΠΡβγБМавдежикмностъяḤḥṇṬṭầếờợ–‘’Ⅲ−∗" dictionary_words = list(self.inverted_index.keys()) self.searcher = LevenshteinSearcher(alphabet, dictionary_words) if self.build_inverted_index: if self.kb_format == "hdt": self.doc = HDTDocument(str(expand_path(self.kb_filename))) if self.kb_format == "sqlite3": self.conn = sqlite3.connect(str(expand_path(self.kb_filename))) self.cursor = self.conn.cursor() self.inverted_index_builder() self.save() else: self.load() def load(self) -> None: self.inverted_index = load_pickle(self.load_path / self.inverted_index_filename) self.entities_list = load_pickle(self.load_path / self.entities_list_filename) self.q2name = load_pickle(self.load_path / self.q2name_filename) def save(self) -> None: save_pickle(self.inverted_index, self.save_path / self.inverted_index_filename) save_pickle(self.entities_list, self.save_path / self.entities_list_filename) save_pickle(self.q2name, self.save_path / self.q2name_filename) if self.q2descr_filename is not None: save_pickle(self.q2descr, self.save_path / self.q2descr_filename) def __call__(self, entity_substr_batch: List[List[str]], entity_positions_batch: List[List[List[int]]] = None, context_tokens: List[List[str]] = None) -> Tuple[List[List[List[str]]], List[List[List[float]]]]: entity_ids_batch = [] confidences_batch = [] if entity_positions_batch is None: entity_positions_batch = [[[0] for i in range(len(entities_list))] for entities_list in entity_substr_batch] for entity_substr_list, entity_positions_list in zip(entity_substr_batch, entity_positions_batch): entity_ids_list = [] confidences_list = [] for entity_substr, entity_pos in zip(entity_substr_list, entity_positions_list): context = "" if self.use_descriptions: context = ' '.join(context_tokens[:entity_pos[0]]+["[ENT]"]+context_tokens[entity_pos[-1]+1:]) entity_ids, confidences = self.link_entity(entity_substr, context) entity_ids_list.append(entity_ids) confidences_list.append(confidences) entity_ids_batch.append(entity_ids_list) confidences_batch.append(confidences_list) return entity_ids_batch, confidences_batch def link_entity(self, entity: str, context: str = None) -> Tuple[List[str], List[float]]: confidences = [] if not entity: entities_ids = ['None'] else: candidate_entities = self.candidate_entities_inverted_index(entity) candidate_entities, candidate_names = self.candidate_entities_names(entity, candidate_entities) entities_ids, confidences, srtd_cand_ent = self.sort_found_entities(candidate_entities, candidate_names, entity, context) return entities_ids, confidences def candidate_entities_inverted_index(self, entity: str) -> List[Tuple[Any, Any, Any]]: word_tokens = nltk.word_tokenize(entity.lower()) candidate_entities = [] for tok in word_tokens: if len(tok) > 1: found = False if tok in self.inverted_index: candidate_entities += self.inverted_index[tok] found = True if self.lemmatize: morph_parse_tok = self.morph.parse(tok)[0] lemmatized_tok = morph_parse_tok.normal_form if lemmatized_tok in self.inverted_index: candidate_entities += self.inverted_index[lemmatized_tok] found = True if not found and self.use_prefix_tree: words_with_levens_1 = self.searcher.search(tok, d=1) for word in words_with_levens_1: candidate_entities += self.inverted_index[word[0]] candidate_entities = list(set(candidate_entities)) candidate_entities = [(entity[0], self.entities_list[entity[0]], entity[1]) for entity in candidate_entities] return candidate_entities def sort_found_entities(self, candidate_entities: List[Tuple[int, str, int]], candidate_names: List[List[str]], entity: str, context: str = None) -> Tuple[List[str], List[float], List[Tuple[str, str, int, int]]]: entities_ratios = [] for candidate, entity_names in zip(candidate_entities, candidate_names): entity_num, entity_id, num_rels = candidate fuzz_ratio = max([fuzz.ratio(name.lower(), entity) for name in entity_names]) entities_ratios.append((entity_num, entity_id, fuzz_ratio, num_rels)) srtd_with_ratios = sorted(entities_ratios, key=lambda x: (x[2], x[3]), reverse=True) if self.use_descriptions: num_to_id = {entity_num: entity_id for entity_num, entity_id, _, _ in srtd_with_ratios[:30]} entity_numbers = [entity_num for entity_num, _, _, _ in srtd_with_ratios[:30]] scores = self.rel_ranker.rank_rels(context, entity_numbers) top_rels = [score[0] for score in scores] entity_ids = [num_to_id[num] for num in top_rels] confidences = [score[1] for score in scores] else: entity_ids = [ent[1] for ent in srtd_with_ratios] confidences = [float(ent[2]) * 0.01 for ent in srtd_with_ratios] return entity_ids, confidences, srtd_with_ratios def candidate_entities_names(self, entity: str, candidate_entities: List[Tuple[int, str, int]]) -> Tuple[List[Tuple[int, str, int]], List[List[str]]]: entity_length = len(entity) candidate_names = [] candidate_entities_filter = [] for candidate in candidate_entities: entity_num = candidate[0] entity_id = candidate[1] entity_names = [] entity_names_found = self.q2name[entity_num] if len(entity_names_found[0]) < 6 * entity_length: entity_name = entity_names_found[0] entity_names.append(entity_name) if len(entity_names_found) > 1: for alias in entity_names_found[1:]: entity_names.append(alias) candidate_names.append(entity_names) candidate_entities_filter.append(candidate) return candidate_entities_filter, candidate_names def inverted_index_builder(self) -> None: log.debug("building inverted index") entities_set = set() id_to_label_dict = defaultdict(list) id_to_descr_dict = {} label_to_id_dict = {} label_triplets = [] alias_triplets_list = [] descr_triplets = [] if self.kb_format == "hdt": label_triplets, c = self.doc.search_triples("", self.label_rel, "") if self.aliases_rels is not None: for alias_rel in self.aliases_rels: alias_triplets, c = self.doc.search_triples("", alias_rel, "") alias_triplets_list.append(alias_triplets) if self.descr_rel is not None: descr_triplets, c = self.doc.search_triples("", self.descr_rel, "") if self.kb_format == "sqlite3": subject, relation, obj = self.sql_column_names query = f'SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} WHERE {relation} = "{self.label_rel}";' res = self.cursor.execute(query) label_triplets = res.fetchall() if self.aliases_rels is not None: for alias_rel in self.aliases_rels: query = f'SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} WHERE {relation} = "{alias_rel}";' res = self.cursor.execute(query) alias_triplets = res.fetchall() alias_triplets_list.append(alias_triplets) if self.descr_rel is not None: query = f'SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} WHERE {relation} = "{self.descr_rel}";' res = self.cursor.execute(query) descr_triplets = res.fetchall() for triplets in [label_triplets] + alias_triplets_list: for triplet in triplets: entities_set.add(triplet[0]) if triplet[2].endswith(self.lang_str): label = triplet[2].replace(self.lang_str, '').replace('"', '') id_to_label_dict[triplet[0]].append(label) label_to_id_dict[label] = triplet[0] for triplet in descr_triplets: entities_set.add(triplet[0]) if triplet[2].endswith(self.lang_str): descr = triplet[2].replace(self.lang_str, '').replace('"', '') id_to_descr_dict[triplet[0]].append(descr) popularities_dict = {} for entity in entities_set: if self.kb_format == "hdt": all_triplets, number_of_triplets = self.doc.search_triples(entity, "", "") popularities_dict[entity] = number_of_triplets if self.kb_format == "sqlite3": subject, relation, obj = self.sql_column_names query = f'SELECT COUNT({obj}) FROM {self.sql_table_name} WHERE {subject} = "{entity}";' res = self.cursor.execute(query) popularities_dict[entity] = res.fetchall()[0][0] entities_dict = {entity: n for n, entity in enumerate(entities_set)} inverted_index = defaultdict(list) for label in label_to_id_dict: tokens = re.findall(self.re_tokenizer, label.lower()) for tok in tokens: if len(tok) > 1 and tok not in self.stopwords: inverted_index[tok].append((entities_dict[label_to_id_dict[label]], popularities_dict[label_to_id_dict[label]])) self.inverted_index = dict(inverted_index) self.entities_list = list(entities_set) self.q2name = [id_to_label_dict[entity] for entity in self.entities_list] self.q2descr = [] if id_to_descr_dict: self.q2descr = [id_to_descr_dict[entity] for entity in self.entities_list]
class KBEntityLinker(Component, Serializable): """ This class extracts from the knowledge base candidate entities for the entity mentioned in the question and then extracts triplets from Wikidata for the extracted entity. Candidate entities are searched in the dictionary where keys are titles and aliases of Wikidata entities and values are lists of tuples (entity_title, entity_id, number_of_relations). First candidate entities are searched in the dictionary by keys where the keys are entities extracted from the question, if nothing is found entities are searched in the dictionary using Levenstein distance between the entity and keys (titles) in the dictionary. """ def __init__(self, load_path: str, inverted_index_filename: str, entities_list_filename: str, q2name_filename: str, who_entities_filename: Optional[str] = None, save_path: str = None, q2descr_filename: str = None, descr_rank_score_thres: float = 0.0, freq_dict_filename: Optional[str] = None, entity_ranker: RelRankerBertInfer = None, build_inverted_index: bool = False, kb_format: str = "hdt", kb_filename: str = None, label_rel: str = None, descr_rel: str = None, aliases_rels: List[str] = None, sql_table_name: str = None, sql_column_names: List[str] = None, lang: str = "en", use_descriptions: bool = False, include_mention: bool = False, lemmatize: bool = False, use_prefix_tree: bool = False, **kwargs) -> None: """ Args: load_path: path to folder with inverted index files inverted_index_filename: file with dict of words (keys) and entities containing these words entities_list_filename: file with the list of entities from the knowledge base q2name_filename: name of file which maps entity id to name who_entities_filename: file with the list of entities in Wikidata, which can be answers to questions with "Who" pronoun, i.e. humans, literary characters etc. save_path: path where to save inverted index files q2descr_filename: name of file which maps entity id to description descr_rank_score_thres: if the score of the entity description is less than threshold, the entity is not added to output list freq_dict_filename: filename with frequences dictionary of Russian words entity_ranker: component deeppavlov.models.kbqa.rel_ranker_bert_infer build_inverted_index: if "true", inverted index of entities of the KB will be built kb_format: "hdt" or "sqlite3" kb_filename: file with the knowledge base, which will be used for building of inverted index label_rel: relation in the knowledge base which connects entity ids and entity titles descr_rel: relation in the knowledge base which connects entity ids and entity descriptions aliases_rels: list of relations which connect entity ids and entity aliases sql_table_name: name of the table with the KB if the KB is in sqlite3 format sql_column_names: names of columns with subject, relation and object lang: language used use_descriptions: whether to use context and descriptions of entities for entity ranking include_mention: whether to leave or delete entity mention from the sentence before passing to BERT ranker lemmatize: whether to lemmatize tokens of extracted entity use_prefix_tree: whether to use prefix tree for search of entities with typos in entity labels **kwargs: """ super().__init__(save_path=save_path, load_path=load_path) self.morph = pymorphy2.MorphAnalyzer() self.lemmatize = lemmatize self.use_prefix_tree = use_prefix_tree self.inverted_index_filename = inverted_index_filename self.entities_list_filename = entities_list_filename self.build_inverted_index = build_inverted_index self.q2name_filename = q2name_filename self.who_entities_filename = who_entities_filename self.q2descr_filename = q2descr_filename self.descr_rank_score_thres = descr_rank_score_thres self.freq_dict_filename = freq_dict_filename self.kb_format = kb_format self.kb_filename = kb_filename self.label_rel = label_rel self.aliases_rels = aliases_rels self.descr_rel = descr_rel self.sql_table_name = sql_table_name self.sql_column_names = sql_column_names self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None self.entities_index: Optional[List[str]] = None self.q2name: Optional[List[Tuple[str]]] = None self.lang_str = f"@{lang}" if self.lang_str == "@en": self.stopwords = set(stopwords.words("english")) elif self.lang_str == "@ru": self.stopwords = set(stopwords.words("russian")) self.re_tokenizer = re.compile(r"[\w']+|[^\w ]") self.entity_ranker = entity_ranker self.use_descriptions = use_descriptions self.include_mention = include_mention if self.use_descriptions and self.entity_ranker is None: raise ValueError("No entity ranker is provided!") if self.use_prefix_tree: alphabet = "!#%\&'()+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz½¿ÁÄ" + \ "ÅÆÇÉÎÓÖ×ÚßàáâãäåæçèéêëíîïðñòóôöøùúûüýāăąćČčĐėęěĞğĩīİıŁłńňŌōőřŚśşŠšťũūůŵźŻżŽžơưșȚțəʻ" + \ "ʿΠΡβγБМавдежикмностъяḤḥṇṬṭầếờợ–‘’Ⅲ−∗" dictionary_words = list(self.inverted_index.keys()) self.searcher = LevenshteinSearcher(alphabet, dictionary_words) if self.build_inverted_index: if self.kb_format == "hdt": self.doc = HDTDocument(str(expand_path(self.kb_filename))) elif self.kb_format == "sqlite3": self.conn = sqlite3.connect(str(expand_path(self.kb_filename))) self.cursor = self.conn.cursor() else: raise ValueError( f'unsupported kb_format value {self.kb_format}') self.inverted_index_builder() self.save() else: self.load() def load_freq_dict(self, freq_dict_filename: str): with open(str(expand_path(freq_dict_filename)), 'r') as fl: lines = fl.readlines() pos_freq_dict = defaultdict(list) for line in lines: line_split = line.strip('\n').split('\t') if re.match("[\d]+\.[\d]+", line_split[2]): pos_freq_dict[line_split[1]].append( (line_split[0], float(line_split[2]))) nouns_with_freq = pos_freq_dict["s"] self.nouns_dict = {noun: freq for noun, freq in nouns_with_freq} def load(self) -> None: self.inverted_index = load_pickle(self.load_path / self.inverted_index_filename) self.entities_list = load_pickle(self.load_path / self.entities_list_filename) self.q2name = load_pickle(self.load_path / self.q2name_filename) if self.who_entities_filename: self.who_entities = load_pickle(self.load_path / self.who_entities_filename) if self.freq_dict_filename: self.load_freq_dict(self.freq_dict_filename) def save(self) -> None: save_pickle(self.inverted_index, self.save_path / self.inverted_index_filename) save_pickle(self.entities_list, self.save_path / self.entities_list_filename) save_pickle(self.q2name, self.save_path / self.q2name_filename) if self.q2descr_filename is not None: save_pickle(self.q2descr, self.save_path / self.q2descr_filename) def __call__( self, entity_substr_batch: List[List[str]], entity_positions_batch: List[List[List[int]]] = None, context_tokens: List[List[str]] = None ) -> Tuple[List[List[List[str]]], List[List[List[float]]]]: entity_ids_batch = [] confidences_batch = [] if entity_positions_batch is None: entity_positions_batch = [[[0] for i in range(len(entities_list))] for entities_list in entity_substr_batch] for entity_substr_list, entity_positions_list in zip( entity_substr_batch, entity_positions_batch): entity_ids_list = [] confidences_list = [] for entity_substr, entity_pos in zip(entity_substr_list, entity_positions_list): context = "" if self.use_descriptions: if self.include_mention: context = ' '.join( context_tokens[:entity_pos[0]] + ["[ENT]"] + context_tokens[entity_pos[0]:entity_pos[-1] + 1] + ["[ENT]"] + context_tokens[entity_pos[-1] + 1:]) else: context = ' '.join(context_tokens[:entity_pos[0]] + ["[ENT]"] + context_tokens[entity_pos[-1] + 1:]) entity_ids, confidences = self.link_entity( entity_substr, context) entity_ids_list.append(entity_ids) confidences_list.append(confidences) entity_ids_batch.append(entity_ids_list) confidences_batch.append(confidences_list) return entity_ids_batch, confidences_batch def link_entity(self, entity: str, context: Optional[str] = None, template_found: Optional[str] = None, cut_entity: bool = False) -> Tuple[List[str], List[float]]: confidences = [] if not entity: entities_ids = ['None'] else: candidate_entities = self.candidate_entities_inverted_index(entity) if cut_entity and candidate_entities and len( entity.split()) > 1 and candidate_entities[0][3] == 1: entity = self.cut_entity_substr(entity) candidate_entities = self.candidate_entities_inverted_index( entity) candidate_entities, candidate_names = self.candidate_entities_names( entity, candidate_entities) entities_ids, confidences, srtd_cand_ent = self.sort_found_entities( candidate_entities, candidate_names, entity, context) if template_found: entities_ids = self.filter_entities(entities_ids, template_found) return entities_ids, confidences def cut_entity_substr(self, entity: str): word_tokens = nltk.word_tokenize(entity.lower()) word_tokens = [ word for word in word_tokens if word not in self.stopwords ] normal_form_tokens = [ self.morph.parse(word)[0].normal_form for word in word_tokens ] words_with_freq = [(word, self.nouns_dict.get(word, 0.0)) for word in normal_form_tokens] words_with_freq = sorted(words_with_freq, key=lambda x: x[1]) return words_with_freq[0][0] def candidate_entities_inverted_index( self, entity: str) -> List[Tuple[Any, Any, Any]]: word_tokens = nltk.word_tokenize(entity.lower()) word_tokens = [ word for word in word_tokens if word not in self.stopwords ] candidate_entities = [] for tok in word_tokens: if len(tok) > 1: found = False if tok in self.inverted_index: candidate_entities += self.inverted_index[tok] found = True if self.lemmatize: morph_parse_tok = self.morph.parse(tok)[0] lemmatized_tok = morph_parse_tok.normal_form if lemmatized_tok != tok and lemmatized_tok in self.inverted_index: candidate_entities += self.inverted_index[ lemmatized_tok] found = True if not found and self.use_prefix_tree: words_with_levens_1 = self.searcher.search(tok, d=1) for word in words_with_levens_1: candidate_entities += self.inverted_index[word[0]] candidate_entities = Counter(candidate_entities).most_common() candidate_entities = [(entity_num, self.entities_list[entity_num], entity_freq, count) for \ (entity_num, entity_freq), count in candidate_entities] return candidate_entities def sort_found_entities( self, candidate_entities: List[Tuple[int, str, int]], candidate_names: List[List[str]], entity: str, context: str = None ) -> Tuple[List[str], List[float], List[Tuple[str, str, int, int]]]: entities_ratios = [] for candidate, entity_names in zip(candidate_entities, candidate_names): entity_num, entity_id, num_rels, tokens_matched = candidate fuzz_ratio = max( [fuzz.ratio(name.lower(), entity) for name in entity_names]) entities_ratios.append( (entity_num, entity_id, tokens_matched, fuzz_ratio, num_rels)) srtd_with_ratios = sorted(entities_ratios, key=lambda x: (x[2], x[3], x[4]), reverse=True) if self.use_descriptions: log.debug(f"context {context}") id_to_score = { entity_id: (tokens_matched, score) for _, entity_id, tokens_matched, score, _ in srtd_with_ratios[:30] } entity_ids = [ entity_id for _, entity_id, _, _, _ in srtd_with_ratios[:30] ] scores = self.entity_ranker.rank_rels(context, entity_ids) entities_with_scores = [(entity_id, id_to_score[entity_id][0], id_to_score[entity_id][1], score) for entity_id, score in scores] entities_with_scores = sorted(entities_with_scores, key=lambda x: (x[1], x[2], x[3]), reverse=True) entities_with_scores = [entity for entity in entities_with_scores if \ (entity[3] > self.descr_rank_score_thres or entity[2] == 100.0)] log.debug(f"entities_with_scores {entities_with_scores[:10]}") entity_ids = [entity for entity, _, _, _ in entities_with_scores] confidences = [score for _, _, _, score in entities_with_scores] else: entity_ids = [ent[1] for ent in srtd_with_ratios] confidences = [float(ent[2]) * 0.01 for ent in srtd_with_ratios] return entity_ids, confidences, srtd_with_ratios def candidate_entities_names( self, entity: str, candidate_entities: List[Tuple[int, str, int]] ) -> Tuple[List[Tuple[int, str, int]], List[List[str]]]: entity_length = len(entity) candidate_names = [] candidate_entities_filter = [] for candidate in candidate_entities: entity_num = candidate[0] entity_names = [] entity_names_found = self.q2name[entity_num] if len(entity_names_found[0]) < 6 * entity_length: entity_name = entity_names_found[0] entity_names.append(entity_name) if len(entity_names_found) > 1: for alias in entity_names_found[1:]: entity_names.append(alias) candidate_names.append(entity_names) candidate_entities_filter.append(candidate) return candidate_entities_filter, candidate_names def inverted_index_builder(self) -> None: log.debug("building inverted index") entities_set = set() id_to_label_dict = defaultdict(list) id_to_descr_dict = {} label_to_id_dict = {} label_triplets = [] alias_triplets_list = [] descr_triplets = [] if self.kb_format == "hdt": label_triplets, c = self.doc.search_triples("", self.label_rel, "") if self.aliases_rels is not None: for alias_rel in self.aliases_rels: alias_triplets, c = self.doc.search_triples( "", alias_rel, "") alias_triplets_list.append(alias_triplets) if self.descr_rel is not None: descr_triplets, c = self.doc.search_triples( "", self.descr_rel, "") if self.kb_format == "sqlite3": subject, relation, obj = self.sql_column_names query = f'SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} '\ f'WHERE {relation} = "{self.label_rel}";' res = self.cursor.execute(query) label_triplets = res.fetchall() if self.aliases_rels is not None: for alias_rel in self.aliases_rels: query = f'SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} '\ f'WHERE {relation} = "{alias_rel}";' res = self.cursor.execute(query) alias_triplets = res.fetchall() alias_triplets_list.append(alias_triplets) if self.descr_rel is not None: query = f'SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} '\ f'WHERE {relation} = "{self.descr_rel}";' res = self.cursor.execute(query) descr_triplets = res.fetchall() for triplets in [label_triplets] + alias_triplets_list: for triplet in triplets: entities_set.add(triplet[0]) if triplet[2].endswith(self.lang_str): label = triplet[2].replace(self.lang_str, '').replace('"', '') id_to_label_dict[triplet[0]].append(label) label_to_id_dict[label] = triplet[0] for triplet in descr_triplets: entities_set.add(triplet[0]) if triplet[2].endswith(self.lang_str): descr = triplet[2].replace(self.lang_str, '').replace('"', '') id_to_descr_dict[triplet[0]].append(descr) popularities_dict = {} for entity in entities_set: if self.kb_format == "hdt": all_triplets, number_of_triplets = self.doc.search_triples( entity, "", "") popularities_dict[entity] = number_of_triplets if self.kb_format == "sqlite3": subject, relation, obj = self.sql_column_names query = f'SELECT COUNT({obj}) FROM {self.sql_table_name} WHERE {subject} = "{entity}";' res = self.cursor.execute(query) popularities_dict[entity] = res.fetchall()[0][0] entities_dict = {entity: n for n, entity in enumerate(entities_set)} inverted_index = defaultdict(list) for label in label_to_id_dict: tokens = re.findall(self.re_tokenizer, label.lower()) for tok in tokens: if len(tok) > 1 and tok not in self.stopwords: inverted_index[tok].append( (entities_dict[label_to_id_dict[label]], popularities_dict[label_to_id_dict[label]])) self.inverted_index = dict(inverted_index) self.entities_list = list(entities_set) self.q2name = [ id_to_label_dict[entity] for entity in self.entities_list ] self.q2descr = [] if id_to_descr_dict: self.q2descr = [ id_to_descr_dict[entity] for entity in self.entities_list ] def filter_entities(self, entities: List[str], template_found: str) -> List[str]: if template_found in ["who is xxx?", "who was xxx?"]: entities = [ entity for entity in entities if entity in self.who_entities ] if template_found in ["what is xxx?", "what was xxx?"]: entities = [ entity for entity in entities if entity not in self.who_entities ] return entities
def __init__(self, load_path: str, inverted_index_filename: str, entities_list_filename: str, q2name_filename: str, types_dict_filename: Optional[str] = None, who_entities_filename: Optional[str] = None, save_path: str = None, q2descr_filename: str = None, descr_rank_score_thres: float = 0.01, freq_dict_filename: Optional[str] = None, entity_ranker: RelRankerBertInfer = None, build_inverted_index: bool = False, kb_format: str = "hdt", kb_filename: str = None, label_rel: str = None, descr_rel: str = None, aliases_rels: List[str] = None, sql_table_name: str = None, sql_column_names: List[str] = None, lang: str = "en", use_descriptions: bool = False, include_mention: bool = False, num_entities_to_return: int = 5, lemmatize: bool = False, use_prefix_tree: bool = False, **kwargs) -> None: """ Args: load_path: path to folder with inverted index files inverted_index_filename: file with dict of words (keys) and entities containing these words entities_list_filename: file with the list of entities from the knowledge base q2name_filename: file which maps entity id to name types_dict_filename: file with types of entities who_entities_filename: file with the list of entities in Wikidata, which can be answers to questions with "Who" pronoun, i.e. humans, literary characters etc. save_path: path where to save inverted index files q2descr_filename: name of file which maps entity id to description descr_rank_score_thres: if the score of the entity description is less than threshold, the entity is not added to output list freq_dict_filename: filename with frequences dictionary of Russian words entity_ranker: component deeppavlov.models.kbqa.rel_ranker_bert_infer build_inverted_index: if "true", inverted index of entities of the KB will be built kb_format: "hdt" or "sqlite3" kb_filename: file with the knowledge base, which will be used for building of inverted index label_rel: relation in the knowledge base which connects entity ids and entity titles descr_rel: relation in the knowledge base which connects entity ids and entity descriptions aliases_rels: list of relations which connect entity ids and entity aliases sql_table_name: name of the table with the KB if the KB is in sqlite3 format sql_column_names: names of columns with subject, relation and object lang: language used use_descriptions: whether to use context and descriptions of entities for entity ranking include_mention: whether to leave or delete entity mention from the sentence before passing to BERT ranker num_entities_to_return: how many entities for each substring the system returns lemmatize: whether to lemmatize tokens of extracted entity use_prefix_tree: whether to use prefix tree for search of entities with typos in entity labels **kwargs: """ super().__init__(save_path=save_path, load_path=load_path) self.morph = pymorphy2.MorphAnalyzer() self.lemmatize = lemmatize self.use_prefix_tree = use_prefix_tree self.inverted_index_filename = inverted_index_filename self.entities_list_filename = entities_list_filename self.build_inverted_index = build_inverted_index self.q2name_filename = q2name_filename self.types_dict_filename = types_dict_filename self.who_entities_filename = who_entities_filename self.q2descr_filename = q2descr_filename self.descr_rank_score_thres = descr_rank_score_thres self.freq_dict_filename = freq_dict_filename self.kb_format = kb_format self.kb_filename = kb_filename self.label_rel = label_rel self.aliases_rels = aliases_rels self.descr_rel = descr_rel self.sql_table_name = sql_table_name self.sql_column_names = sql_column_names self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None self.entities_index: Optional[List[str]] = None self.q2name: Optional[List[Tuple[str]]] = None self.types_dict: Optional[Dict[str, List[str]]] = None self.lang_str = f"@{lang}" if self.lang_str == "@en": self.stopwords = set(stopwords.words("english")) elif self.lang_str == "@ru": self.stopwords = set(stopwords.words("russian")) self.re_tokenizer = re.compile(r"[\w']+|[^\w ]") self.entity_ranker = entity_ranker self.use_descriptions = use_descriptions self.include_mention = include_mention self.num_entities_to_return = num_entities_to_return if self.use_descriptions and self.entity_ranker is None: raise ValueError("No entity ranker is provided!") if self.use_prefix_tree: alphabet = "!#%\&'()+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz½¿ÁÄ" + \ "ÅÆÇÉÎÓÖ×ÚßàáâãäåæçèéêëíîïðñòóôöøùúûüýāăąćČčĐėęěĞğĩīİıŁłńňŌōőřŚśşŠšťũūůŵźŻżŽžơưșȚțəʻ" + \ "ʿΠΡβγБМавдежикмностъяḤḥṇṬṭầếờợ–‘’Ⅲ−∗" dictionary_words = list(self.inverted_index.keys()) self.searcher = LevenshteinSearcher(alphabet, dictionary_words) if self.build_inverted_index: if self.kb_format == "hdt": self.doc = HDTDocument(str(expand_path(self.kb_filename))) elif self.kb_format == "sqlite3": self.conn = sqlite3.connect(str(expand_path(self.kb_filename))) self.cursor = self.conn.cursor() else: raise ValueError( f'unsupported kb_format value {self.kb_format}') self.inverted_index_builder() self.save() else: self.load()
class EntityLinker(Component, Serializable): """ This class extracts from Wikidata candidate entities for the entity mentioned in the question and then extracts triplets from Wikidata for the extracted entity. Candidate entities are searched in the dictionary where keys are titles and aliases of Wikidata entities and values are lists of tuples (entity_title, entity_id, number_of_relations). First candidate entities are searched in the dictionary by keys where the keys are entities extracted from the question, if nothing is found entities are searched in the dictionary using Levenstein distance between the entity and keys (titles) in the dictionary. """ LANGUAGES = set(['rus']) def __init__(self, load_path: str, wiki_filename: str, entities_filename: str, inverted_index_filename: str, id_to_name_file: str, lemmatize: bool = True, debug: bool = False, rule_filter_entities: bool = True, use_inverted_index: bool = True, language: str = 'rus', *args, **kwargs) -> None: """ Args: load_path: path to folder with wikidata files wiki_filename: file with Wikidata triplets entities_filename: file with dict of entity titles (keys) and entity ids (values) inverted_index_filename: file with dict of words (keys) and entities containing these words (values) id_to_name_file: file with dict of entity ids (keys) and entities names and aliases (values) lemmatize: whether to lemmatize tokens of extracted entity debug: whether to print entities extracted from Wikidata rule_filter_entities: whether to filter entities which do not fit the question use_inverted_index: whether to use inverted index for entity linking language - the language of the linker (used for filtration of some questions to improve overall performance) *args: **kwargs: """ super().__init__(save_path=None, load_path=load_path) self.morph = pymorphy2.MorphAnalyzer() self.lemmatize = lemmatize self.debug = debug self.rule_filter_entities = rule_filter_entities self.use_inverted_index = use_inverted_index self._language = language if language not in self.LANGUAGES: log.warning( f'EntityLinker supports only the following languages: {self.LANGUAGES}' ) self._wiki_filename = wiki_filename self._entities_filename = entities_filename self.inverted_index_filename = inverted_index_filename self.id_to_name_file = id_to_name_file self.name_to_q: Optional[Dict[str, List[Tuple[str]]]] = None self.wikidata: Optional[Dict[str, List[List[str]]]] = None self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None self.id_to_name: Optional[Dict[str, Dict[List[str]]]] = None self.load() if self.use_inverted_index: alphabet = "abcdefghijklmnopqrstuvwxyzабвгдеёжзийклмнопрстуфхцчшщъыьэюя1234567890-_()=+!?.,/;:&@<>|#$%^*" dictionary_words = list(self.inverted_index.keys()) self.searcher = LevenshteinSearcher(alphabet, dictionary_words) def load(self) -> None: if self.use_inverted_index: with open(self.load_path / self.inverted_index_filename, 'rb') as inv: self.inverted_index = pickle.load(inv) self.inverted_index: Dict[str, List[Tuple[str]]] with open(self.load_path / self.id_to_name_file, 'rb') as i2n: self.id_to_name = pickle.load(i2n) self.id_to_name: Dict[str, Dict[List[str]]] else: with open(self.load_path / self._entities_filename, 'rb') as e: self.name_to_q = pickle.load(e) self.name_to_q: Dict[str, List[Tuple[str]]] with open(self.load_path / self._wiki_filename, 'rb') as w: self.wikidata = pickle.load(w) self.wikidata: Dict[str, List[List[str]]] def save(self) -> None: pass def __call__( self, entity: str, question_tokens: List[str] ) -> Tuple[List[List[List[str]]], List[str]]: confidences = [] srtd_cand_ent = [] if not entity: wiki_entities = ['None'] else: if self.use_inverted_index: candidate_entities = self.candidate_entities_inverted_index( entity) candidate_names = self.candidate_entities_names( candidate_entities) wiki_entities, confidences, srtd_cand_ent = self.sort_found_entities( candidate_entities, candidate_names, entity) else: candidate_entities = self.find_candidate_entities(entity) srtd_cand_ent = sorted(candidate_entities, key=lambda x: x[2], reverse=True) if len(srtd_cand_ent) > 0: wiki_entities = [ent[1] for ent in srtd_cand_ent] confidences = [1.0 for i in range(len(srtd_cand_ent))] srtd_cand_ent = [ (ent[0], ent[1], conf, ent[2]) for ent, conf in zip(srtd_cand_ent, confidences) ] else: candidates = self.fuzzy_entity_search(entity) candidates = list(set(candidates)) srtd_cand_ent = [(ent[0][0], ent[0][1], ent[1], ent[0][2]) for ent in candidates] srtd_cand_ent = sorted(srtd_cand_ent, key=lambda x: (x[2], x[3]), reverse=True) if len(srtd_cand_ent) > 0: wiki_entities = [ent[1] for ent in srtd_cand_ent] confidences = [ float(ent[2]) * 0.01 for ent in srtd_cand_ent ] else: wiki_entities = ["None"] confidences = [0.0] entity_triplets = self.extract_triplets_from_wiki(wiki_entities) if self.rule_filter_entities and self._language == 'rus': filtered_entities, filtered_entity_triplets = self.filter_triplets_rus( entity_triplets, question_tokens, srtd_cand_ent) if self.debug: self._log_entities(filtered_entities[:10]) return filtered_entity_triplets, confidences def _log_entities(self, srtd_cand_ent): entities_to_print = [] for name, q, ratio, n_rel in srtd_cand_ent: entities_to_print.append( f'{name}, http://wikidata.org/wiki/{q}, {ratio}, {n_rel}') log.debug('\n' + '\n'.join(entities_to_print)) def find_candidate_entities(self, entity: str) -> List[str]: candidate_entities = list(self.name_to_q.get(entity, [])) entity_split = entity.split(' ') if len(entity_split) < 6 and self.lemmatize: entity_lemm_tokens = [] for tok in entity_split: morph_parse_tok = self.morph.parse(tok)[0] lemmatized_tok = morph_parse_tok.normal_form entity_lemm_tokens.append(lemmatized_tok) masks = itertools.product([False, True], repeat=len(entity_split)) for mask in masks: entity_lemm = [] for i in range(len(entity_split)): if mask[i]: entity_lemm.append(entity_split[i]) else: entity_lemm.append(entity_lemm_tokens[i]) entity_lemm = ' '.join(entity_lemm) if entity_lemm != entity: candidate_entities += self.name_to_q.get(entity_lemm, []) candidate_entities = list(set(candidate_entities)) return candidate_entities def fuzzy_entity_search(self, entity: str) -> List[Tuple[Tuple, str]]: word_length = len(entity) candidates = [] for title in self.name_to_q: length_ratio = len(title) / word_length if length_ratio > 0.75 and length_ratio < 1.25: ratio = fuzz.ratio(title, entity) if ratio > 70: entity_candidates = self.name_to_q.get(title, []) for cand in entity_candidates: candidates.append((cand, fuzz.ratio(entity, cand[0]))) return candidates def extract_triplets_from_wiki( self, entity_ids: List[str]) -> List[List[List[str]]]: entity_triplets = [] for entity_id in entity_ids: if entity_id in self.wikidata and entity_id.startswith('Q'): triplets_for_entity = self.wikidata[entity_id] entity_triplets.append(triplets_for_entity) else: entity_triplets.append([]) return entity_triplets @staticmethod def filter_triplets_rus( entity_triplets: List[List[List[str]]], question_tokens: List[str], srtd_cand_ent: List[Tuple[str]] ) -> Tuple[List[Tuple[str]], List[List[List[str]]]]: question = ' '.join(question_tokens).lower() what_template = 'что ' found_what_template = False found_what_template = question.find(what_template) > -1 filtered_entity_triplets = [] filtered_entities = [] for wiki_entity, triplets_for_entity in zip(srtd_cand_ent, entity_triplets): entity_is_human = False entity_is_asteroid = False entity_is_named = False entity_title = wiki_entity[0] if entity_title[0].isupper(): entity_is_named = True property_is_instance_of = 'P31' id_for_entity_human = 'Q5' id_for_entity_asteroid = 'Q3863' for triplet in triplets_for_entity: if triplet[0] == property_is_instance_of and triplet[ 1] == id_for_entity_human: entity_is_human = True break if triplet[0] == property_is_instance_of and triplet[ 1] == id_for_entity_asteroid: entity_is_asteroid = True break if found_what_template and (entity_is_human or entity_is_named or entity_is_asteroid or wiki_entity[2] < 90): continue filtered_entity_triplets.append(triplets_for_entity) filtered_entities.append(wiki_entity) return filtered_entities, filtered_entity_triplets def candidate_entities_inverted_index(self, entity: str) -> List[Tuple[str]]: word_tokens = nltk.word_tokenize(entity) candidate_entities = [] for tok in word_tokens: if len(tok) > 1: found = False if tok in self.inverted_index: candidate_entities += self.inverted_index[tok] found = True morph_parse_tok = self.morph.parse(tok)[0] lemmatized_tok = morph_parse_tok.normal_form if lemmatized_tok != tok and lemmatized_tok in self.inverted_index: candidate_entities += self.inverted_index[lemmatized_tok] found = True if not found: words_with_levens_1 = self.searcher.search(tok, d=1) for word in words_with_levens_1: candidate_entities += self.inverted_index[word[0]] candidate_entities = list(set(candidate_entities)) return candidate_entities def candidate_entities_names( self, candidate_entities: List[Tuple[str]]) -> List[List[str]]: candidate_names = [] for candidate in candidate_entities: entity_id = candidate[0] entity_names = [self.id_to_name[entity_id]["name"]] if "aliases" in self.id_to_name[entity_id].keys(): aliases = self.id_to_name[entity_id]["aliases"] for alias in aliases: entity_names.append(alias) candidate_names.append(entity_names) return candidate_names def sort_found_entities( self, candidate_entities: List[Tuple[str]], candidate_names: List[List[str]], entity: str) -> Tuple[List[str], List[str], List[Tuple[str]]]: entities_ratios = [] for candidate, entity_names in zip(candidate_entities, candidate_names): entity_id = candidate[0] num_rels = candidate[1] entity_name = entity_names[0] morph_parse_entity = self.morph.parse(entity)[0] lemm_entity = morph_parse_entity.normal_form fuzz_ratio_lemm = max([ fuzz.ratio(name.lower(), lemm_entity.lower()) for name in entity_names ]) fuzz_ratio_nolemm = max([ fuzz.ratio(name.lower(), entity.lower()) for name in entity_names ]) fuzz_ratio = max(fuzz_ratio_lemm, fuzz_ratio_nolemm) entities_ratios.append( (entity_name, entity_id, fuzz_ratio, num_rels)) srtd_with_ratios = sorted(entities_ratios, key=lambda x: (x[2], x[3]), reverse=True) wiki_entities = [ent[1] for ent in srtd_with_ratios if ent[2] > 84] confidences = [ float(ent[2]) * 0.01 for ent in srtd_with_ratios if ent[2] > 84 ] return wiki_entities, confidences, srtd_with_ratios
class KBEntityLinker(Component, Serializable): """ This class extracts from the knowledge base candidate entities for the entity mentioned in the question and then extracts triplets from Wikidata for the extracted entity. Candidate entities are searched in the dictionary where keys are titles and aliases of Wikidata entities and values are lists of tuples (entity_title, entity_id, number_of_relations). First candidate entities are searched in the dictionary by keys where the keys are entities extracted from the question, if nothing is found entities are searched in the dictionary using Levenstein distance between the entity and keys (titles) in the dictionary. """ def __init__( self, load_path: str, inverted_index_filename: str, entities_list_filename: str, q2name_filename: str, types_dict_filename: Optional[str] = None, who_entities_filename: Optional[str] = None, save_path: str = None, q2descr_filename: str = None, descr_rank_score_thres: float = 0.5, freq_dict_filename: Optional[str] = None, entity_ranker: RelRankerInfer = None, build_inverted_index: bool = False, kb_format: str = "hdt", kb_filename: str = None, label_rel: str = None, descr_rel: str = None, aliases_rels: List[str] = None, sql_table_name: str = None, sql_column_names: List[str] = None, lang: str = "en", use_descriptions: bool = False, include_mention: bool = False, num_entities_to_return: int = 5, num_entities_for_bert_ranking: int = 100, lemmatize: bool = False, use_prefix_tree: bool = False, **kwargs, ) -> None: """ Args: load_path: path to folder with inverted index files inverted_index_filename: file with dict of words (keys) and entities containing these words entities_list_filename: file with the list of entities from the knowledge base q2name_filename: file which maps entity id to name types_dict_filename: file with types of entities who_entities_filename: file with the list of entities in Wikidata, which can be answers to questions with "Who" pronoun, i.e. humans, literary characters etc. save_path: path where to save inverted index files q2descr_filename: name of file which maps entity id to description descr_rank_score_thres: if the score of the entity description is less than threshold, the entity is not added to output list freq_dict_filename: filename with frequences dictionary of Russian words entity_ranker: component deeppavlov.models.kbqa.rel_ranker_infer build_inverted_index: if "true", inverted index of entities of the KB will be built kb_format: "hdt" or "sqlite3" kb_filename: file with the knowledge base, which will be used for building of inverted index label_rel: relation in the knowledge base which connects entity ids and entity titles descr_rel: relation in the knowledge base which connects entity ids and entity descriptions aliases_rels: list of relations which connect entity ids and entity aliases sql_table_name: name of the table with the KB if the KB is in sqlite3 format sql_column_names: names of columns with subject, relation and object lang: language used use_descriptions: whether to use context and descriptions of entities for entity ranking include_mention: whether to leave or delete entity mention from the sentence before passing to BERT ranker num_entities_to_return: how many entities for each substring the system returns lemmatize: whether to lemmatize tokens of extracted entity use_prefix_tree: whether to use prefix tree for search of entities with typos in entity labels **kwargs: """ super().__init__(save_path=save_path, load_path=load_path) self.morph = pymorphy2.MorphAnalyzer() self.lemmatize = lemmatize self.use_prefix_tree = use_prefix_tree self.inverted_index_filename = inverted_index_filename self.entities_list_filename = entities_list_filename self.build_inverted_index = build_inverted_index self.q2name_filename = q2name_filename self.types_dict_filename = types_dict_filename self.who_entities_filename = who_entities_filename self.q2descr_filename = q2descr_filename self.descr_rank_score_thres = descr_rank_score_thres self.freq_dict_filename = freq_dict_filename self.kb_format = kb_format self.kb_filename = kb_filename self.label_rel = label_rel self.aliases_rels = aliases_rels self.descr_rel = descr_rel self.sql_table_name = sql_table_name self.sql_column_names = sql_column_names self.inverted_index: Optional[Dict[str, List[Tuple[str]]]] = None self.entities_index: Optional[List[str]] = None self.q2name: Optional[List[Tuple[str]]] = None self.types_dict: Optional[Dict[str, List[str]]] = None self.lang_str = f"@{lang}" if self.lang_str == "@en": self.stopwords = set(stopwords.words("english")) elif self.lang_str == "@ru": self.stopwords = set(stopwords.words("russian")) self.re_tokenizer = re.compile(r"[\w']+|[^\w ]") self.entity_ranker = entity_ranker self.nlp = en_core_web_sm.load() self.inflect_engine = inflect.engine() self.use_descriptions = use_descriptions self.include_mention = include_mention self.num_entities_to_return = num_entities_to_return self.num_entities_for_bert_ranking = num_entities_for_bert_ranking self.black_list_what_is = { "Q277759", # book series "Q11424", # film "Q7889", # video game "Q2743", # musical theatre "Q5398426", # tv series "Q506240", # television film "Q21191270", # television series episode "Q7725634", # literary work "Q131436", # board game "Q1783817", # cooperative board game } if self.use_descriptions and self.entity_ranker is None: raise ValueError("No entity ranker is provided!") if self.use_prefix_tree: alphabet = ( r"!#%\&'()+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz½¿ÁÄ" + "ÅÆÇÉÎÓÖ×ÚßàáâãäåæçèéêëíîïðñòóôöøùúûüýāăąćČčĐėęěĞğĩīİıŁłńňŌōőřŚśşŠšťũūůŵźŻżŽžơưșȚțəʻ" + "ʿΠΡβγБМавдежикмностъяḤḥṇṬṭầếờợ–‘’Ⅲ−∗") dictionary_words = list(self.inverted_index.keys()) self.searcher = LevenshteinSearcher(alphabet, dictionary_words) if self.build_inverted_index: if self.kb_format == "hdt": self.doc = HDTDocument(str(expand_path(self.kb_filename))) elif self.kb_format == "sqlite3": self.conn = sqlite3.connect(str(expand_path(self.kb_filename))) self.cursor = self.conn.cursor() else: raise ValueError( f"unsupported kb_format value {self.kb_format}") self.inverted_index_builder() self.save() else: self.load() def load_freq_dict(self, freq_dict_filename: str): with open(str(expand_path(freq_dict_filename)), "r") as fl: lines = fl.readlines() pos_freq_dict = defaultdict(list) for line in lines: line_split = line.strip("\n").split("\t") if re.match(r"[\d]+\.[\d]+", line_split[2]): pos_freq_dict[line_split[1]].append( (line_split[0], float(line_split[2]))) nouns_with_freq = pos_freq_dict["s"] self.nouns_dict = {noun: freq for noun, freq in nouns_with_freq} def load(self) -> None: self.inverted_index = load_pickle(self.load_path / self.inverted_index_filename) self.entities_list = load_pickle(self.load_path / self.entities_list_filename) self.q2name = load_pickle(self.load_path / self.q2name_filename) if self.who_entities_filename: self.who_entities = load_pickle(self.load_path / self.who_entities_filename) if self.freq_dict_filename: self.load_freq_dict(self.freq_dict_filename) if self.types_dict_filename: self.types_dict = load_pickle(self.load_path / self.types_dict_filename) def save(self) -> None: save_pickle(self.inverted_index, self.save_path / self.inverted_index_filename) save_pickle(self.entities_list, self.save_path / self.entities_list_filename) save_pickle(self.q2name, self.save_path / self.q2name_filename) if self.q2descr_filename is not None: save_pickle(self.q2descr, self.save_path / self.q2descr_filename) def __call__( self, entity_substr_batch: List[List[str]], templates_batch: List[str] = None, long_context_batch: List[str] = None, entity_types_batch: List[List[List[str]]] = None, short_context_batch: List[str] = None, ) -> Tuple[List[List[List[str]]], List[List[List[float]]]]: entity_ids_batch = [] confidences_batch = [] tokens_match_conf_batch = [] if templates_batch is None: templates_batch = ["" for _ in entity_substr_batch] if long_context_batch is None: long_context_batch = ["" for _ in entity_substr_batch] if short_context_batch is None: short_context_batch = ["" for _ in entity_substr_batch] if entity_types_batch is None: entity_types_batch = [[[] for _ in entity_substr_list] for entity_substr_list in entity_substr_batch ] for entity_substr_list, template_found, long_context, entity_types_list, short_context in zip( entity_substr_batch, templates_batch, long_context_batch, entity_types_batch, short_context_batch): entity_ids_list = [] confidences_list = [] tokens_match_conf_list = [] for entity_substr, entity_types in zip(entity_substr_list, entity_types_list): entity_ids, confidences, tokens_match_conf = self.link_entity( entity_substr, long_context, short_context, template_found, entity_types) if self.num_entities_to_return == 1: if entity_ids: entity_ids_list.append(entity_ids[0]) confidences_list.append(confidences[0]) tokens_match_conf_list.append(tokens_match_conf[0]) else: entity_ids_list.append("") confidences_list.append(0.0) tokens_match_conf_list.append(0.0) else: entity_ids_list.append( entity_ids[:self.num_entities_to_return]) confidences_list.append( confidences[:self.num_entities_to_return]) tokens_match_conf_list.append( tokens_match_conf[:self.num_entities_to_return]) entity_ids_batch.append(entity_ids_list) confidences_batch.append(confidences_list) tokens_match_conf_batch.append(tokens_match_conf_list) return entity_ids_batch, confidences_batch, tokens_match_conf_batch def lemmatize_substr(self, text): lemm_text = "" if text: pr_text = self.nlp(text) processed_tokens = [] for token in pr_text: if token.tag_ in [ "NNS", "NNP" ] and self.inflect_engine.singular_noun(token.text): processed_tokens.append( self.inflect_engine.singular_noun(token.text)) else: processed_tokens.append(token.text) lemm_text = " ".join(processed_tokens) return lemm_text def link_entity( self, entity: str, long_context: Optional[str] = None, short_context: Optional[str] = None, template_found: Optional[str] = None, entity_types: List[str] = None, cut_entity: bool = False, ) -> Tuple[List[str], List[float]]: confidences = [] tokens_match_conf = [] if not entity: entities_ids = ["None"] else: entity_is_uttr = False lets_talk_phrases = [ "let's talk", "let's chat", "what about", "do you know", "tell me about" ] found_lets_talk_phrase = any( [phrase in short_context for phrase in lets_talk_phrases]) if (short_context and (entity == short_context or entity == short_context[:-1] or found_lets_talk_phrase) and len(entity.split()) == 1): lemm_entity = self.lemmatize_substr(entity) entity_is_uttr = True else: lemm_entity = entity candidate_entities = self.candidate_entities_inverted_index( lemm_entity) if self.types_dict: if entity_types: entity_types = set(entity_types) candidate_entities = [ ent for ent in candidate_entities if self.types_dict.get( ent[1], set()).intersection(entity_types) ] if template_found in ["what is xxx?", "what was xxx?" ] or entity_is_uttr: candidate_entities_filtered = [ ent for ent in candidate_entities if not self.types_dict.get(ent[1], set()).intersection( self.black_list_what_is) ] if candidate_entities_filtered: candidate_entities = candidate_entities_filtered if cut_entity and candidate_entities and len( lemm_entity.split()) > 1 and candidate_entities[0][3] == 1: lemm_entity = self.cut_entity_substr(lemm_entity) candidate_entities = self.candidate_entities_inverted_index( lemm_entity) candidate_entities, candidate_names = self.candidate_entities_names( lemm_entity, candidate_entities) entities_ids, confidences, tokens_match_conf, srtd_cand_ent = self.sort_found_entities( candidate_entities, candidate_names, lemm_entity, entity, long_context) if template_found: entities_ids = self.filter_entities(entities_ids, template_found) return entities_ids, confidences, tokens_match_conf def cut_entity_substr(self, entity: str): word_tokens = nltk.word_tokenize(entity.lower()) word_tokens = [ word for word in word_tokens if word not in self.stopwords ] normal_form_tokens = [ self.morph.parse(word)[0].normal_form for word in word_tokens ] words_with_freq = [(word, self.nouns_dict.get(word, 0.0)) for word in normal_form_tokens] words_with_freq = sorted(words_with_freq, key=lambda x: x[1]) return words_with_freq[0][0] def candidate_entities_inverted_index( self, entity: str) -> List[Tuple[Any, Any, Any]]: word_tokens = nltk.word_tokenize(entity.lower()) word_tokens = [ word for word in word_tokens if word not in self.stopwords ] candidate_entities = [] candidate_entities_for_tokens = [] for tok in word_tokens: candidate_entities_for_tok = set() if len(tok) > 1: found = False if tok in self.inverted_index: candidate_entities_for_tok = set(self.inverted_index[tok]) found = True if self.lemmatize: if self.lang_str == "@ru": morph_parse_tok = self.morph.parse(tok)[0] lemmatized_tok = morph_parse_tok.normal_form if self.lang_str == "@en": lemmatized_tok = self.lemmatizer.lemmatize(tok) if lemmatized_tok != tok and lemmatized_tok in self.inverted_index: candidate_entities_for_tok = candidate_entities_for_tok.union( set(self.inverted_index[lemmatized_tok])) found = True if not found and self.use_prefix_tree: words_with_levens_1 = self.searcher.search(tok, d=1) for word in words_with_levens_1: candidate_entities_for_tok = candidate_entities_for_tok.union( set(self.inverted_index[word[0]])) candidate_entities_for_tokens.append( candidate_entities_for_tok) for candidate_entities_for_tok in candidate_entities_for_tokens: candidate_entities += list(candidate_entities_for_tok) candidate_entities = Counter(candidate_entities).most_common() candidate_entities = sorted(candidate_entities, key=lambda x: (x[0][1], x[1]), reverse=True) candidate_entities = candidate_entities[:1000] candidate_entities = [ (entity_num, self.entities_list[entity_num], entity_freq, count) for (entity_num, entity_freq), count in candidate_entities ] return candidate_entities def sort_found_entities( self, candidate_entities: List[Tuple[int, str, int]], candidate_names: List[List[str]], lemm_entity: str, entity: str, context: str = None, ) -> Tuple[List[str], List[float], List[Tuple[str, str, int, int]]]: entities_ratios = [] lemm_entity = lemm_entity.lower() for candidate, entity_names in zip(candidate_entities, candidate_names): entity_num, entity_id, num_rels, tokens_matched = candidate fuzz_ratio = max([ fuzz.ratio(name.lower(), lemm_entity) for name in entity_names ]) entity_tokens = re.findall(self.re_tokenizer, entity.lower()) lemm_entity_tokens = re.findall(self.re_tokenizer, lemm_entity.lower()) entity_tokens = { word for word in entity_tokens if (len(word) > 1 and word != "'s" and word not in self.stopwords) } lemm_entity_tokens = { word for word in lemm_entity_tokens if (len(word) > 1 and word != "'s" and word not in self.stopwords) } match_counts = [] for name in entity_names: name_tokens = re.findall(self.re_tokenizer, name.lower()) name_tokens = { word for word in name_tokens if (len(word) > 1 and word != "'s" and word not in self.stopwords) } entity_inters_len = len( entity_tokens.intersection(name_tokens)) lemm_entity_inters_len = len( lemm_entity_tokens.intersection(name_tokens)) entity_ratio_1 = 0.0 entity_ratio_2 = 0.0 if len(entity_tokens): entity_ratio_1 = entity_inters_len / len(entity_tokens) if entity_ratio_1 > 1.0 and entity_ratio_1 != 0.0: entity_ratio_1 = 1.0 / entity_ratio_1 if len(name_tokens): entity_ratio_2 = entity_inters_len / len(name_tokens) if entity_ratio_2 > 1.0 and entity_ratio_2 != 0.0: entity_ratio_2 = 1.0 / entity_ratio_2 lemm_entity_ratio_1 = 0.0 lemm_entity_ratio_2 = 0.0 if len(lemm_entity_tokens): lemm_entity_ratio_1 = lemm_entity_inters_len / len( lemm_entity_tokens) if lemm_entity_ratio_1 > 1.0 and lemm_entity_ratio_1 != 0.0: lemm_entity_ratio_1 = 1.0 / lemm_entity_ratio_1 if len(name_tokens): lemm_entity_ratio_2 = lemm_entity_inters_len / len( name_tokens) if lemm_entity_ratio_2 > 1.0 and lemm_entity_ratio_2 != 0.0: lemm_entity_ratio_2 = 1.0 / lemm_entity_ratio_2 match_count = max(entity_ratio_1, entity_ratio_2, lemm_entity_ratio_1, lemm_entity_ratio_2) match_counts.append(match_count) match_counts = sorted(match_counts, reverse=True) if match_counts: tokens_matched = match_counts[0] else: tokens_matched = 0.0 entities_ratios.append( (entity_num, entity_id, tokens_matched, fuzz_ratio, num_rels)) srtd_with_ratios = sorted(entities_ratios, key=lambda x: (x[2], x[3], x[4]), reverse=True) if self.use_descriptions: log.debug(f"context {context}") id_to_score = { entity_id: (tokens_matched, score, num_rels) for _, entity_id, tokens_matched, score, num_rels in srtd_with_ratios[:self.num_entities_for_bert_ranking] } entity_ids = [ entity_id for _, entity_id, _, _, _ in srtd_with_ratios[:self.num_entities_for_bert_ranking] ] scores = self.entity_ranker.rank_rels(context, entity_ids) entities_with_scores = [ (entity_id, id_to_score[entity_id][0], id_to_score[entity_id][1], id_to_score[entity_id][2], score) for entity_id, score in scores ] entities_with_scores = sorted(entities_with_scores, key=lambda x: (x[1], x[2], x[3], x[4]), reverse=True) entities_with_scores = [ ent for ent in entities_with_scores if (ent[4] > self.descr_rank_score_thres or ent[2] == 100.0 or (ent[1] == 1.0 and ent[2] > 92.0 and ent[3] > 20 and ent[4] > 0.2)) ] log.debug(f"entities_with_scores {entities_with_scores[:10]}") entity_ids = [ent for ent, *_ in entities_with_scores] confidences = [score for *_, score in entities_with_scores] tokens_match_conf = [ ratio for _, ratio, *_ in entities_with_scores ] else: entity_ids = [ent[1] for ent in srtd_with_ratios] confidences = [ent[4] * 0.01 for ent in srtd_with_ratios] tokens_match_conf = [ent[2] for ent in srtd_with_ratios] return entity_ids, confidences, tokens_match_conf, srtd_with_ratios def candidate_entities_names( self, entity: str, candidate_entities: List[Tuple[int, str, int]] ) -> Tuple[List[Tuple[int, str, int]], List[List[str]]]: entity_length = len(entity) candidate_names = [] candidate_entities_filter = [] for candidate in candidate_entities: entity_num = candidate[0] entity_names = [] entity_names_found = self.q2name[entity_num] if len(entity_names_found[0]) < 6 * entity_length: entity_name = entity_names_found[0] entity_names.append(entity_name) if len(entity_names_found) > 1: for alias in entity_names_found[1:]: entity_names.append(alias) candidate_names.append(entity_names) candidate_entities_filter.append(candidate) return candidate_entities_filter, candidate_names def inverted_index_builder(self) -> None: log.debug("building inverted index") entities_set = set() id_to_label_dict = defaultdict(list) id_to_descr_dict = {} label_to_id_dict = {} label_triplets = [] alias_triplets_list = [] descr_triplets = [] if self.kb_format == "hdt": label_triplets, c = self.doc.search_triples("", self.label_rel, "") if self.aliases_rels is not None: for alias_rel in self.aliases_rels: alias_triplets, c = self.doc.search_triples( "", alias_rel, "") alias_triplets_list.append(alias_triplets) if self.descr_rel is not None: descr_triplets, c = self.doc.search_triples( "", self.descr_rel, "") if self.kb_format == "sqlite3": subject, relation, obj = self.sql_column_names query = ( f"SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} " f'WHERE {relation} = "{self.label_rel}";') res = self.cursor.execute(query) label_triplets = res.fetchall() if self.aliases_rels is not None: for alias_rel in self.aliases_rels: query = ( f"SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} " f'WHERE {relation} = "{alias_rel}";') res = self.cursor.execute(query) alias_triplets = res.fetchall() alias_triplets_list.append(alias_triplets) if self.descr_rel is not None: query = ( f"SELECT {subject}, {relation}, {obj} FROM {self.sql_table_name} " f'WHERE {relation} = "{self.descr_rel}";') res = self.cursor.execute(query) descr_triplets = res.fetchall() for triplets in [label_triplets] + alias_triplets_list: for triplet in triplets: entities_set.add(triplet[0]) if triplet[2].endswith(self.lang_str): label = triplet[2].replace(self.lang_str, "").replace('"', "") id_to_label_dict[triplet[0]].append(label) label_to_id_dict[label] = triplet[0] for triplet in descr_triplets: entities_set.add(triplet[0]) if triplet[2].endswith(self.lang_str): descr = triplet[2].replace(self.lang_str, "").replace('"', "") id_to_descr_dict[triplet[0]].append(descr) popularities_dict = {} for entity in entities_set: if self.kb_format == "hdt": all_triplets, number_of_triplets = self.doc.search_triples( entity, "", "") popularities_dict[entity] = number_of_triplets if self.kb_format == "sqlite3": subject, relation, obj = self.sql_column_names query = f'SELECT COUNT({obj}) FROM {self.sql_table_name} WHERE {subject} = "{entity}";' res = self.cursor.execute(query) popularities_dict[entity] = res.fetchall()[0][0] entities_dict = {entity: n for n, entity in enumerate(entities_set)} inverted_index = defaultdict(list) for label in label_to_id_dict: tokens = re.findall(self.re_tokenizer, label.lower()) for tok in tokens: if len(tok) > 1 and tok not in self.stopwords: inverted_index[tok].append( (entities_dict[label_to_id_dict[label]], popularities_dict[label_to_id_dict[label]])) self.inverted_index = dict(inverted_index) self.entities_list = list(entities_set) self.q2name = [ id_to_label_dict[entity] for entity in self.entities_list ] self.q2descr = [] if id_to_descr_dict: self.q2descr = [ id_to_descr_dict[entity] for entity in self.entities_list ] def filter_entities(self, entities: List[str], template_found: str) -> List[str]: if template_found in ["who is xxx?", "who was xxx?"]: entities = [ entity for entity in entities if entity in self.who_entities ] if template_found in ["what is xxx?", "what was xxx?"]: entities = [ entity for entity in entities if entity not in self.who_entities ] return entities