def test_get_explanation_provides_non_empty_explanation_for_typical_inputs(self):
     logical_form = (
         "(infer (a:sugar higher world1) (a:diabetes higher world2) (a:diabetes higher world1))"
     )
     entities = {"a:sugar": "sugar", "a:diabetes": "diabetes"}
     world_extractions = {"world1": "bill", "world2": "sue"}
     answer_index = 0
     knowledge_graph = KnowledgeGraph(entities.keys(), {key: [] for key in entities}, entities)
     world = QuarelWorld(knowledge_graph, "quarel_v1_attr_entities")
     explanation = get_explanation(logical_form, world_extractions, answer_index, world)
     assert len(explanation) == 4
Exemple #2
0
 def test_get_explanation_provides_non_empty_explanation_for_typical_inputs(
         self):
     logical_form = '(infer (a:sugar higher world1) (a:diabetes higher world2) (a:diabetes higher world1))'
     entities = {'a:sugar': 'sugar', 'a:diabetes': 'diabetes'}
     world_extractions = {'world1': 'bill', 'world2': 'sue'}
     answer_index = 0
     knowledge_graph = KnowledgeGraph(entities.keys(),
                                      {key: []
                                       for key in entities}, entities)
     world = QuarelWorld(knowledge_graph, "quarel_v1_attr_entities")
     explanation = get_explanation(logical_form, world_extractions,
                                   answer_index, world)
     assert len(explanation) == 4
    def get_table_knowledge_graph(self) -> KnowledgeGraph:
        if self._table_knowledge_graph is None:
            entities: Set[str] = set()
            neighbors: Dict[str, List[str]] = defaultdict(list)
            entity_text: Dict[str, str] = {}
            # Add all column names to entities. We'll define their neighbors to be empty lists for
            # now, and later add number and string entities as needed.
            number_columns = []
            date_columns = []
            for typed_column_name in self.column_names:
                if "number_column:" in typed_column_name or "num2_column" in typed_column_name:
                    number_columns.append(typed_column_name)

                if "date_column:" in typed_column_name:
                    date_columns.append(typed_column_name)

                # Add column names to entities, with no neighbors yet.
                entities.add(typed_column_name)
                neighbors[typed_column_name] = []
                entity_text[typed_column_name] = typed_column_name.split(
                    ":", 1)[-1].replace("_", " ")

            string_entities, numbers = self.get_entities_from_question()
            for entity, column_names in string_entities:
                entities.add(entity)
                for column_name in column_names:
                    neighbors[entity].append(column_name)
                    neighbors[column_name].append(entity)
                entity_text[entity] = entity.replace("string:",
                                                     "").replace("_", " ")
            # For all numbers (except -1), we add all number and date columns as their neighbors.
            for number, _ in numbers:
                entities.add(number)
                neighbors[number].extend(number_columns + date_columns)
                for column_name in number_columns + date_columns:
                    neighbors[column_name].append(number)
                entity_text[number] = number
            for entity, entity_neighbors in neighbors.items():
                neighbors[entity] = list(set(entity_neighbors))

            # Add "-1" as an entity only if we have date columns in the table because we will need
            # it as a wild-card in dates. The neighbors are the date columns.
            if "-1" not in neighbors and date_columns:
                entities.add("-1")
                neighbors["-1"] = date_columns
                entity_text["-1"] = "-1"
                for date_column in date_columns:
                    neighbors[date_column].append("-1")
            self._table_knowledge_graph = KnowledgeGraph(
                entities, dict(neighbors), entity_text)
        return self._table_knowledge_graph
