Exemplo n.º 1
0
    def text_to_instance(self,
                         utterance: str,
                         db_id: str,
                         sql: List[str] = None):
        fields: Dict[str, Field] = {}

        db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer,
                                     tables_file=self._tables_file, dataset_path=self._dataset_path)
        table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph,
                                                db_context.tokenized_utterance,
                                                self._utterance_token_indexers,
                                                entity_tokens=db_context.entity_tokens,
                                                include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                                                max_table_tokens=None)  # self._max_table_tokens)
        
        world = SpiderWorld(db_context, query=sql)
        fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers)

        action_sequence, all_actions = world.get_action_sequence_and_all_actions()
        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, rhs = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(rhs),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)
            
        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {action.rule: i  # type: ignore
                      for i, action in enumerate(valid_actions_field.field_list)}

        for production_rule in action_sequence:
            index_fields.append(IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]
        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        fields["world"] = MetadataField(world)
        fields["schema"] = table_field
        return Instance(fields)
Exemplo n.º 2
0
    def text_to_instance(self,
                         utterance: str,
                         db_id: str,
                         sql: List[str] = None):
        fields: Dict[str, Field] = {}

        db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer,
                                     tables_file=self._tables_file, dataset_path=self._dataset_path)
        table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph,
                                                db_context.tokenized_utterance,
                                                {},
                                                entity_tokens=db_context.entity_tokens,
                                                include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                                                max_table_tokens=None)  # self._max_table_tokens)


        combined_tokens = [] + db_context.tokenized_utterance
        entity_token_map = dict(zip(db_context.knowledge_graph.entities, db_context.entity_tokens))
        entity_tokens = []
        for e in db_context.knowledge_graph.entities:
            if e.startswith('column:'):
                table_name, column_name = e.split(':')[-2:]
                table_tokens = entity_token_map['table:'+table_name]
                column_tokens = entity_token_map[e]
                if column_name.startswith(table_name):
                    column_tokens = column_tokens[len(table_tokens):]
                entity_tokens.append(table_tokens + [Token(text='[unused30]')] + column_tokens)
            else:
                entity_tokens.append(entity_token_map[e])
        for e in entity_tokens:
            combined_tokens += [Token(text='[SEP]')] + e
        if len(combined_tokens) > 450:
            return None
        db_context.entity_tokens = entity_tokens
        fields["utterance"] = TextField(combined_tokens, self._utterance_token_indexers)

        world = SpiderWorld(db_context, query=sql)

        action_sequence, all_actions = world.get_action_sequence_and_all_actions()

        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, _ = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(nonterminal),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)

        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {action.rule: i  # type: ignore
                      for i, action in enumerate(valid_actions_field.field_list)}

        for production_rule in action_sequence:
            index_fields.append(IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]

        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        fields["world"] = MetadataField(world)
        fields["schema"] = table_field
        return Instance(fields)
