def test_get_explanation_provides_non_empty_explanation_for_typical_inputs(self): logical_form = ( "(infer (a:sugar higher world1) (a:diabetes higher world2) (a:diabetes higher world1))" ) entities = {"a:sugar": "sugar", "a:diabetes": "diabetes"} world_extractions = {"world1": "bill", "world2": "sue"} answer_index = 0 knowledge_graph = KnowledgeGraph(entities.keys(), {key: [] for key in entities}, entities) world = QuarelWorld(knowledge_graph, "quarel_v1_attr_entities") explanation = get_explanation(logical_form, world_extractions, answer_index, world) assert len(explanation) == 4
def test_get_explanation_provides_non_empty_explanation_for_typical_inputs( self): logical_form = '(infer (a:sugar higher world1) (a:diabetes higher world2) (a:diabetes higher world1))' entities = {'a:sugar': 'sugar', 'a:diabetes': 'diabetes'} world_extractions = {'world1': 'bill', 'world2': 'sue'} answer_index = 0 knowledge_graph = KnowledgeGraph(entities.keys(), {key: [] for key in entities}, entities) world = QuarelWorld(knowledge_graph, "quarel_v1_attr_entities") explanation = get_explanation(logical_form, world_extractions, answer_index, world) assert len(explanation) == 4
def get_table_knowledge_graph(self) -> KnowledgeGraph: if self._table_knowledge_graph is None: entities: Set[str] = set() neighbors: Dict[str, List[str]] = defaultdict(list) entity_text: Dict[str, str] = {} # Add all column names to entities. We'll define their neighbors to be empty lists for # now, and later add number and string entities as needed. number_columns = [] date_columns = [] for typed_column_name in self.column_names: if "number_column:" in typed_column_name or "num2_column" in typed_column_name: number_columns.append(typed_column_name) if "date_column:" in typed_column_name: date_columns.append(typed_column_name) # Add column names to entities, with no neighbors yet. entities.add(typed_column_name) neighbors[typed_column_name] = [] entity_text[typed_column_name] = typed_column_name.split( ":", 1)[-1].replace("_", " ") string_entities, numbers = self.get_entities_from_question() for entity, column_names in string_entities: entities.add(entity) for column_name in column_names: neighbors[entity].append(column_name) neighbors[column_name].append(entity) entity_text[entity] = entity.replace("string:", "").replace("_", " ") # For all numbers (except -1), we add all number and date columns as their neighbors. for number, _ in numbers: entities.add(number) neighbors[number].extend(number_columns + date_columns) for column_name in number_columns + date_columns: neighbors[column_name].append(number) entity_text[number] = number for entity, entity_neighbors in neighbors.items(): neighbors[entity] = list(set(entity_neighbors)) # Add "-1" as an entity only if we have date columns in the table because we will need # it as a wild-card in dates. The neighbors are the date columns. if "-1" not in neighbors and date_columns: entities.add("-1") neighbors["-1"] = date_columns entity_text["-1"] = "-1" for date_column in date_columns: neighbors[date_column].append("-1") self._table_knowledge_graph = KnowledgeGraph( entities, dict(neighbors), entity_text) return self._table_knowledge_graph
def get_table_knowledge_graph(self) -> KnowledgeGraph: if self._table_knowledge_graph is None: entities: Set[str] = set() neighbors: Dict[str, List[str]] = defaultdict(list) entity_text: Dict[str, str] = {} # Add all column names to entities. We'll define their neighbors to be empty lists for # now, and later add number and string entities as needed. number_columns = [] date_columns = [] for typed_column_name in self.table_data[0].keys(): if "number_column:" in typed_column_name or "num2_column" in typed_column_name: number_columns.append(typed_column_name) if "date_column:" in typed_column_name: date_columns.append(typed_column_name) # Add column names to entities, with no neighbors yet. entities.add(typed_column_name) neighbors[typed_column_name] = [] entity_text[typed_column_name] = typed_column_name.split( ":")[-1].replace("_", " ") for entity, _, _, column_names in self.question_entities: entities.add(entity) for column_name in column_names: neighbors[entity].append(column_name) neighbors[column_name].append(entity) entity_text[entity] = entity.replace("string:", "").replace("_", " ") if self._num2id is None or self._date2id is None: raise NotImplementedError for number in self._num2id: entities.add(number) neighbors[number].extend(number_columns) for column_name in number_columns: neighbors[column_name].append(number) entity_text[number] = number for date in self._date2id: entities.add(date) neighbors[date].extend(date_columns) for column_name in date_columns: neighbors[column_name].append(date) entity_text[date] = date for entity, entity_neighbors in neighbors.items(): neighbors[entity] = list(set(entity_neighbors)) self._table_knowledge_graph = KnowledgeGraph( entities, dict(neighbors), entity_text) return self._table_knowledge_graph
def get_knowledge_graph(self) -> KnowledgeGraph: if self._knowledge_graph is None: entities: Set[str] = set() neighbors: Dict[str, List[str]] = defaultdict(list) entity_text: Dict[str, str] = {} if self.paragraph_data: # Add all column names to entities. We'll define their neighbors to be empty lists for # now, and later add number and string entities as needed. for relation_name in self.paragraph_data[0].keys(): # Add relation names to entities, with no neighbors yet. entities.add(relation_name) neighbors[relation_name] = [] entity_text[relation_name] = relation_name.split(":")[-1].replace("_", " ") string_entities, numbers = self.get_entities_from_question() for entity, relation_name in string_entities: entities.add(entity) neighbors[entity].append(relation_name) neighbors[relation_name].append(entity) entity_text[entity] = entity.replace("string:", "").replace("_", " ") # If a word embedding is passed to the constructor, we might have entities extracted # from the paragraph as well. for paragragh_entity, relation_names in self.paragraph_tokens_to_keep: entity_name = f"string:{paragragh_entity}" entities.add(entity_name) neighbors[entity_name].extend(relation_names) for relation_name in relation_names: neighbors[relation_name].append(entity_name) entity_text[entity_name] = paragragh_entity.lower() # We add all numbers without any neighbors for number, _ in numbers: entities.add(number) entity_text[number] = number neighbors[number] = [] for entity, entity_neighbors in neighbors.items(): neighbors[entity] = list(set(entity_neighbors)) # We need -1 as wild cards in dates. entities.add("-1") entity_text["-1"] = "-1" neighbors["-1"] = [] self._knowledge_graph = KnowledgeGraph(entities, dict(neighbors), entity_text) return self._knowledge_graph
def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph: entities: Set[str] = set() neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet) entity_text: Dict[str, str] = {} foreign_keys_to_column: Dict[str, str] = {} db_schema = self.schema tables = db_schema.values() if db_id not in self.db_tables_data: self.db_tables_data[db_id] = read_dataset_values( db_id, self.dataset_path, tables) tables_data = self.db_tables_data[db_id] string_column_mapping: Dict[str, set] = defaultdict(set) for table, table_data in tables_data.items(): for table_row in table_data: for column, cell_value in zip(db_schema[table.name].columns, table_row): if column.column_type == 'text' and type( cell_value) is str: cell_value_normalized = self.normalize_string( cell_value) column_key = self.entity_key_for_column( table.name, column) string_column_mapping[cell_value_normalized].add( column_key) # string_entities = self.get_entities_from_question(string_column_mapping) string_entities = [] for table in tables: table_key = f"table:{table.name.lower()}" entities.add(table_key) entity_text[table_key] = table.text for column in db_schema[table.name].columns: entity_key = self.entity_key_for_column(table.name, column) entities.add(entity_key) neighbors[entity_key].add(table_key) neighbors[table_key].add(entity_key) entity_text[entity_key] = column.text for string_entity, column_keys in string_entities: entities.add(string_entity) for column_key in column_keys: neighbors[string_entity].add(column_key) neighbors[column_key].add(string_entity) entity_text[string_entity] = string_entity.replace("string:", "").replace( "_", " ") # loop again after we have gone through all columns to link foreign keys columns for table_name in db_schema.keys(): for column in db_schema[table_name].columns: if column.foreign_key is None: continue other_column_table, other_column_name = column.foreign_key.split( ':') # must have exactly one by design other_column = [ col for col in db_schema[other_column_table].columns if col.name == other_column_name ][0] entity_key = self.entity_key_for_column(table_name, column) other_entity_key = self.entity_key_for_column( other_column_table, other_column) neighbors[entity_key].add(other_entity_key) neighbors[other_entity_key].add(entity_key) foreign_keys_to_column[entity_key] = other_entity_key kg = KnowledgeGraph(entities, dict(neighbors), entity_text) kg.foreign_keys_to_column = foreign_keys_to_column return kg
def empty_field(self) -> 'KnowledgeGraphField': return KnowledgeGraphField(KnowledgeGraph(set(), {}), [], self._token_indexers)
def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph: entities: Set[str] = set() neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet) entity_text: Dict[str, str] = {} foreign_keys_to_column: Dict[str, str] = {} db_schema = self.schema tables = db_schema.values() if db_id not in self.db_tables_data: self.db_tables_data[db_id] = read_dataset_values( db_id, self.database_path, tables) tables_data = self.db_tables_data[db_id] string_column_mapping: Dict[str, set] = defaultdict(set) for table, table_data in tables_data.items(): for table_row in table_data: # TODO: special case for column * if db_schema[table.name].columns[0].name == '*': columns = db_schema[table.name].columns[1:] else: columns = db_schema[table.name].columns assert len(columns) == len(table_row) for column, cell_value in zip(db_schema[table.name].columns, table_row): if column.column_type == 'text' and type( cell_value) is str: cell_value_normalized = self.normalize_string( cell_value) column_key = self.entity_key_for_column( table.name, column) string_column_mapping[cell_value_normalized].add( column_key) for table in tables: table_key = f"table:{table.name.lower()}" entities.add(table_key) entity_text[table_key] = table.text for column in db_schema[table.name].columns: entity_key = self.entity_key_for_column(table.name, column) entities.add(entity_key) neighbors[entity_key].add(table_key) neighbors[table_key].add(entity_key) entity_text[entity_key] = column.text # dynamic entities of values in question # TODO: we should disable the string match entities in train. # Because it will cause the inconsistent between train and test # value_entities = self.get_values_from_question(string_column_mapping) # # for value_repr, column_keys in value_entities: # entities.add(value_repr) # for column_key in column_keys: # neighbors[value_repr].add(column_key) # neighbors[column_key].add(value_repr) # entity_text[value_repr] = value_repr.replace("string:", "").replace("_", " ") # loop again after we have gone through all columns to link foreign keys columns for table_name in db_schema.keys(): for column in db_schema[table_name].columns: if column.foreign_key is None: continue for foreign_key in column.foreign_key: other_column_table, other_column_name = foreign_key.split( ':') # must have exactly one by design other_column = [ col for col in db_schema[other_column_table].columns if col.name == other_column_name ][0] entity_key = self.entity_key_for_column(table_name, column) other_entity_key = self.entity_key_for_column( other_column_table, other_column) neighbors[entity_key].add(other_entity_key) neighbors[other_entity_key].add(entity_key) foreign_keys_to_column[entity_key] = other_entity_key kg = KnowledgeGraph(entities, dict(neighbors), entity_text) kg.foreign_keys_to_column = foreign_keys_to_column return kg
def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph: entities: Set[str] = set() neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet) entity_text: Dict[str, str] = {} foreign_keys_to_column: Dict[str, str] = {} db_schema = self.schema tables = db_schema.values() if db_id not in self.db_tables_data: """KAIMARY""" # Put "SELECT * FROM xxx" into sqlite db to get the rows in each table self.db_tables_data[db_id] = read_dataset_values( db_id, self.dataset_path, tables) tables_data = self.db_tables_data[db_id] string_column_mapping: Dict[str, set] = defaultdict(set) for table, table_data in tables_data.items(): for table_row in table_data: for column, cell_value in zip(db_schema[table.name].columns, table_row): if column.column_type == 'text' and type( cell_value) is str: cell_value_normalized = self.normalize_string( cell_value) """KAIMARY""" # Formatted as: "column:{column_type}:{table_name}:{column.name}" column_key = self.entity_key_for_column( table.name, column) """KAIMARY""" # Cell value as key, column_key as the value. string_column_mapping[cell_value_normalized].add( column_key) """KAIMARY""" # Only return "text" column type related tokens # Formatted as: "string:{entity['value']}, entity['token_in_columns']) string_entities = self.get_entities_from_question( string_column_mapping) """KAIMARY three types of entities: TABLE entities: entity["table:{table.name}"] COLUMN entities: entity["column:{column_type}:{table_name}:{column.name}"] STRING entities: entity["string:{normalized_token_text}"] """ for table in tables: table_key = f"table:{table.name.lower()}" entities.add(table_key) entity_text[table_key] = table.text for column in db_schema[table.name].columns: entity_key = self.entity_key_for_column(table.name, column) entities.add(entity_key) neighbors[entity_key].add(table_key) neighbors[table_key].add(entity_key) entity_text[entity_key] = column.text for string_entity, column_keys in string_entities: entities.add(string_entity) for column_key in column_keys: neighbors[string_entity].add(column_key) neighbors[column_key].add(string_entity) entity_text[string_entity] = string_entity.replace("string:", "").replace( "_", " ") # loop again after we have gone through all columns to link foreign keys columns for table_name in db_schema.keys(): for column in db_schema[table_name].columns: if column.foreign_key is None: continue other_column_table, other_column_name = column.foreign_key.split( ':') # must have exactly one by design other_column = [ col for col in db_schema[other_column_table].columns if col.name == other_column_name ][0] entity_key = self.entity_key_for_column(table_name, column) other_entity_key = self.entity_key_for_column( other_column_table, other_column) neighbors[entity_key].add(other_entity_key) neighbors[other_entity_key].add(entity_key) foreign_keys_to_column[entity_key] = other_entity_key kg = KnowledgeGraph(entities, dict(neighbors), entity_text) kg.foreign_keys_to_column = foreign_keys_to_column return kg
def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph: entities: Set[str] = set() neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet) entity_text: Dict[str, str] = {} foreign_keys_to_column: Dict[str, str] = {} db_schema = self.schema tables = db_schema.values() if db_id not in self.db_tables_data: self.db_tables_data[db_id] = read_dataset_values( db_id, self.dataset_path, tables) tables_data = self.db_tables_data[db_id] string_column_mapping: Dict[str, set] = defaultdict(set) for table, table_data in tables_data.items(): for table_row in table_data: for column, cell_value in zip(db_schema[table.name].columns, table_row): if column.column_type == 'text' and type( cell_value) is str: cell_value_normalized = self.normalize_string( cell_value) column_key = self.entity_key_for_column( table.name, column) string_column_mapping[cell_value_normalized].add( column_key) # string_entities because it only search the text column data. # string_entities is the column information that its value appearing in the question. string_entities = self.get_entities_from_question( string_column_mapping) # table.text|column.text -> table or column name # table_key -> 'table':table.text # entity_key -> 'column':column_type:table.text:column.text. column_type can be: 'text', 'number', 'primary', 'foreign'!!! # entities -> set of table_key and entity_key. The number of items is number of table + number of columns in each table. Table name + column name. # entity_text-> dict of {table_key:table.text} and {entity_key:column.text} # neighbors -> dict of {entity_key:table.text} and {table_key:{all its column.text}} for table in tables: table_key = f"table:{table.name.lower()}" entities.add(table_key) entity_text[table_key] = table.text for column in db_schema[table.name].columns: entity_key = self.entity_key_for_column(table.name, column) entities.add(entity_key) neighbors[entity_key].add(table_key) neighbors[table_key].add(entity_key) entity_text[entity_key] = column.text # token_in_utterance-> the token in utterance that appear in the data of database # string_entity -> type:token_in_utterance. type is only 'string' here. # column_keys -> list of entity_key # column_key -> entity_key # Now, entities -> set of table_key and entity_key and string_entity # Now, entity_text -> dict of {table_key:table_names} and {entity_key:column_names}. table_names and column_names come from tables.json. Not table_names_original and column_names_original!! # Now, neighbors -> Plus: {string_entity:all its entity_key} and {entity_key:table.text and string_entity (if it has)} and {table_key:{all its column.text}} for string_entity, column_keys in string_entities: entities.add(string_entity) for column_key in column_keys: neighbors[string_entity].add(column_key) neighbors[column_key].add(string_entity) entity_text[string_entity] = string_entity.replace("string:", "").replace( "_", " ") # loop again after we have gone through all columns to link foreign keys columns for table_name in db_schema.keys(): for column in db_schema[table_name].columns: if column.foreign_key is None: continue other_column_table, other_column_name = column.foreign_key.split( ':') # must have exactly one by design other_column = [ col for col in db_schema[other_column_table].columns if col.name == other_column_name ][0] entity_key = self.entity_key_for_column(table_name, column) other_entity_key = self.entity_key_for_column( other_column_table, other_column) neighbors[entity_key].add(other_entity_key) neighbors[other_entity_key].add(entity_key) foreign_keys_to_column[entity_key] = other_entity_key # if we can not find tokens appearing in both the data of database and token of question: # But even we can find tokens, here still the same. # 1. the length of entities and neighbors and entity_text are all equal to number of table + number of columns in each table # 2. entities is table and column 'name' # 3. neighbors is the relationship for every column and table. For example: # neighbors[table].items = [all its column name] # neighbors[column].items = [table_name, relation_column]; # if column_1 is a primary key and column_2 and column_3 is foreign keys reference to column_1, neighbors[column_1].items = [column_1's table_name, column_2, column_3] # neighbors[column_2].items = [column_2's table_name, column_1] # neighbors[column_3].items = [column_3's table_name, column_1] # if a column_4 does not have any relation_column, neighbors[column_4].items = [column_4's table_name]; # 4. entity_text: dict of {table_key:table_names} and {entity_key:column_names}. table_names and column_names is a readable name different from the original name in database. # if we can find token appearing in both the data of database and token of question: # 1. the length of three objectives are still equal and its value = previous value + number of token found # So every thing is the same as we can not find the token. The token expand the three objectives and nothing more, Now let's discuss the new data brought by the token. # 2. entities += ['string':token]. We call ['string':token] as type_with_token. # 3. neighbors += {type_with_token:[its columns]}. Supposing the token appearing in question and name_column and previous_name_column. here will be: # neighbors += {type_with_token:[name_column,previous_name_column]}. # neighbors[name_column].add(type_with_token), which is from neighbors[name_column] = {name_column's table_name} to neighbors[name_column] = {name_column's table_name, type_with_token} # neighbors[previous_name_column] do the same process. # neighbors[name_column] and neighbors[previous_name_column] exist for neighbors even we do not find tokens. But neighbors[type_with_token] only exist when we find tokens. # 4. entity_text += {type_with_token: token} # To now, we know how to create a KnowledgeGraph: # entities is node name in a graph. # entity_text is real text for the node. For example, we can call the node as 'column:text:management:temporary_acting', but its text for this node is 'temporary acting'. # 'temporary acting' copy from column_names in tables.json (not column_names_original). So we will analyse the text only. Node name is just a symbol. # neighbors is the edge in a graph. # neighbors['node_name'].items=[list of (other) node name ]. Here is absolutely other! # The edge contain: table-column, foreign-primary, column-token_in_quesion. # Their length is number of node in this graph. kg = KnowledgeGraph(entities, dict(neighbors), entity_text) # node name, edge, node value # Add a attribute for kg. # Example data of foreign_keys_to_column generated from table 'department_management' # foreign_keys_to_column['column:foreign:management:department_id'] = 'column:primary:department:department_id' # foreign_keys_to_column['column:foreign:management:head_id'] = 'column:primary:head:head_id' kg.foreign_keys_to_column = foreign_keys_to_column return kg
def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph: entities: Set[str] = set() neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet) entity_text: Dict[str, str] = {} #存放table/column的text foreign_keys_to_column: Dict[str, str] = {} db_schema = self.schema tables = db_schema.values() #todo 读取wiki的数据库具体的row信息 if db_id not in self.db_tables_data: self.db_tables_data[db_id] = read_wiki_dataset_values( db_id, self.tables_file, tables) # dict[table,List[tuple(row)]] tables_data = self.db_tables_data[db_id] # 这个是为了将cell_value map到对应的 column string_column_mapping: Dict[str, set] = defaultdict(set) for table, table_data in tables_data.items(): for table_row in table_data: for column, cell_value in zip(db_schema[table.name].columns, table_row): if column.column_type == 'text' and type( cell_value) is str: cell_value_normalized = self.normalize_string( cell_value) column_key = self.entity_key_for_column( table.name, column) string_column_mapping[cell_value_normalized].add( column_key) # 这步返回了 question 中所有的 entities ( str : list ) string_entities = self.get_entities_from_question( string_column_mapping) for table in tables: table_key = f"table:{table.name.lower()}" entities.add(table_key) #对于wikisql 来说是没有区别的 都是 db_id entity_text[ table_key] = table.text # table.name = table.text =db_id for column in db_schema[table.name].columns: # eneity_key 是column 的签名 entity_key = self.entity_key_for_column(table.name, column) # eneities 同时有table_key和entity_key (column_key) entities.add(entity_key) # 而且可以通过neighbors来标记 table/column的关系 neighbors[entity_key].add(table_key) neighbors[table_key].add(entity_key) entity_text[entity_key] = column.text # entities中也包含了 question中的entities_text for string_entity, column_keys in string_entities: entities.add(string_entity) for column_key in column_keys: neighbors[string_entity].add(column_key) neighbors[column_key].add(string_entity) entity_text[string_entity] = string_entity.replace("string:", "").replace( "_", " ") # loop again after we have gone through all columns to link foreign keys columns # 对于wikisql来说 foreign_key信息全是None 不生效 for table_name in db_schema.keys(): for column in db_schema[table_name].columns: if column.foreign_key is None: continue other_column_table, other_column_name = column.foreign_key.split( ':') # must have exactly one by design other_column = [ col for col in db_schema[other_column_table].columns if col.name == other_column_name ][0] entity_key = self.entity_key_for_column(table_name, column) other_entity_key = self.entity_key_for_column( other_column_table, other_column) neighbors[entity_key].add(other_entity_key) neighbors[other_entity_key].add(entity_key) foreign_keys_to_column[entity_key] = other_entity_key kg = KnowledgeGraph(entities, dict(neighbors), entity_text) kg.foreign_keys_to_column = foreign_keys_to_column return kg
def __init__( self, lazy: bool = False, sample: int = -1, lf_syntax: str = None, replace_world_entities: bool = False, align_world_extractions: bool = False, gold_world_extractions: bool = False, tagger_only: bool = False, denotation_only: bool = False, world_extraction_model: Optional[str] = None, skip_attributes_regex: Optional[str] = None, entity_bits_mode: Optional[str] = None, entity_types: Optional[List[str]] = None, lexical_cues: List[str] = None, tokenizer: Tokenizer = None, question_token_indexers: Dict[str, TokenIndexer] = None, ) -> None: super().__init__(lazy=lazy) self._tokenizer = tokenizer or WordTokenizer() self._question_token_indexers = question_token_indexers or { "tokens": SingleIdTokenIndexer() } self._entity_token_indexers = self._question_token_indexers self._sample = sample self._replace_world_entities = replace_world_entities self._lf_syntax = lf_syntax self._entity_bits_mode = entity_bits_mode self._align_world_extractions = align_world_extractions self._gold_world_extractions = gold_world_extractions self._entity_types = entity_types self._tagger_only = tagger_only self._denotation_only = denotation_only self._skip_attributes_regex = None if skip_attributes_regex is not None: self._skip_attributes_regex = re.compile(skip_attributes_regex) self._lexical_cues = lexical_cues # Recording of entities in categories relevant for tagging all_entities = {} all_entities["world"] = ["world1", "world2"] # TODO: Clarify this into an appropriate parameter self._collapse_tags = ["world"] self._all_entities = None if entity_types is not None: if self._entity_bits_mode == "collapsed": self._all_entities = entity_types else: self._all_entities = [e for t in entity_types for e in all_entities[t]] logger.info(f"all_entities = {self._all_entities}") # Base world, depending on LF syntax only self._knowledge_graph = KnowledgeGraph( entities={"placeholder"}, neighbors={}, entity_text={"placeholder": "placeholder"} ) self._world = QuarelWorld(self._knowledge_graph, self._lf_syntax) # Decide dynamic entities, if any self._dynamic_entities: Dict[str, str] = dict() self._use_attr_entities = False if "_attr_entities" in lf_syntax: self._use_attr_entities = True qr_coeff_sets = self._world.qr_coeff_sets for qset in qr_coeff_sets: for attribute in qset: if ( self._skip_attributes_regex is not None and self._skip_attributes_regex.search(attribute) ): continue # Get text associated with each entity, both from entity identifier and # associated lexical cues, if any entity_strings = [words_from_entity_string(attribute).lower()] if self._lexical_cues is not None: for key in self._lexical_cues: if attribute in LEXICAL_CUES[key]: entity_strings += LEXICAL_CUES[key][attribute] self._dynamic_entities["a:" + attribute] = " ".join(entity_strings) # Update world to include dynamic entities if self._use_attr_entities: logger.info(f"dynamic_entities = {self._dynamic_entities}") neighbors: Dict[str, List[str]] = {key: [] for key in self._dynamic_entities} self._knowledge_graph = KnowledgeGraph( entities=set(self._dynamic_entities.keys()), neighbors=neighbors, entity_text=self._dynamic_entities, ) self._world = QuarelWorld(self._knowledge_graph, self._lf_syntax) self._stemmer = PorterStemmer().stemmer self._world_tagger_extractor = None self._extract_worlds = False if world_extraction_model is not None: logger.info("Loading world tagger model...") self._extract_worlds = True self._world_tagger_extractor = WorldTaggerExtractor(world_extraction_model) logger.info("Done loading world tagger model!") # Convenience regex for recognizing attributes self._attr_regex = re.compile(r"""\((\w+) (high|low|higher|lower)""")
def text_to_instance( self, # type: ignore question: str, logical_forms: List[str] = None, additional_metadata: Dict[str, Any] = None, world_extractions: Dict[str, Union[str, List[str]]] = None, entity_literals: Dict[str, Union[str, List[str]]] = None, tokenized_question: List[Token] = None, debug_counter: int = None, qr_spec_override: List[Dict[str, int]] = None, dynamic_entities_override: Dict[str, str] = None, ) -> Instance: tokenized_question = tokenized_question or self._tokenizer.tokenize(question.lower()) additional_metadata = additional_metadata or dict() additional_metadata["question_tokens"] = [token.text for token in tokenized_question] if world_extractions is not None: additional_metadata["world_extractions"] = world_extractions question_field = TextField(tokenized_question, self._question_token_indexers) if qr_spec_override is not None or dynamic_entities_override is not None: # Dynamically specify theory and/or entities dynamic_entities = dynamic_entities_override or self._dynamic_entities neighbors: Dict[str, List[str]] = {key: [] for key in dynamic_entities.keys()} knowledge_graph = KnowledgeGraph( entities=set(dynamic_entities.keys()), neighbors=neighbors, entity_text=dynamic_entities, ) world = QuarelWorld(knowledge_graph, self._lf_syntax, qr_coeff_sets=qr_spec_override) else: knowledge_graph = self._knowledge_graph world = self._world table_field = KnowledgeGraphField( knowledge_graph, tokenized_question, self._entity_token_indexers, tokenizer=self._tokenizer, ) if self._tagger_only: fields: Dict[str, Field] = {"tokens": question_field} if entity_literals is not None: entity_tags = self._get_entity_tags( self._all_entities, table_field, entity_literals, tokenized_question ) if debug_counter > 0: logger.info(f"raw entity tags = {entity_tags}") entity_tags_bio = self._convert_tags_bio(entity_tags) fields["tags"] = SequenceLabelField(entity_tags_bio, question_field) additional_metadata["tags_gold"] = entity_tags_bio additional_metadata["words"] = [x.text for x in tokenized_question] fields["metadata"] = MetadataField(additional_metadata) return Instance(fields) world_field = MetadataField(world) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(" -> ") is_global_rule = not world.is_table_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = { "question": question_field, "table": table_field, "world": world_field, "actions": action_field, } if self._denotation_only: denotation_field = LabelField(additional_metadata["answer_index"], skip_indexing=True) fields["denotation_target"] = denotation_field if self._entity_bits_mode is not None and world_extractions is not None: entity_bits = self._get_entity_tags( ["world1", "world2"], table_field, world_extractions, tokenized_question ) if self._entity_bits_mode == "simple": entity_bits_v = [[[0, 0], [1, 0], [0, 1]][tag] for tag in entity_bits] elif self._entity_bits_mode == "simple_collapsed": entity_bits_v = [[[0], [1], [1]][tag] for tag in entity_bits] elif self._entity_bits_mode == "simple3": entity_bits_v = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]][tag] for tag in entity_bits] entity_bits_field = ArrayField(np.array(entity_bits_v)) fields["entity_bits"] = entity_bits_field if logical_forms: action_map = { action.rule: i for i, action in enumerate(action_field.field_list) # type: ignore } action_sequence_fields: List[Field] = [] for logical_form in logical_forms: expression = world.parse_logical_form(logical_form) action_sequence = world.get_action_sequence(expression) try: index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append(IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except KeyError as error: logger.info(f"Missing production rule: {error.args}, skipping logical form") logger.info(f"Question was: {question}") logger.info(f"Logical form was: {logical_form}") continue fields["target_action_sequences"] = ListField(action_sequence_fields) fields["metadata"] = MetadataField(additional_metadata or {}) return Instance(fields)