Exemple #4
0
    def get_table_knowledge_graph(self) -> KnowledgeGraph:
        if self._table_knowledge_graph is None:
            entities: Set[str] = set()
            neighbors: Dict[str, List[str]] = defaultdict(list)
            entity_text: Dict[str, str] = {}
            # Add all column names to entities. We'll define their neighbors to be empty lists for
            # now, and later add number and string entities as needed.
            number_columns = []
            date_columns = []
            for typed_column_name in self.table_data[0].keys():
                if "number_column:" in typed_column_name or "num2_column" in typed_column_name:
                    number_columns.append(typed_column_name)

                if "date_column:" in typed_column_name:
                    date_columns.append(typed_column_name)

                # Add column names to entities, with no neighbors yet.
                entities.add(typed_column_name)
                neighbors[typed_column_name] = []
                entity_text[typed_column_name] = typed_column_name.split(
                    ":")[-1].replace("_", " ")

            for entity, _, _, column_names in self.question_entities:
                entities.add(entity)
                for column_name in column_names:
                    neighbors[entity].append(column_name)
                    neighbors[column_name].append(entity)
                entity_text[entity] = entity.replace("string:",
                                                     "").replace("_", " ")

            if self._num2id is None or self._date2id is None:
                raise NotImplementedError
            for number in self._num2id:
                entities.add(number)
                neighbors[number].extend(number_columns)
                for column_name in number_columns:
                    neighbors[column_name].append(number)
                entity_text[number] = number
            for date in self._date2id:
                entities.add(date)
                neighbors[date].extend(date_columns)
                for column_name in date_columns:
                    neighbors[column_name].append(date)
                entity_text[date] = date

            for entity, entity_neighbors in neighbors.items():
                neighbors[entity] = list(set(entity_neighbors))
            self._table_knowledge_graph = KnowledgeGraph(
                entities, dict(neighbors), entity_text)
        return self._table_knowledge_graph
Exemple #5
0
    def get_knowledge_graph(self) -> KnowledgeGraph:
        if self._knowledge_graph is None:
            entities: Set[str] = set()
            neighbors: Dict[str, List[str]] = defaultdict(list)
            entity_text: Dict[str, str] = {}
            if self.paragraph_data:
                # Add all column names to entities. We'll define their neighbors to be empty lists for
                # now, and later add number and string entities as needed.
                for relation_name in self.paragraph_data[0].keys():
                    # Add relation names to entities, with no neighbors yet.
                    entities.add(relation_name)
                    neighbors[relation_name] = []
                    entity_text[relation_name] = relation_name.split(":")[-1].replace("_", " ")

                string_entities, numbers = self.get_entities_from_question()
                for entity, relation_name in string_entities:
                    entities.add(entity)
                    neighbors[entity].append(relation_name)
                    neighbors[relation_name].append(entity)
                    entity_text[entity] = entity.replace("string:", "").replace("_", " ")

                # If a word embedding is passed to the constructor, we might have entities extracted
                # from the paragraph as well.
                for paragragh_entity, relation_names in self.paragraph_tokens_to_keep:
                    entity_name = f"string:{paragragh_entity}"
                    entities.add(entity_name)
                    neighbors[entity_name].extend(relation_names)
                    for relation_name in relation_names:
                        neighbors[relation_name].append(entity_name)
                    entity_text[entity_name] = paragragh_entity.lower()

                # We add all numbers without any neighbors
                for number, _ in numbers:
                    entities.add(number)
                    entity_text[number] = number
                    neighbors[number] = []

                for entity, entity_neighbors in neighbors.items():
                    neighbors[entity] = list(set(entity_neighbors))

                # We need -1 as wild cards in dates.
                entities.add("-1")
                entity_text["-1"] = "-1"
                neighbors["-1"] = []
            self._knowledge_graph = KnowledgeGraph(entities, dict(neighbors), entity_text)
        return self._knowledge_graph
