Ejemplo n.º 1
0
    def infer_query(self, question, db_id):
        ###step-1 schema-linking
        lemma_utterance_stanza = self.stanza_model(question)
        lemma_utterance = [
            word.lemma for sent in lemma_utterance_stanza.sentences
            for word in sent.words
        ]
        db_context = SpiderDBContext(db_id,
                                     lemma_utterance,
                                     tables_file=self.schema_path,
                                     dataset_path=self.db_path,
                                     stanza_model=self.stanza_model,
                                     schemas=self.schema,
                                     original_utterance=question)
        value_match, value_alignment, exact_match, partial_match = db_context.get_db_knowledge_graph(
            db_id)

        item = {}
        item['interaction'] = [{
            'db_id': db_id,
            'question': question,
            'sql': '',
            'value_match': value_match,
            'value_alignment': value_alignment,
            'exact_match': exact_match,
            'partial_match': partial_match,
        }]

        ###step-2 serialization
        source_sequence, _ = build_schema_linking_data(
            schema=self.database_schemas[db_id],
            question=question,
            item=item,
            turn_id=0,
            linking_type='default')
        slml_question = source_sequence[0]

        ###step-3 prediction
        schemas, eval_foreign_key_maps = load_tables(self.schema_path)
        original_schemas = load_original_schemas(self.schema_path)
        alias_schema = get_alias_schema(schemas)
        rnt = decode_with_constrain(slml_question, alias_schema[db_id],
                                    self.model)
        predict_sql = rnt[0]['text'] if isinstance(rnt[0]['text'],
                                                   str) else rnt[0]['text'][0]
        score = rnt[0]['score'].tolist()

        predict_sql = post_processing_sql(predict_sql,
                                          eval_foreign_key_maps[db_id],
                                          original_schemas[db_id],
                                          schemas[db_id])

        return {
            "slml_question": slml_question,
            "predict_sql": predict_sql,
            "score": score
        }
Ejemplo n.º 2
0
    def text_to_instance(self,
                         utterance: str,
                         db_id: str,
                         sql: List[str] = None):
        fields: Dict[str, Field] = {}

        db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer,
                                     tables_file=self._tables_file, dataset_path=self._dataset_path)
        table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph,
                                                db_context.tokenized_utterance,
                                                self._utterance_token_indexers,
                                                entity_tokens=db_context.entity_tokens,
                                                include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                                                max_table_tokens=None)  # self._max_table_tokens)
        
        world = SpiderWorld(db_context, query=sql)
        fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers)

        action_sequence, all_actions = world.get_action_sequence_and_all_actions()
        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, rhs = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(rhs),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)
            
        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {action.rule: i  # type: ignore
                      for i, action in enumerate(valid_actions_field.field_list)}

        for production_rule in action_sequence:
            index_fields.append(IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]
        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        fields["world"] = MetadataField(world)
        fields["schema"] = table_field
        return Instance(fields)
Ejemplo n.º 3
0
    def text_to_instance(self,
                         utterance: str,
                         db_id: str,
                         sql: List[str] = None):
        fields: Dict[str, Field] = {}

        db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer,
                                     tables_file=self._tables_file, dataset_path=self._dataset_path)
        table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph,
                                                db_context.tokenized_utterance,
                                                {},
                                                entity_tokens=db_context.entity_tokens,
                                                include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                                                max_table_tokens=None)  # self._max_table_tokens)


        combined_tokens = [] + db_context.tokenized_utterance
        entity_token_map = dict(zip(db_context.knowledge_graph.entities, db_context.entity_tokens))
        entity_tokens = []
        for e in db_context.knowledge_graph.entities:
            if e.startswith('column:'):
                table_name, column_name = e.split(':')[-2:]
                table_tokens = entity_token_map['table:'+table_name]
                column_tokens = entity_token_map[e]
                if column_name.startswith(table_name):
                    column_tokens = column_tokens[len(table_tokens):]
                entity_tokens.append(table_tokens + [Token(text='[unused30]')] + column_tokens)
            else:
                entity_tokens.append(entity_token_map[e])
        for e in entity_tokens:
            combined_tokens += [Token(text='[SEP]')] + e
        if len(combined_tokens) > 450:
            return None
        db_context.entity_tokens = entity_tokens
        fields["utterance"] = TextField(combined_tokens, self._utterance_token_indexers)

        world = SpiderWorld(db_context, query=sql)

        action_sequence, all_actions = world.get_action_sequence_and_all_actions()

        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, _ = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(nonterminal),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)

        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {action.rule: i  # type: ignore
                      for i, action in enumerate(valid_actions_field.field_list)}

        for production_rule in action_sequence:
            index_fields.append(IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]

        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        fields["world"] = MetadataField(world)
        fields["schema"] = table_field
        return Instance(fields)