Exemplo n.º 3
0
    def text_to_instance(
            self,
            utterance: str,  # question
            db_id: str,
            sql: List[str] = None):
        fields: Dict[str, Field] = {}

        # db_context is db graph and its tokens. It include: utterance, db graph
        db_context = SpiderDBContext(db_id,
                                     utterance,
                                     tokenizer=self._tokenizer,
                                     tables_file=self._tables_file,
                                     dataset_path=self._dataset_path)

        # A instance contain many fields and must be filed obj in allennlp. (You can consider fields are columns in table or attribute in obj)
        # So we need to convert the db_context to a Filed obj which is table_field.
        # db_context.knowledge_graph is a graph so we need a graph field obj and SpiderKnowledgeGraphField inherit KnowledgeGraphField.
        table_field = SpiderKnowledgeGraphField(
            db_context.knowledge_graph,
            db_context.tokenized_utterance,
            self._utterance_token_indexers,
            entity_tokens=db_context.entity_tokens,
            include_in_vocab=False,  # TODO: self._use_table_for_vocab,
            max_table_tokens=None,
            conceptnet=self._concept_word)  # self._max_table_tokens)

        world = SpiderWorld(db_context, query=sql)

        fields["utterance"] = TextField(db_context.tokenized_utterance,
                                        self._utterance_token_indexers)

        # action_sequence is the parsed result by grammar. The grammar is created by certain database.
        # all_actions is the total grammar string list.
        # So you can consider action_sequence is subset of all_actions.
        # And this subset include all grammar you need for this query.
        # The grammar is defined in semparse/contexts/spider_db_grammar.py which is similar to the BNF grammar type.
        action_sequence, all_actions = world.get_action_sequence_and_all_actions(
        )

        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, rhs = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            # Help ProductionRuleField: https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html?highlight=productionrulefield#production-rule-field
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(rhs),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)

        # valid_actions_field is generated by all_actions that include all grammar.
        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        # give every grammar a id.
        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(valid_actions_field.field_list)
        }

        # give every grammar rule in action_sequence a total grammar.
        # So maybe you can infer this rule to others through the total grammar rules easily.
        if action_sequence:
            for production_rule in action_sequence:
                index_fields.append(
                    IndexField(action_map[production_rule],
                               valid_actions_field))
        else:  # action_sequence is None, which means: our grammar for the query is error.
            index_fields = [IndexField(-1, valid_actions_field)]
            # assert False # gan ???
        action_sequence_field = ListField(index_fields)

        # The fields["valid_actions"] include the global rule which is the same in all SQL and database specific rules.
        # For example, 'binaryop -> ["!="]' and 'query -> [select_core, groupby_clause, limit]' are global rules.
        # 'column_name -> ["department@budget_in_billions"]' and 'col_ref -> ["department@budget_in_billions"]' are not global rules.
        # So fields["valid_actions"] is case by case but will contain the same global rules. And it must include the rules appearing in fields["action_sequence"].
        # Now the attribute of _rule_id is None. But when finish loading all data, allennlp will automatically build the vocabulary and give a unique _ruled_id for every rule.
        # But when forward the fields["valid_actions"] to the model, it will become the ProductionRule List.
        # We will find that the global rules will contain a _rule_id but non-global rules will not.
        # And in ProductionRule List, the global rules will become a tuple and non-global will become a ProductionRule obj.
        # In a tuple and ProductionRule, its[0] shows its rule value, such as: 'where_clause -> ["where", expr, where_conj]' or 'col_ref -> ["department@budget_in_billions"]'.
        # In a tuple and ProductionRule, its[1] shows whether it is global rules, such as True or False.
        # In a tuple and ProductionRule, its[2] shows _rule_id. But if it is non-global rule, it will be None.
        # In a tuple and ProductionRule, its[3] shows left rule value. For example: 'where_clause' is the left rule value of 'where_clause -> ["where", expr, where_conj]'.
        # fields["valid_actions"]   # All action / All grammar

        # The information of fields["valid_actions"] is almost the same as the world.valid_actions but using different representations (world is a SpiderWorld obj)
        # There are two kinds of valid actions in the project but their information is the same.
        # The first one is a set list: (We call it as list-type-action)
        # [   #key                #value                      #key         #value   #other key and value pairs
        #    {rule: 'arg_list -> [expr, ",", arg_list]'  ,  is_global_rule: True,            ... }
        #    {rule: 'arg_list -> [expr]'  ,                 is_global_rule: True,            ... }
        #    {...}
        # ]
        # You can easily extract the all valid action to a list, such as:
        # all_actions:
        # ['arg_list -> [expr, ",", arg_list]' , 'arg_list -> [expr]' , ... ]

        # The second one is also a dict but its key is different and it will combine the same left key value together: (We call it as dict-type-action)
        # {       #key           #value-list
        #       arg_list:[
        #                   '[expr, ",", arg_list]',
        #                   '[expr]'
        #                ]
        #
        #       ...: [...]
        #       ...
        # }
        # Say it again, they are valid actions.

        # fields["utterance"]       # TextFile for utterance
        fields[
            "action_sequence"] = action_sequence_field  # grammar rules (action) of this query, and every rule contains a total grammar set which is fields["valid_actions"].
        fields["world"] = MetadataField(
            world
        )  #Maybe just for calc the metric. # A MetadataField is a Field that does not get converted into tensors. https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html?highlight=metadatafield#metadata-field
        fields["schema"] = table_field
        return Instance(fields)
Exemplo n.º 4
0
    def text_to_instance(self,
                         utterance: str,
                         db_id: str,
                         sql: List[str] = None):
        fields: Dict[str, Field] = {}
        """KAIMARY"""
        # Contains
        # 1. db schema(Tables with corresponding columns)
        # 2. Tokenized utterance
        # 3. Knowledge graph(Table entities, column entities and "text" column type related token entities)
        # 4. Entity_tokens(Retrieved from entities_text from kg)
        db_context = SpiderDBContext(db_id,
                                     utterance,
                                     tokenizer=self._tokenizer,
                                     tables_file=self._tables_file,
                                     dataset_path=self._dataset_path)
        # https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html#knowledge-graph-field
        # *feature extractors*
        table_field = SpiderKnowledgeGraphField(
            db_context.knowledge_graph,
            db_context.tokenized_utterance,
            self._utterance_token_indexers,
            entity_tokens=db_context.entity_tokens,
            include_in_vocab=False,  # TODO: self._use_table_for_vocab,
            max_table_tokens=None)  # self._max_table_tokens)

        world = SpiderWorld(db_context, query=sql)
        fields["utterance"] = TextField(db_context.tokenized_utterance,
                                        self._utterance_token_indexers)

        action_sequence, all_actions = world.get_action_sequence_and_all_actions(
        )

        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, rhs = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(rhs),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)

        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(valid_actions_field.field_list)
        }

        for production_rule in action_sequence:
            index_fields.append(
                IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]

        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        fields["world"] = MetadataField(world)
        fields["schema"] = table_field
        return Instance(fields)