Exemple #6
0
    def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph:
        entities: Set[str] = set()
        neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet)
        entity_text: Dict[str, str] = {}
        foreign_keys_to_column: Dict[str, str] = {}

        db_schema = self.schema
        tables = db_schema.values()

        if db_id not in self.db_tables_data:
            self.db_tables_data[db_id] = read_dataset_values(
                db_id, self.dataset_path, tables)

        tables_data = self.db_tables_data[db_id]

        string_column_mapping: Dict[str, set] = defaultdict(set)

        for table, table_data in tables_data.items():
            for table_row in table_data:
                for column, cell_value in zip(db_schema[table.name].columns,
                                              table_row):
                    if column.column_type == 'text' and type(
                            cell_value) is str:
                        cell_value_normalized = self.normalize_string(
                            cell_value)
                        column_key = self.entity_key_for_column(
                            table.name, column)
                        string_column_mapping[cell_value_normalized].add(
                            column_key)

        # string_entities = self.get_entities_from_question(string_column_mapping)
        string_entities = []

        for table in tables:
            table_key = f"table:{table.name.lower()}"
            entities.add(table_key)
            entity_text[table_key] = table.text

            for column in db_schema[table.name].columns:
                entity_key = self.entity_key_for_column(table.name, column)
                entities.add(entity_key)
                neighbors[entity_key].add(table_key)
                neighbors[table_key].add(entity_key)
                entity_text[entity_key] = column.text

        for string_entity, column_keys in string_entities:
            entities.add(string_entity)
            for column_key in column_keys:
                neighbors[string_entity].add(column_key)
                neighbors[column_key].add(string_entity)
            entity_text[string_entity] = string_entity.replace("string:",
                                                               "").replace(
                                                                   "_", " ")

        # loop again after we have gone through all columns to link foreign keys columns
        for table_name in db_schema.keys():
            for column in db_schema[table_name].columns:
                if column.foreign_key is None:
                    continue

                other_column_table, other_column_name = column.foreign_key.split(
                    ':')

                # must have exactly one by design
                other_column = [
                    col for col in db_schema[other_column_table].columns
                    if col.name == other_column_name
                ][0]

                entity_key = self.entity_key_for_column(table_name, column)
                other_entity_key = self.entity_key_for_column(
                    other_column_table, other_column)

                neighbors[entity_key].add(other_entity_key)
                neighbors[other_entity_key].add(entity_key)

                foreign_keys_to_column[entity_key] = other_entity_key

        kg = KnowledgeGraph(entities, dict(neighbors), entity_text)
        kg.foreign_keys_to_column = foreign_keys_to_column

        return kg
Exemple #7
0
 def empty_field(self) -> 'KnowledgeGraphField':
     return KnowledgeGraphField(KnowledgeGraph(set(), {}), [],
                                self._token_indexers)