Ejemplo n.º 4
0
    def text_to_instance(self,
                         utterance: str,
                         db_id: str,
                         sql: List[str] = None):
        fields: Dict[str, Field] = {}
        if self._is_spider:
            db_context = SpiderDBContext(db_id,
                                         utterance,
                                         tokenizer=self._tokenizer,
                                         tables_file=self._tables_file,
                                         dataset_path=self._dataset_path)
            table_field = SpiderKnowledgeGraphField(
                db_context.knowledge_graph,
                db_context.tokenized_utterance,
                self._utterance_token_indexers,
                entity_tokens=db_context.entity_tokens,
                include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                max_table_tokens=None)  # self._max_table_tokens)

            world = SpiderWorld(db_context, nl_context=None, query=sql)
            fields["utterance"] = TextField(db_context.tokenized_utterance,
                                            self._utterance_token_indexers)

            action_sequence, all_actions = world.get_action_sequence_and_all_actions(
            )

            if action_sequence is None and self._keep_if_unparsable:
                # print("Parse error")
                action_sequence = []
            elif action_sequence is None:
                return None

            index_fields: List[Field] = []
            production_rule_fields: List[Field] = []

            for production_rule in all_actions:
                nonterminal, rhs = production_rule.split(' -> ')
                production_rule = ' '.join(production_rule.split(' '))
                field = ProductionRuleField(production_rule,
                                            world.is_global_rule(rhs),
                                            nonterminal=nonterminal)
                production_rule_fields.append(field)

            valid_actions_field = ListField(production_rule_fields)
            fields["valid_actions"] = valid_actions_field

            action_map = {
                action.rule: i  # type: ignore
                for i, action in enumerate(valid_actions_field.field_list)
            }

            for production_rule in action_sequence:
                index_fields.append(
                    IndexField(action_map[production_rule],
                               valid_actions_field))
            if not action_sequence:
                index_fields = [IndexField(-1, valid_actions_field)]

            action_sequence_field = ListField(index_fields)
            fields["action_sequence"] = action_sequence_field
            fields["world"] = MetadataField(world)
            fields["schema"] = table_field
        else:
            db_context = WikiDBContext(db_id,
                                       utterance,
                                       tokenizer=self._tokenizer,
                                       tables_file=self._tables_file,
                                       dataset_path=self._dataset_path)
            #print(db_context.entity_tokens)
            #todo 这个WikiKnowledgeGraphField和对应的spider的一模一样 只是改动了类名
            table_field = WikiKnowledgeGraphField(
                db_context.knowledge_graph,
                db_context.tokenized_utterance,
                self._utterance_token_indexers,
                entity_tokens=db_context.entity_tokens,
                include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                max_table_tokens=None)  # self._max_table_tokens)

            world = WikiWorld(db_context, nl_context=None, query=sql)
            fields["utterance"] = TextField(db_context.tokenized_utterance,
                                            self._utterance_token_indexers)

            #todo 这一步会报错 应该是grammar不匹配的问题 ParseError:['select', '1-10015132-11@Position', 'where', '1-10015132-11@School/Club Team', '=', "'value'"]
            #todo action_sequence: None
            #todo all_actions: ['arg_list -> [expr, ",", arg_list]', 'arg_list -> [expr]', 'arg_list_or_star -> ["*"]', 'arg_list_or_star -> [arg_list]', 'binaryop -> ["!="]', 'binaryop -> ["*"]', 'binaryop -> ["+"]', 'binaryop -> ["-"]', 'binaryop -> ["/"]', 'binaryop -> ["<"]', 'binaryop -> ["<="]', 'binaryop -> ["<>"]', 'binaryop -> ["="]', 'binaryop -> [">"]', 'binaryop -> [">="]', 'binaryop -> ["and"]', 'binaryop -> ["like"]', 'binaryop -> ["or"]', 'boolean -> ["false"]', 'boolean -> ["true"]', 'col_ref -> ["1-10015132-11@col0"]', 'col_ref -> ["1-10015132-11@col1"]', 'col_ref -> ["1-10015132-11@col2"]', 'col_ref -> ["1-10015132-11@col3"]', 'col_ref -> ["1-10015132-11@col4"]', 'col_ref -> ["1-10015132-11@col5"]', 'column_name -> ["1-10015132-11@col0"]', 'column_name -> ["1-10015132-11@col1"]', 'column_name -> ["1-10015132-11@col2"]', 'column_name -> ["1-10015132-11@col3"]', 'column_name -> ["1-10015132-11@col4"]', 'column_name -> ["1-10015132-11@col5"]', 'expr -> [in_expr]', 'expr -> [source_subq]', 'expr -> [unaryop, expr]', 'expr -> [value, "between", value, "and", value]', 'expr -> [value, "like", string]', 'expr -> [value, binaryop, expr]', 'expr -> [value]', 'fname -> ["all"]', 'fname -> ["avg"]', 'fname -> ["count"]', 'fname -> ["max"]', 'fname -> ["min"]', 'fname -> ["sum"]', 'from_clause -> ["from", source]', 'from_clause -> ["from", table_name, join_clauses]', 'function -> [fname, "(", "distinct", arg_list_or_star, ")"]', 'function -> [fname, "(", arg_list_or_star, ")"]', 'group_clause -> [expr, ",", group_clause]', 'group_clause -> [expr]', 'groupby_clause -> ["group", "by", group_clause, "having", expr]', 'groupby_clause -> ["group", "by", group_clause]', 'in_expr -> [value, "in", expr]', 'in_expr -> [value, "in", string_set]', 'in_expr -> [value, "not", "in", expr]', 'in_expr -> [value, "not", "in", string_set]', 'iue -> ["except"]', 'iue -> ["intersect"]', 'iue -> ["union"]', 'join_clause -> ["join", table_name, "on", join_condition_clause]', 'join_clauses -> [join_clause, join_clauses]', 'join_clauses -> [join_clause]', 'join_condition -> [column_name, "=", column_name]', 'join_condition_clause -> [join_condition, "and", join_condition_clause]', 'join_condition_clause -> [join_condition]', 'limit -> ["limit", non_literal_number]', 'non_literal_number -> ["1"]', 'non_literal_number -> ["2"]', 'non_literal_number -> ["3"]', 'non_literal_number -> ["4"]', 'number -> ["value"]', 'order_clause -> [ordering_term, ",", order_clause]', 'order_clause -> [ordering_term]', 'orderby_clause -> ["order", "by", order_clause]', 'ordering -> ["asc"]', 'ordering -> ["desc"]', 'ordering_term -> [expr, ordering]', 'ordering_term -> [expr]', 'parenval -> ["(", expr, ")"]', 'query -> [select_core, groupby_clause, limit]', 'query -> [select_core, groupby_clause, orderby_clause, limit]', 'query -> [select_core, groupby_clause, orderby_clause]', 'query -> [select_core, groupby_clause]', 'query -> [select_core, orderby_clause, limit]', 'query -> [select_core, orderby_clause]', 'query -> [select_core]', 'select_core -> [select_with_distinct, select_results, from_clause, where_clause]', 'select_core -> [select_with_distinct, select_results, from_clause]', 'select_core -> [select_with_distinct, select_results, where_clause]', 'select_core -> [select_with_distinct, select_results]', 'select_result -> ["*"]', 'select_result -> [column_name]', 'select_result -> [expr]', 'select_result -> [table_name, ".*"]', 'select_results -> [select_result, ",", select_results]', 'select_results -> [select_result]', 'select_with_distinct -> ["select", "distinct"]', 'select_with_distinct -> ["select"]', 'single_source -> [source_subq]', 'single_source -> [table_name]', 'source -> [single_source, ",", source]', 'source -> [single_source]', 'source_subq -> ["(", query, ")"]', 'statement -> [query, iue, query]', 'statement -> [query]', 'string -> ["\'", "value", "\'"]', 'string_set -> ["(", string_set_vals, ")"]', 'string_set_vals -> [string, ",", string_set_vals]', 'string_set_vals -> [string]', "table_name -> ['1-10015132-11']", "table_source -> ['1-10015132-11']", 'unaryop -> ["+"]', 'unaryop -> ["-"]', 'unaryop -> ["not"]', 'value -> ["YEAR(CURDATE())"]', 'value -> [boolean]', 'value -> [column_name]', 'value -> [function]', 'value -> [number]', 'value -> [parenval]', 'value -> [string]', 'where_clause -> ["where", expr, where_conj]', 'where_clause -> ["where", expr]', 'where_conj -> ["and", expr, where_conj]', 'where_conj -> ["and", expr]']
            action_sequence, all_actions = world.get_action_sequence_and_all_actions(
            )

            if action_sequence is None and self._keep_if_unparsable:
                # print("Parse error")
                action_sequence = []
            elif action_sequence is None:
                return None

            production_rule_fields: List[Field] = []

            for production_rule in all_actions:

                nonterminal, rhs = production_rule.split(' -> ')
                production_rule = ' '.join(production_rule.split(' '))

                field = ProductionRuleField(production_rule,
                                            world.is_global_rule(rhs),
                                            nonterminal=nonterminal)
                production_rule_fields.append(field)

            valid_actions_field = ListField(production_rule_fields)
            fields["valid_actions"] = valid_actions_field

            index_fields: List[Field] = []

            # action: ProductionRuleField
            action_map = {
                action.rule: i  # type: ignore
                for i, action in enumerate(valid_actions_field.field_list)
            }

            for production_rule in action_sequence:
                index_fields.append(
                    IndexField(action_map[production_rule],
                               valid_actions_field))
            if not action_sequence:
                index_fields = [IndexField(-1, valid_actions_field)]

            action_sequence_field = ListField(index_fields)
            fields["action_sequence"] = action_sequence_field

            fields["world"] = MetadataField(world)
            fields["schema"] = table_field

            #print(fields)

        return Instance(fields)