Exemplo n.º 5
0
    def text_to_instance(self,
                         utterances: List[str],
                         db_id: str,
                         sql: List[List[str]] = None):
        fields: Dict[str, Field] = {}

        ctxts = [
            SpiderDBContext(db_id,
                            utterance,
                            tokenizer=self._tokenizer,
                            tables_file=self._tables_file,
                            dataset_path=self._dataset_path)
            for utterance in utterances
        ]

        super_utterance = ' '.join(utterances)
        hack_ctxt = SpiderDBContext(db_id,
                                    super_utterance,
                                    tokenizer=self._tokenizer,
                                    tables_file=self._tables_file,
                                    dataset_path=self._dataset_path)

        kg = SpiderKnowledgeGraphField(
            hack_ctxt.knowledge_graph,
            hack_ctxt.tokenized_utterance,
            self._utterance_token_indexers,
            entity_tokens=hack_ctxt.entity_tokens,
            include_in_vocab=False,  # TODO: self._use_table_for_vocab,
            max_table_tokens=None)
        '''
        kgs = [SpiderKnowledgeGraphField(db_context.knowledge_graph,
                                        db_context.tokenized_utterance,
                                        self._utterance_token_indexers,
                                        entity_tokens=db_context.entity_tokens,
                                        include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                                        max_table_tokens=None)  # self._max_table_tokens)
                        for db_context in ctxts]
        '''
        worlds = []

        for i in range(len(sql)):
            sqli = sql[i]
            db_context = ctxts[i]
            world = SpiderWorld(db_context, query=sqli)
            worlds.append(world)

        fields["utterances"] = ListField([
            TextField(db_context.tokenized_utterance,
                      self._utterance_token_indexers) for db_context in ctxts
        ])

        #action_sequence, all_actions = world.get_action_sequence_and_all_actions()

        action_tups = [
            world.get_action_sequence_and_all_actions() for world in worlds
        ]
        action_sequences = [tup[0] for tup in action_tups]
        all_actions = [tup[1] for tup in action_tups]

        for i in range(len(action_sequences)):
            action_sequence = action_sequences[i]

            if action_sequence is None and self._keep_if_unparsable:
                # print("Parse error")
                action_sequence = []
            elif action_sequence is None:
                return None

            action_sequences[i] = action_sequence

        all_valid_actions_fields = []
        all_action_sequence_fields = []

        for i in range(len(all_actions)):
            index_fields: List[Field] = []
            production_rule_fields: List[Field] = []

            all_actionsi = all_actions[i]
            for production_rule in all_actionsi:
                nonterminal, rhs = production_rule.split(' -> ')
                production_rule = ' '.join(production_rule.split(' '))
                field = ProductionRuleField(production_rule,
                                            world.is_global_rule(rhs),
                                            nonterminal=nonterminal)
                production_rule_fields.append(field)

            valid_actions_field = ListField(production_rule_fields)

            all_valid_actions_fields.append(valid_actions_field)

            action_map = {
                action.rule: i  # type: ignore
                for i, action in enumerate(valid_actions_field.field_list)
            }

            index_fields: List[Field] = []

            action_sequence = action_sequences[i]

            for production_rule in action_sequence:
                index_fields.append(
                    IndexField(action_map[production_rule],
                               valid_actions_field))
            if not action_sequence:
                index_fields = [IndexField(-1, valid_actions_field)]

            action_sequence_field = ListField(index_fields)

            all_action_sequence_fields.append(action_sequence_field)

        fields["valid_actions"] = ListField(all_valid_actions_fields)
        fields["action_sequences"] = ListField(all_action_sequence_fields)
        fields["worlds"] = ListField(
            [MetadataField(world) for world in worlds])
        fields["schema"] = kg
        '''
        fields['utterances'] = ListField[TextField]
        fields['valid_actions'] = ListField[ListField[ProductionRuleField]]
        fields['action_sequences'] = ListField[ListField[IndexField]]
        fields['worlds'] = ListField[MetadataField[SpiderWorld]]
        fields['schemas'] = ListField[SpiderKnowledgeGraphField]
        '''

        return Instance(fields)