Exemple #8
0
    def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph:
        entities: Set[str] = set()
        neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet)
        entity_text: Dict[str, str] = {}
        foreign_keys_to_column: Dict[str, str] = {}

        db_schema = self.schema
        tables = db_schema.values()

        if db_id not in self.db_tables_data:
            self.db_tables_data[db_id] = read_dataset_values(
                db_id, self.database_path, tables)

        tables_data = self.db_tables_data[db_id]

        string_column_mapping: Dict[str, set] = defaultdict(set)

        for table, table_data in tables_data.items():
            for table_row in table_data:
                # TODO: special case for column *
                if db_schema[table.name].columns[0].name == '*':
                    columns = db_schema[table.name].columns[1:]
                else:
                    columns = db_schema[table.name].columns
                assert len(columns) == len(table_row)
                for column, cell_value in zip(db_schema[table.name].columns,
                                              table_row):
                    if column.column_type == 'text' and type(
                            cell_value) is str:
                        cell_value_normalized = self.normalize_string(
                            cell_value)
                        column_key = self.entity_key_for_column(
                            table.name, column)
                        string_column_mapping[cell_value_normalized].add(
                            column_key)

        for table in tables:
            table_key = f"table:{table.name.lower()}"
            entities.add(table_key)
            entity_text[table_key] = table.text

            for column in db_schema[table.name].columns:
                entity_key = self.entity_key_for_column(table.name, column)
                entities.add(entity_key)
                neighbors[entity_key].add(table_key)
                neighbors[table_key].add(entity_key)
                entity_text[entity_key] = column.text

        # dynamic entities of values in question
        # TODO: we should disable the string match entities in train.
        #  Because it will cause the inconsistent between train and test
        # value_entities = self.get_values_from_question(string_column_mapping)
        #
        # for value_repr, column_keys in value_entities:
        #     entities.add(value_repr)
        #     for column_key in column_keys:
        #         neighbors[value_repr].add(column_key)
        #         neighbors[column_key].add(value_repr)
        #     entity_text[value_repr] = value_repr.replace("string:", "").replace("_", " ")

        # loop again after we have gone through all columns to link foreign keys columns
        for table_name in db_schema.keys():
            for column in db_schema[table_name].columns:
                if column.foreign_key is None:
                    continue

                for foreign_key in column.foreign_key:
                    other_column_table, other_column_name = foreign_key.split(
                        ':')

                    # must have exactly one by design
                    other_column = [
                        col for col in db_schema[other_column_table].columns
                        if col.name == other_column_name
                    ][0]

                    entity_key = self.entity_key_for_column(table_name, column)
                    other_entity_key = self.entity_key_for_column(
                        other_column_table, other_column)

                    neighbors[entity_key].add(other_entity_key)
                    neighbors[other_entity_key].add(entity_key)

                    foreign_keys_to_column[entity_key] = other_entity_key

        kg = KnowledgeGraph(entities, dict(neighbors), entity_text)
        kg.foreign_keys_to_column = foreign_keys_to_column

        return kg
    def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph:
        entities: Set[str] = set()
        neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet)
        entity_text: Dict[str, str] = {}
        foreign_keys_to_column: Dict[str, str] = {}

        db_schema = self.schema
        tables = db_schema.values()

        if db_id not in self.db_tables_data:
            """KAIMARY"""
            # Put "SELECT * FROM xxx" into sqlite db to get the rows in each table
            self.db_tables_data[db_id] = read_dataset_values(
                db_id, self.dataset_path, tables)

        tables_data = self.db_tables_data[db_id]

        string_column_mapping: Dict[str, set] = defaultdict(set)

        for table, table_data in tables_data.items():
            for table_row in table_data:
                for column, cell_value in zip(db_schema[table.name].columns,
                                              table_row):
                    if column.column_type == 'text' and type(
                            cell_value) is str:
                        cell_value_normalized = self.normalize_string(
                            cell_value)
                        """KAIMARY"""
                        # Formatted as: "column:{column_type}:{table_name}:{column.name}"
                        column_key = self.entity_key_for_column(
                            table.name, column)
                        """KAIMARY"""
                        #  Cell value as key, column_key as the value.
                        string_column_mapping[cell_value_normalized].add(
                            column_key)
        """KAIMARY"""
        # Only return "text" column type related tokens
        # Formatted as: "string:{entity['value']}, entity['token_in_columns'])
        string_entities = self.get_entities_from_question(
            string_column_mapping)
        """KAIMARY
             three types of entities:
             TABLE entities: entity["table:{table.name}"]
             COLUMN entities: entity["column:{column_type}:{table_name}:{column.name}"]
             STRING entities: entity["string:{normalized_token_text}"]
        """
        for table in tables:
            table_key = f"table:{table.name.lower()}"
            entities.add(table_key)
            entity_text[table_key] = table.text

            for column in db_schema[table.name].columns:
                entity_key = self.entity_key_for_column(table.name, column)
                entities.add(entity_key)
                neighbors[entity_key].add(table_key)
                neighbors[table_key].add(entity_key)
                entity_text[entity_key] = column.text

        for string_entity, column_keys in string_entities:
            entities.add(string_entity)
            for column_key in column_keys:
                neighbors[string_entity].add(column_key)
                neighbors[column_key].add(string_entity)
            entity_text[string_entity] = string_entity.replace("string:",
                                                               "").replace(
                                                                   "_", " ")

        # loop again after we have gone through all columns to link foreign keys columns
        for table_name in db_schema.keys():
            for column in db_schema[table_name].columns:
                if column.foreign_key is None:
                    continue

                other_column_table, other_column_name = column.foreign_key.split(
                    ':')

                # must have exactly one by design
                other_column = [
                    col for col in db_schema[other_column_table].columns
                    if col.name == other_column_name
                ][0]

                entity_key = self.entity_key_for_column(table_name, column)
                other_entity_key = self.entity_key_for_column(
                    other_column_table, other_column)

                neighbors[entity_key].add(other_entity_key)
                neighbors[other_entity_key].add(entity_key)

                foreign_keys_to_column[entity_key] = other_entity_key

        kg = KnowledgeGraph(entities, dict(neighbors), entity_text)
        kg.foreign_keys_to_column = foreign_keys_to_column

        return kg
    def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph:
        entities: Set[str] = set()
        neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet)
        entity_text: Dict[str, str] = {}
        foreign_keys_to_column: Dict[str, str] = {}

        db_schema = self.schema
        tables = db_schema.values()

        if db_id not in self.db_tables_data:
            self.db_tables_data[db_id] = read_dataset_values(
                db_id, self.dataset_path, tables)

        tables_data = self.db_tables_data[db_id]

        string_column_mapping: Dict[str, set] = defaultdict(set)

        for table, table_data in tables_data.items():
            for table_row in table_data:
                for column, cell_value in zip(db_schema[table.name].columns,
                                              table_row):
                    if column.column_type == 'text' and type(
                            cell_value) is str:
                        cell_value_normalized = self.normalize_string(
                            cell_value)
                        column_key = self.entity_key_for_column(
                            table.name, column)
                        string_column_mapping[cell_value_normalized].add(
                            column_key)

        # string_entities because it only search the text column data.
        # string_entities is the column information that its value appearing in the question.
        string_entities = self.get_entities_from_question(
            string_column_mapping)

        # table.text|column.text -> table or column name
        # table_key  -> 'table':table.text
        # entity_key -> 'column':column_type:table.text:column.text. column_type can be: 'text', 'number', 'primary', 'foreign'!!!
        # entities   ->  set of table_key and entity_key. The number of items is number of table + number of columns in each table. Table name + column name.
        # entity_text->  dict of {table_key:table.text} and {entity_key:column.text}
        # neighbors  ->  dict of {entity_key:table.text} and {table_key:{all its column.text}}
        for table in tables:
            table_key = f"table:{table.name.lower()}"
            entities.add(table_key)
            entity_text[table_key] = table.text
            for column in db_schema[table.name].columns:
                entity_key = self.entity_key_for_column(table.name, column)
                entities.add(entity_key)
                neighbors[entity_key].add(table_key)
                neighbors[table_key].add(entity_key)
                entity_text[entity_key] = column.text

        # token_in_utterance-> the token in utterance that appear in the data of database
        # string_entity     -> type:token_in_utterance. type is only 'string' here.
        # column_keys       -> list of entity_key
        # column_key        -> entity_key
        # Now, entities     -> set of table_key and entity_key and string_entity
        # Now, entity_text  -> dict of {table_key:table_names} and {entity_key:column_names}. table_names and column_names come from tables.json. Not table_names_original and column_names_original!!
        # Now, neighbors    -> Plus: {string_entity:all its entity_key} and {entity_key:table.text and string_entity (if it has)} and {table_key:{all its column.text}}
        for string_entity, column_keys in string_entities:
            entities.add(string_entity)
            for column_key in column_keys:
                neighbors[string_entity].add(column_key)
                neighbors[column_key].add(string_entity)
            entity_text[string_entity] = string_entity.replace("string:",
                                                               "").replace(
                                                                   "_", " ")

        # loop again after we have gone through all columns to link foreign keys columns
        for table_name in db_schema.keys():
            for column in db_schema[table_name].columns:
                if column.foreign_key is None:
                    continue

                other_column_table, other_column_name = column.foreign_key.split(
                    ':')

                # must have exactly one by design
                other_column = [
                    col for col in db_schema[other_column_table].columns
                    if col.name == other_column_name
                ][0]

                entity_key = self.entity_key_for_column(table_name, column)
                other_entity_key = self.entity_key_for_column(
                    other_column_table, other_column)

                neighbors[entity_key].add(other_entity_key)
                neighbors[other_entity_key].add(entity_key)

                foreign_keys_to_column[entity_key] = other_entity_key

        # if we can not find tokens appearing in both the data of database and token of question:
        # But even we can find tokens, here still the same.
        #   1. the length of entities and neighbors and entity_text are all equal to number of table + number of columns in each table
        #   2. entities is table and column 'name'
        #   3. neighbors is the relationship for every column and table. For example:
        #           neighbors[table].items = [all its column name]
        #           neighbors[column].items = [table_name, relation_column];
        #               if column_1 is a primary key and column_2 and column_3 is foreign keys reference to column_1, neighbors[column_1].items = [column_1's table_name, column_2, column_3]
        #           neighbors[column_2].items = [column_2's table_name, column_1]
        #           neighbors[column_3].items = [column_3's table_name, column_1]
        #           if a column_4 does not have any relation_column, neighbors[column_4].items = [column_4's table_name];
        #   4. entity_text: dict of {table_key:table_names} and {entity_key:column_names}. table_names and column_names is a readable name different from the original name in database.

        # if we can find token appearing in both the data of database and token of question:
        #   1. the length of three objectives are still equal and its value = previous value + number of token found
        #       So every thing is the same as we can not find the token. The token expand the three objectives and nothing more, Now let's discuss the new data brought by the token.
        #   2. entities += ['string':token]. We call ['string':token] as type_with_token.
        #   3. neighbors += {type_with_token:[its columns]}. Supposing the token appearing in question and name_column and previous_name_column. here will be:
        #           neighbors += {type_with_token:[name_column,previous_name_column]}.
        #           neighbors[name_column].add(type_with_token), which is from neighbors[name_column] = {name_column's table_name} to neighbors[name_column] = {name_column's table_name, type_with_token}
        #           neighbors[previous_name_column] do the same process.
        #           neighbors[name_column] and neighbors[previous_name_column] exist for neighbors even we do not find tokens. But neighbors[type_with_token] only exist when we find tokens.
        #   4. entity_text +=  {type_with_token: token}

        # To now, we know how to create a KnowledgeGraph:
        # entities is node name in a graph.
        # entity_text is real text for the node. For example, we can call the node as 'column:text:management:temporary_acting', but its text for this node is 'temporary acting'.
        #           'temporary acting' copy from column_names in tables.json (not column_names_original). So we will analyse the text only. Node name is just a symbol.
        # neighbors is the edge in a graph.
        #           neighbors['node_name'].items=[list of (other) node name ]. Here is absolutely other!
        #           The edge contain: table-column, foreign-primary, column-token_in_quesion.
        # Their length is number of node in this graph.
        kg = KnowledgeGraph(entities, dict(neighbors),
                            entity_text)  # node name, edge, node value

        # Add a attribute for kg.
        # Example data of foreign_keys_to_column generated from table 'department_management'
        # foreign_keys_to_column['column:foreign:management:department_id'] = 'column:primary:department:department_id'
        # foreign_keys_to_column['column:foreign:management:head_id'] = 'column:primary:head:head_id'
        kg.foreign_keys_to_column = foreign_keys_to_column

        return kg