Ejemplo n.º 5
0
    def text_to_instance(
            self,
            utterance: str,  # question
            db_id: str,
            sql: List[str] = None):
        fields: Dict[str, Field] = {}

        # db_context is db graph and its tokens. It include: utterance, db graph
        db_context = SpiderDBContext(db_id,
                                     utterance,
                                     tokenizer=self._tokenizer,
                                     tables_file=self._tables_file,
                                     dataset_path=self._dataset_path)

        # A instance contain many fields and must be filed obj in allennlp. (You can consider fields are columns in table or attribute in obj)
        # So we need to convert the db_context to a Filed obj which is table_field.
        # db_context.knowledge_graph is a graph so we need a graph field obj and SpiderKnowledgeGraphField inherit KnowledgeGraphField.
        table_field = SpiderKnowledgeGraphField(
            db_context.knowledge_graph,
            db_context.tokenized_utterance,
            self._utterance_token_indexers,
            entity_tokens=db_context.entity_tokens,
            include_in_vocab=False,  # TODO: self._use_table_for_vocab,
            max_table_tokens=None,
            conceptnet=self._concept_word)  # self._max_table_tokens)

        world = SpiderWorld(db_context, query=sql)

        fields["utterance"] = TextField(db_context.tokenized_utterance,
                                        self._utterance_token_indexers)

        # action_sequence is the parsed result by grammar. The grammar is created by certain database.
        # all_actions is the total grammar string list.
        # So you can consider action_sequence is subset of all_actions.
        # And this subset include all grammar you need for this query.
        # The grammar is defined in semparse/contexts/spider_db_grammar.py which is similar to the BNF grammar type.
        action_sequence, all_actions = world.get_action_sequence_and_all_actions(
        )

        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, rhs = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            # Help ProductionRuleField: https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html?highlight=productionrulefield#production-rule-field
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(rhs),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)

        # valid_actions_field is generated by all_actions that include all grammar.
        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        # give every grammar a id.
        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(valid_actions_field.field_list)
        }

        # give every grammar rule in action_sequence a total grammar.
        # So maybe you can infer this rule to others through the total grammar rules easily.
        if action_sequence:
            for production_rule in action_sequence:
                index_fields.append(
                    IndexField(action_map[production_rule],
                               valid_actions_field))
        else:  # action_sequence is None, which means: our grammar for the query is error.
            index_fields = [IndexField(-1, valid_actions_field)]
            # assert False # gan ???
        action_sequence_field = ListField(index_fields)

        # The fields["valid_actions"] include the global rule which is the same in all SQL and database specific rules.
        # For example, 'binaryop -> ["!="]' and 'query -> [select_core, groupby_clause, limit]' are global rules.
        # 'column_name -> ["department@budget_in_billions"]' and 'col_ref -> ["department@budget_in_billions"]' are not global rules.
        # So fields["valid_actions"] is case by case but will contain the same global rules. And it must include the rules appearing in fields["action_sequence"].
        # Now the attribute of _rule_id is None. But when finish loading all data, allennlp will automatically build the vocabulary and give a unique _ruled_id for every rule.
        # But when forward the fields["valid_actions"] to the model, it will become the ProductionRule List.
        # We will find that the global rules will contain a _rule_id but non-global rules will not.
        # And in ProductionRule List, the global rules will become a tuple and non-global will become a ProductionRule obj.
        # In a tuple and ProductionRule, its[0] shows its rule value, such as: 'where_clause -> ["where", expr, where_conj]' or 'col_ref -> ["department@budget_in_billions"]'.
        # In a tuple and ProductionRule, its[1] shows whether it is global rules, such as True or False.
        # In a tuple and ProductionRule, its[2] shows _rule_id. But if it is non-global rule, it will be None.
        # In a tuple and ProductionRule, its[3] shows left rule value. For example: 'where_clause' is the left rule value of 'where_clause -> ["where", expr, where_conj]'.
        # fields["valid_actions"]   # All action / All grammar

        # The information of fields["valid_actions"] is almost the same as the world.valid_actions but using different representations (world is a SpiderWorld obj)
        # There are two kinds of valid actions in the project but their information is the same.
        # The first one is a set list: (We call it as list-type-action)
        # [   #key                #value                      #key         #value   #other key and value pairs
        #    {rule: 'arg_list -> [expr, ",", arg_list]'  ,  is_global_rule: True,            ... }
        #    {rule: 'arg_list -> [expr]'  ,                 is_global_rule: True,            ... }
        #    {...}
        # ]
        # You can easily extract the all valid action to a list, such as:
        # all_actions:
        # ['arg_list -> [expr, ",", arg_list]' , 'arg_list -> [expr]' , ... ]

        # The second one is also a dict but its key is different and it will combine the same left key value together: (We call it as dict-type-action)
        # {       #key           #value-list
        #       arg_list:[
        #                   '[expr, ",", arg_list]',
        #                   '[expr]'
        #                ]
        #
        #       ...: [...]
        #       ...
        # }
        # Say it again, they are valid actions.

        # fields["utterance"]       # TextFile for utterance
        fields[
            "action_sequence"] = action_sequence_field  # grammar rules (action) of this query, and every rule contains a total grammar set which is fields["valid_actions"].
        fields["world"] = MetadataField(
            world
        )  #Maybe just for calc the metric. # A MetadataField is a Field that does not get converted into tensors. https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html?highlight=metadatafield#metadata-field
        fields["schema"] = table_field
        return Instance(fields)
Ejemplo n.º 6
0
def read_spider_split(dataset_path, table_path, database_path):
    with open(dataset_path) as f:
        split_data = json.load(f)
    print('read_spider_split', dataset_path, len(split_data))

    schemas = read_dataset_schema(table_path, stanza_model)

    interaction_list = {}
    for i, ex in enumerate(tqdm(split_data)):
        db_id = ex['db_id']

        ex['query_toks_no_value'] = normalize_original_sql(ex['query_toks_no_value'])
        turn_sql = ' '.join(ex['query_toks_no_value'])
        turn_sql = turn_sql.replace('select count ( * ) from follows group by value',
                                    'select count ( * ) from follows group by f1')
        ex['query_toks_no_value'] = turn_sql.split(' ')

        ex = fix_number_value(ex)
        try:
            ex['query_toks_no_value'] = disambiguate_items(db_id, ex['query_toks_no_value'],
                                                           tables_file=table_path, allow_aliases=False)
        except:
            print(ex['query_toks'])
            continue

        final_sql_parse = ' '.join(ex['query_toks_no_value'])
        final_utterance = ' '.join(ex['question_toks']).lower()

        if stanza_model is not None:
            lemma_utterance_stanza = stanza_model(final_utterance)
            lemma_utterance = [word.lemma for sent in lemma_utterance_stanza.sentences for word in sent.words]
            original_utterance = final_utterance
        else:
            original_utterance = lemma_utterance = final_utterance.split(' ')

        # using db content
        db_context = SpiderDBContext(db_id,
                                     lemma_utterance,
                                     tables_file=table_path,
                                     dataset_path=database_path,
                                     stanza_model=stanza_model,
                                     schemas=schemas,
                                     original_utterance=original_utterance)

        value_match, value_alignment, exact_match, partial_match = db_context.get_db_knowledge_graph(db_id)

        if value_match != []:
            print(value_match, value_alignment)

        if db_id not in interaction_list:
            interaction_list[db_id] = []

        interaction = {}
        interaction['id'] = i
        interaction['database_id'] = db_id
        interaction['interaction'] = [{'utterance': final_utterance,
                                       'db_id': db_id,
                                       'query': ex['query'],
                                       'question': ex['question'],
                                       'sql': final_sql_parse,
                                       'value_match': value_match,
                                       'value_alignment': value_alignment,
                                       'exact_match': exact_match,
                                       'partial_match': partial_match,
                                       }]
        interaction_list[db_id].append(interaction)

    return interaction_list