Exemple #11
0
    def get_db_knowledge_graph(self, db_id: str) -> KnowledgeGraph:
        entities: Set[str] = set()
        neighbors: Dict[str, OrderedSet[str]] = defaultdict(OrderedSet)
        entity_text: Dict[str, str] = {}  #存放table/column的text
        foreign_keys_to_column: Dict[str, str] = {}

        db_schema = self.schema
        tables = db_schema.values()

        #todo 读取wiki的数据库具体的row信息
        if db_id not in self.db_tables_data:
            self.db_tables_data[db_id] = read_wiki_dataset_values(
                db_id, self.tables_file, tables)

        # dict[table,List[tuple(row)]]
        tables_data = self.db_tables_data[db_id]

        # 这个是为了将cell_value map到对应的 column
        string_column_mapping: Dict[str, set] = defaultdict(set)

        for table, table_data in tables_data.items():
            for table_row in table_data:
                for column, cell_value in zip(db_schema[table.name].columns,
                                              table_row):
                    if column.column_type == 'text' and type(
                            cell_value) is str:
                        cell_value_normalized = self.normalize_string(
                            cell_value)
                        column_key = self.entity_key_for_column(
                            table.name, column)
                        string_column_mapping[cell_value_normalized].add(
                            column_key)

        # 这步返回了 question 中所有的 entities  ( str : list )
        string_entities = self.get_entities_from_question(
            string_column_mapping)

        for table in tables:
            table_key = f"table:{table.name.lower()}"
            entities.add(table_key)  #对于wikisql 来说是没有区别的 都是 db_id
            entity_text[
                table_key] = table.text  # table.name = table.text =db_id

            for column in db_schema[table.name].columns:
                # eneity_key 是column 的签名
                entity_key = self.entity_key_for_column(table.name, column)
                # eneities 同时有table_key和entity_key (column_key)
                entities.add(entity_key)
                # 而且可以通过neighbors来标记 table/column的关系
                neighbors[entity_key].add(table_key)
                neighbors[table_key].add(entity_key)
                entity_text[entity_key] = column.text
        # entities中也包含了 question中的entities_text
        for string_entity, column_keys in string_entities:
            entities.add(string_entity)
            for column_key in column_keys:
                neighbors[string_entity].add(column_key)
                neighbors[column_key].add(string_entity)
            entity_text[string_entity] = string_entity.replace("string:",
                                                               "").replace(
                                                                   "_", " ")

        # loop again after we have gone through all columns to link foreign keys columns
        # 对于wikisql来说 foreign_key信息全是None 不生效
        for table_name in db_schema.keys():
            for column in db_schema[table_name].columns:
                if column.foreign_key is None:
                    continue

                other_column_table, other_column_name = column.foreign_key.split(
                    ':')

                # must have exactly one by design
                other_column = [
                    col for col in db_schema[other_column_table].columns
                    if col.name == other_column_name
                ][0]

                entity_key = self.entity_key_for_column(table_name, column)
                other_entity_key = self.entity_key_for_column(
                    other_column_table, other_column)

                neighbors[entity_key].add(other_entity_key)
                neighbors[other_entity_key].add(entity_key)

                foreign_keys_to_column[entity_key] = other_entity_key

        kg = KnowledgeGraph(entities, dict(neighbors), entity_text)
        kg.foreign_keys_to_column = foreign_keys_to_column

        return kg