Ejemplo n.º 7
0
    def text_to_instance(self,
                         utterance: str,
                         db_id: str,
                         sql: List[str] = None):
        fields: Dict[str, Field] = {}
        """KAIMARY"""
        # Contains
        # 1. db schema(Tables with corresponding columns)
        # 2. Tokenized utterance
        # 3. Knowledge graph(Table entities, column entities and "text" column type related token entities)
        # 4. Entity_tokens(Retrieved from entities_text from kg)
        db_context = SpiderDBContext(db_id,
                                     utterance,
                                     tokenizer=self._tokenizer,
                                     tables_file=self._tables_file,
                                     dataset_path=self._dataset_path)
        # https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html#knowledge-graph-field
        # *feature extractors*
        table_field = SpiderKnowledgeGraphField(
            db_context.knowledge_graph,
            db_context.tokenized_utterance,
            self._utterance_token_indexers,
            entity_tokens=db_context.entity_tokens,
            include_in_vocab=False,  # TODO: self._use_table_for_vocab,
            max_table_tokens=None)  # self._max_table_tokens)

        world = SpiderWorld(db_context, query=sql)
        fields["utterance"] = TextField(db_context.tokenized_utterance,
                                        self._utterance_token_indexers)

        action_sequence, all_actions = world.get_action_sequence_and_all_actions(
        )

        if action_sequence is None and self._keep_if_unparsable:
            # print("Parse error")
            action_sequence = []
        elif action_sequence is None:
            return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, rhs = production_rule.split(' -> ')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(production_rule,
                                        world.is_global_rule(rhs),
                                        nonterminal=nonterminal)
            production_rule_fields.append(field)

        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(valid_actions_field.field_list)
        }

        for production_rule in action_sequence:
            index_fields.append(
                IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]

        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        fields["world"] = MetadataField(world)
        fields["schema"] = table_field
        return Instance(fields)