Exemple #12
0
    def __init__(
        self,
        lazy: bool = False,
        sample: int = -1,
        lf_syntax: str = None,
        replace_world_entities: bool = False,
        align_world_extractions: bool = False,
        gold_world_extractions: bool = False,
        tagger_only: bool = False,
        denotation_only: bool = False,
        world_extraction_model: Optional[str] = None,
        skip_attributes_regex: Optional[str] = None,
        entity_bits_mode: Optional[str] = None,
        entity_types: Optional[List[str]] = None,
        lexical_cues: List[str] = None,
        tokenizer: Tokenizer = None,
        question_token_indexers: Dict[str, TokenIndexer] = None,
    ) -> None:
        super().__init__(lazy=lazy)
        self._tokenizer = tokenizer or WordTokenizer()
        self._question_token_indexers = question_token_indexers or {
            "tokens": SingleIdTokenIndexer()
        }
        self._entity_token_indexers = self._question_token_indexers
        self._sample = sample
        self._replace_world_entities = replace_world_entities
        self._lf_syntax = lf_syntax
        self._entity_bits_mode = entity_bits_mode
        self._align_world_extractions = align_world_extractions
        self._gold_world_extractions = gold_world_extractions
        self._entity_types = entity_types
        self._tagger_only = tagger_only
        self._denotation_only = denotation_only
        self._skip_attributes_regex = None
        if skip_attributes_regex is not None:
            self._skip_attributes_regex = re.compile(skip_attributes_regex)
        self._lexical_cues = lexical_cues

        # Recording of entities in categories relevant for tagging
        all_entities = {}
        all_entities["world"] = ["world1", "world2"]
        # TODO: Clarify this into an appropriate parameter
        self._collapse_tags = ["world"]

        self._all_entities = None
        if entity_types is not None:
            if self._entity_bits_mode == "collapsed":
                self._all_entities = entity_types
            else:
                self._all_entities = [e for t in entity_types for e in all_entities[t]]

        logger.info(f"all_entities = {self._all_entities}")

        # Base world, depending on LF syntax only
        self._knowledge_graph = KnowledgeGraph(
            entities={"placeholder"}, neighbors={}, entity_text={"placeholder": "placeholder"}
        )
        self._world = QuarelWorld(self._knowledge_graph, self._lf_syntax)

        # Decide dynamic entities, if any
        self._dynamic_entities: Dict[str, str] = dict()
        self._use_attr_entities = False
        if "_attr_entities" in lf_syntax:
            self._use_attr_entities = True
            qr_coeff_sets = self._world.qr_coeff_sets
            for qset in qr_coeff_sets:
                for attribute in qset:
                    if (
                        self._skip_attributes_regex is not None
                        and self._skip_attributes_regex.search(attribute)
                    ):
                        continue
                    # Get text associated with each entity, both from entity identifier and
                    # associated lexical cues, if any
                    entity_strings = [words_from_entity_string(attribute).lower()]
                    if self._lexical_cues is not None:
                        for key in self._lexical_cues:
                            if attribute in LEXICAL_CUES[key]:
                                entity_strings += LEXICAL_CUES[key][attribute]
                    self._dynamic_entities["a:" + attribute] = " ".join(entity_strings)

        # Update world to include dynamic entities
        if self._use_attr_entities:
            logger.info(f"dynamic_entities = {self._dynamic_entities}")
            neighbors: Dict[str, List[str]] = {key: [] for key in self._dynamic_entities}
            self._knowledge_graph = KnowledgeGraph(
                entities=set(self._dynamic_entities.keys()),
                neighbors=neighbors,
                entity_text=self._dynamic_entities,
            )
            self._world = QuarelWorld(self._knowledge_graph, self._lf_syntax)

        self._stemmer = PorterStemmer().stemmer

        self._world_tagger_extractor = None
        self._extract_worlds = False
        if world_extraction_model is not None:
            logger.info("Loading world tagger model...")
            self._extract_worlds = True
            self._world_tagger_extractor = WorldTaggerExtractor(world_extraction_model)
            logger.info("Done loading world tagger model!")

        # Convenience regex for recognizing attributes
        self._attr_regex = re.compile(r"""\((\w+) (high|low|higher|lower)""")
Exemple #13
0
    def text_to_instance(
        self,  # type: ignore
        question: str,
        logical_forms: List[str] = None,
        additional_metadata: Dict[str, Any] = None,
        world_extractions: Dict[str, Union[str, List[str]]] = None,
        entity_literals: Dict[str, Union[str, List[str]]] = None,
        tokenized_question: List[Token] = None,
        debug_counter: int = None,
        qr_spec_override: List[Dict[str, int]] = None,
        dynamic_entities_override: Dict[str, str] = None,
    ) -> Instance:

        tokenized_question = tokenized_question or self._tokenizer.tokenize(question.lower())
        additional_metadata = additional_metadata or dict()
        additional_metadata["question_tokens"] = [token.text for token in tokenized_question]
        if world_extractions is not None:
            additional_metadata["world_extractions"] = world_extractions
        question_field = TextField(tokenized_question, self._question_token_indexers)

        if qr_spec_override is not None or dynamic_entities_override is not None:
            # Dynamically specify theory and/or entities
            dynamic_entities = dynamic_entities_override or self._dynamic_entities
            neighbors: Dict[str, List[str]] = {key: [] for key in dynamic_entities.keys()}
            knowledge_graph = KnowledgeGraph(
                entities=set(dynamic_entities.keys()),
                neighbors=neighbors,
                entity_text=dynamic_entities,
            )
            world = QuarelWorld(knowledge_graph, self._lf_syntax, qr_coeff_sets=qr_spec_override)
        else:
            knowledge_graph = self._knowledge_graph
            world = self._world

        table_field = KnowledgeGraphField(
            knowledge_graph,
            tokenized_question,
            self._entity_token_indexers,
            tokenizer=self._tokenizer,
        )

        if self._tagger_only:
            fields: Dict[str, Field] = {"tokens": question_field}
            if entity_literals is not None:
                entity_tags = self._get_entity_tags(
                    self._all_entities, table_field, entity_literals, tokenized_question
                )
                if debug_counter > 0:
                    logger.info(f"raw entity tags = {entity_tags}")
                entity_tags_bio = self._convert_tags_bio(entity_tags)
                fields["tags"] = SequenceLabelField(entity_tags_bio, question_field)
                additional_metadata["tags_gold"] = entity_tags_bio
            additional_metadata["words"] = [x.text for x in tokenized_question]
            fields["metadata"] = MetadataField(additional_metadata)
            return Instance(fields)

        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(" -> ")
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            "question": question_field,
            "table": table_field,
            "world": world_field,
            "actions": action_field,
        }

        if self._denotation_only:
            denotation_field = LabelField(additional_metadata["answer_index"], skip_indexing=True)
            fields["denotation_target"] = denotation_field

        if self._entity_bits_mode is not None and world_extractions is not None:
            entity_bits = self._get_entity_tags(
                ["world1", "world2"], table_field, world_extractions, tokenized_question
            )
            if self._entity_bits_mode == "simple":
                entity_bits_v = [[[0, 0], [1, 0], [0, 1]][tag] for tag in entity_bits]
            elif self._entity_bits_mode == "simple_collapsed":
                entity_bits_v = [[[0], [1], [1]][tag] for tag in entity_bits]
            elif self._entity_bits_mode == "simple3":
                entity_bits_v = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]][tag] for tag in entity_bits]

            entity_bits_field = ArrayField(np.array(entity_bits_v))
            fields["entity_bits"] = entity_bits_field

        if logical_forms:
            action_map = {
                action.rule: i for i, action in enumerate(action_field.field_list)  # type: ignore
            }
            action_sequence_fields: List[Field] = []
            for logical_form in logical_forms:
                expression = world.parse_logical_form(logical_form)
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(IndexField(action_map[production_rule], action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.info(f"Missing production rule: {error.args}, skipping logical form")
                    logger.info(f"Question was: {question}")
                    logger.info(f"Logical form was: {logical_form}")
                    continue
            fields["target_action_sequences"] = ListField(action_sequence_fields)
        fields["metadata"] = MetadataField(additional_metadata or {})
        return Instance(fields)