Ejemplo n.º 8
0
    def text_to_instance(self,
                         utterances: List[str],
                         db_id: str,
                         sql: List[List[str]] = None):
        fields: Dict[str, Field] = {}

        ctxts = [
            SpiderDBContext(db_id,
                            utterance,
                            tokenizer=self._tokenizer,
                            tables_file=self._tables_file,
                            dataset_path=self._dataset_path)
            for utterance in utterances
        ]

        super_utterance = ' '.join(utterances)
        hack_ctxt = SpiderDBContext(db_id,
                                    super_utterance,
                                    tokenizer=self._tokenizer,
                                    tables_file=self._tables_file,
                                    dataset_path=self._dataset_path)

        kg = SpiderKnowledgeGraphField(
            hack_ctxt.knowledge_graph,
            hack_ctxt.tokenized_utterance,
            self._utterance_token_indexers,
            entity_tokens=hack_ctxt.entity_tokens,
            include_in_vocab=False,  # TODO: self._use_table_for_vocab,
            max_table_tokens=None)
        '''
        kgs = [SpiderKnowledgeGraphField(db_context.knowledge_graph,
                                        db_context.tokenized_utterance,
                                        self._utterance_token_indexers,
                                        entity_tokens=db_context.entity_tokens,
                                        include_in_vocab=False,  # TODO: self._use_table_for_vocab,
                                        max_table_tokens=None)  # self._max_table_tokens)
                        for db_context in ctxts]
        '''
        worlds = []

        for i in range(len(sql)):
            sqli = sql[i]
            db_context = ctxts[i]
            world = SpiderWorld(db_context, query=sqli)
            worlds.append(world)

        fields["utterances"] = ListField([
            TextField(db_context.tokenized_utterance,
                      self._utterance_token_indexers) for db_context in ctxts
        ])

        #action_sequence, all_actions = world.get_action_sequence_and_all_actions()

        action_tups = [
            world.get_action_sequence_and_all_actions() for world in worlds
        ]
        action_sequences = [tup[0] for tup in action_tups]
        all_actions = [tup[1] for tup in action_tups]

        for i in range(len(action_sequences)):
            action_sequence = action_sequences[i]

            if action_sequence is None and self._keep_if_unparsable:
                # print("Parse error")
                action_sequence = []
            elif action_sequence is None:
                return None

            action_sequences[i] = action_sequence

        all_valid_actions_fields = []
        all_action_sequence_fields = []

        for i in range(len(all_actions)):
            index_fields: List[Field] = []
            production_rule_fields: List[Field] = []

            all_actionsi = all_actions[i]
            for production_rule in all_actionsi:
                nonterminal, rhs = production_rule.split(' -> ')
                production_rule = ' '.join(production_rule.split(' '))
                field = ProductionRuleField(production_rule,
                                            world.is_global_rule(rhs),
                                            nonterminal=nonterminal)
                production_rule_fields.append(field)

            valid_actions_field = ListField(production_rule_fields)

            all_valid_actions_fields.append(valid_actions_field)

            action_map = {
                action.rule: i  # type: ignore
                for i, action in enumerate(valid_actions_field.field_list)
            }

            index_fields: List[Field] = []

            action_sequence = action_sequences[i]

            for production_rule in action_sequence:
                index_fields.append(
                    IndexField(action_map[production_rule],
                               valid_actions_field))
            if not action_sequence:
                index_fields = [IndexField(-1, valid_actions_field)]

            action_sequence_field = ListField(index_fields)

            all_action_sequence_fields.append(action_sequence_field)

        fields["valid_actions"] = ListField(all_valid_actions_fields)
        fields["action_sequences"] = ListField(all_action_sequence_fields)
        fields["worlds"] = ListField(
            [MetadataField(world) for world in worlds])
        fields["schema"] = kg
        '''
        fields['utterances'] = ListField[TextField]
        fields['valid_actions'] = ListField[ListField[ProductionRuleField]]
        fields['action_sequences'] = ListField[ListField[IndexField]]
        fields['worlds'] = ListField[MetadataField[SpiderWorld]]
        fields['schemas'] = ListField[SpiderKnowledgeGraphField]
        '''

        return Instance(fields)