def infer_query(self, question, db_id): ###step-1 schema-linking lemma_utterance_stanza = self.stanza_model(question) lemma_utterance = [ word.lemma for sent in lemma_utterance_stanza.sentences for word in sent.words ] db_context = SpiderDBContext(db_id, lemma_utterance, tables_file=self.schema_path, dataset_path=self.db_path, stanza_model=self.stanza_model, schemas=self.schema, original_utterance=question) value_match, value_alignment, exact_match, partial_match = db_context.get_db_knowledge_graph( db_id) item = {} item['interaction'] = [{ 'db_id': db_id, 'question': question, 'sql': '', 'value_match': value_match, 'value_alignment': value_alignment, 'exact_match': exact_match, 'partial_match': partial_match, }] ###step-2 serialization source_sequence, _ = build_schema_linking_data( schema=self.database_schemas[db_id], question=question, item=item, turn_id=0, linking_type='default') slml_question = source_sequence[0] ###step-3 prediction schemas, eval_foreign_key_maps = load_tables(self.schema_path) original_schemas = load_original_schemas(self.schema_path) alias_schema = get_alias_schema(schemas) rnt = decode_with_constrain(slml_question, alias_schema[db_id], self.model) predict_sql = rnt[0]['text'] if isinstance(rnt[0]['text'], str) else rnt[0]['text'][0] score = rnt[0]['score'].tolist() predict_sql = post_processing_sql(predict_sql, eval_foreign_key_maps[db_id], original_schemas[db_id], schemas[db_id]) return { "slml_question": slml_question, "predict_sql": predict_sql, "score": score }
def text_to_instance(self, utterance: str, db_id: str, sql: List[str] = None): fields: Dict[str, Field] = {} db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) world = SpiderWorld(db_context, query=sql) fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers) action_sequence, all_actions = world.get_action_sequence_and_all_actions() if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = {action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list)} for production_rule in action_sequence: index_fields.append(IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field fields["world"] = MetadataField(world) fields["schema"] = table_field return Instance(fields)
def text_to_instance(self, utterance: str, db_id: str, sql: List[str] = None): fields: Dict[str, Field] = {} db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph, db_context.tokenized_utterance, {}, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) combined_tokens = [] + db_context.tokenized_utterance entity_token_map = dict(zip(db_context.knowledge_graph.entities, db_context.entity_tokens)) entity_tokens = [] for e in db_context.knowledge_graph.entities: if e.startswith('column:'): table_name, column_name = e.split(':')[-2:] table_tokens = entity_token_map['table:'+table_name] column_tokens = entity_token_map[e] if column_name.startswith(table_name): column_tokens = column_tokens[len(table_tokens):] entity_tokens.append(table_tokens + [Token(text='[unused30]')] + column_tokens) else: entity_tokens.append(entity_token_map[e]) for e in entity_tokens: combined_tokens += [Token(text='[SEP]')] + e if len(combined_tokens) > 450: return None db_context.entity_tokens = entity_tokens fields["utterance"] = TextField(combined_tokens, self._utterance_token_indexers) world = SpiderWorld(db_context, query=sql) action_sequence, all_actions = world.get_action_sequence_and_all_actions() if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, _ = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(nonterminal), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = {action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list)} for production_rule in action_sequence: index_fields.append(IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field fields["world"] = MetadataField(world) fields["schema"] = table_field return Instance(fields)
def text_to_instance(self, utterance: str, db_id: str, sql: List[str] = None): fields: Dict[str, Field] = {} if self._is_spider: db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) table_field = SpiderKnowledgeGraphField( db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) world = SpiderWorld(db_context, nl_context=None, query=sql) fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers) action_sequence, all_actions = world.get_action_sequence_and_all_actions( ) if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field fields["world"] = MetadataField(world) fields["schema"] = table_field else: db_context = WikiDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) #print(db_context.entity_tokens) #todo 这个WikiKnowledgeGraphField和对应的spider的一模一样 只是改动了类名 table_field = WikiKnowledgeGraphField( db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) world = WikiWorld(db_context, nl_context=None, query=sql) fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers) #todo 这一步会报错 应该是grammar不匹配的问题 ParseError:['select', '1-10015132-11@Position', 'where', '1-10015132-11@School/Club Team', '=', "'value'"] #todo action_sequence: None #todo all_actions: ['arg_list -> [expr, ",", arg_list]', 'arg_list -> [expr]', 'arg_list_or_star -> ["*"]', 'arg_list_or_star -> [arg_list]', 'binaryop -> ["!="]', 'binaryop -> ["*"]', 'binaryop -> ["+"]', 'binaryop -> ["-"]', 'binaryop -> ["/"]', 'binaryop -> ["<"]', 'binaryop -> ["<="]', 'binaryop -> ["<>"]', 'binaryop -> ["="]', 'binaryop -> [">"]', 'binaryop -> [">="]', 'binaryop -> ["and"]', 'binaryop -> ["like"]', 'binaryop -> ["or"]', 'boolean -> ["false"]', 'boolean -> ["true"]', 'col_ref -> ["1-10015132-11@col0"]', 'col_ref -> ["1-10015132-11@col1"]', 'col_ref -> ["1-10015132-11@col2"]', 'col_ref -> ["1-10015132-11@col3"]', 'col_ref -> ["1-10015132-11@col4"]', 'col_ref -> ["1-10015132-11@col5"]', 'column_name -> ["1-10015132-11@col0"]', 'column_name -> ["1-10015132-11@col1"]', 'column_name -> ["1-10015132-11@col2"]', 'column_name -> ["1-10015132-11@col3"]', 'column_name -> ["1-10015132-11@col4"]', 'column_name -> ["1-10015132-11@col5"]', 'expr -> [in_expr]', 'expr -> [source_subq]', 'expr -> [unaryop, expr]', 'expr -> [value, "between", value, "and", value]', 'expr -> [value, "like", string]', 'expr -> [value, binaryop, expr]', 'expr -> [value]', 'fname -> ["all"]', 'fname -> ["avg"]', 'fname -> ["count"]', 'fname -> ["max"]', 'fname -> ["min"]', 'fname -> ["sum"]', 'from_clause -> ["from", source]', 'from_clause -> ["from", table_name, join_clauses]', 'function -> [fname, "(", "distinct", arg_list_or_star, ")"]', 'function -> [fname, "(", arg_list_or_star, ")"]', 'group_clause -> [expr, ",", group_clause]', 'group_clause -> [expr]', 'groupby_clause -> ["group", "by", group_clause, "having", expr]', 'groupby_clause -> ["group", "by", group_clause]', 'in_expr -> [value, "in", expr]', 'in_expr -> [value, "in", string_set]', 'in_expr -> [value, "not", "in", expr]', 'in_expr -> [value, "not", "in", string_set]', 'iue -> ["except"]', 'iue -> ["intersect"]', 'iue -> ["union"]', 'join_clause -> ["join", table_name, "on", join_condition_clause]', 'join_clauses -> [join_clause, join_clauses]', 'join_clauses -> [join_clause]', 'join_condition -> [column_name, "=", column_name]', 'join_condition_clause -> [join_condition, "and", join_condition_clause]', 'join_condition_clause -> [join_condition]', 'limit -> ["limit", non_literal_number]', 'non_literal_number -> ["1"]', 'non_literal_number -> ["2"]', 'non_literal_number -> ["3"]', 'non_literal_number -> ["4"]', 'number -> ["value"]', 'order_clause -> [ordering_term, ",", order_clause]', 'order_clause -> [ordering_term]', 'orderby_clause -> ["order", "by", order_clause]', 'ordering -> ["asc"]', 'ordering -> ["desc"]', 'ordering_term -> [expr, ordering]', 'ordering_term -> [expr]', 'parenval -> ["(", expr, ")"]', 'query -> [select_core, groupby_clause, limit]', 'query -> [select_core, groupby_clause, orderby_clause, limit]', 'query -> [select_core, groupby_clause, orderby_clause]', 'query -> [select_core, groupby_clause]', 'query -> [select_core, orderby_clause, limit]', 'query -> [select_core, orderby_clause]', 'query -> [select_core]', 'select_core -> [select_with_distinct, select_results, from_clause, where_clause]', 'select_core -> [select_with_distinct, select_results, from_clause]', 'select_core -> [select_with_distinct, select_results, where_clause]', 'select_core -> [select_with_distinct, select_results]', 'select_result -> ["*"]', 'select_result -> [column_name]', 'select_result -> [expr]', 'select_result -> [table_name, ".*"]', 'select_results -> [select_result, ",", select_results]', 'select_results -> [select_result]', 'select_with_distinct -> ["select", "distinct"]', 'select_with_distinct -> ["select"]', 'single_source -> [source_subq]', 'single_source -> [table_name]', 'source -> [single_source, ",", source]', 'source -> [single_source]', 'source_subq -> ["(", query, ")"]', 'statement -> [query, iue, query]', 'statement -> [query]', 'string -> ["\'", "value", "\'"]', 'string_set -> ["(", string_set_vals, ")"]', 'string_set_vals -> [string, ",", string_set_vals]', 'string_set_vals -> [string]', "table_name -> ['1-10015132-11']", "table_source -> ['1-10015132-11']", 'unaryop -> ["+"]', 'unaryop -> ["-"]', 'unaryop -> ["not"]', 'value -> ["YEAR(CURDATE())"]', 'value -> [boolean]', 'value -> [column_name]', 'value -> [function]', 'value -> [number]', 'value -> [parenval]', 'value -> [string]', 'where_clause -> ["where", expr, where_conj]', 'where_clause -> ["where", expr]', 'where_conj -> ["and", expr, where_conj]', 'where_conj -> ["and", expr]'] action_sequence, all_actions = world.get_action_sequence_and_all_actions( ) if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field index_fields: List[Field] = [] # action: ProductionRuleField action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field fields["world"] = MetadataField(world) fields["schema"] = table_field #print(fields) return Instance(fields)
def text_to_instance( self, utterance: str, # question db_id: str, sql: List[str] = None): fields: Dict[str, Field] = {} # db_context is db graph and its tokens. It include: utterance, db graph db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) # A instance contain many fields and must be filed obj in allennlp. (You can consider fields are columns in table or attribute in obj) # So we need to convert the db_context to a Filed obj which is table_field. # db_context.knowledge_graph is a graph so we need a graph field obj and SpiderKnowledgeGraphField inherit KnowledgeGraphField. table_field = SpiderKnowledgeGraphField( db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None, conceptnet=self._concept_word) # self._max_table_tokens) world = SpiderWorld(db_context, query=sql) fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers) # action_sequence is the parsed result by grammar. The grammar is created by certain database. # all_actions is the total grammar string list. # So you can consider action_sequence is subset of all_actions. # And this subset include all grammar you need for this query. # The grammar is defined in semparse/contexts/spider_db_grammar.py which is similar to the BNF grammar type. action_sequence, all_actions = world.get_action_sequence_and_all_actions( ) if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) # Help ProductionRuleField: https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html?highlight=productionrulefield#production-rule-field field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) # valid_actions_field is generated by all_actions that include all grammar. valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field # give every grammar a id. action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } # give every grammar rule in action_sequence a total grammar. # So maybe you can infer this rule to others through the total grammar rules easily. if action_sequence: for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) else: # action_sequence is None, which means: our grammar for the query is error. index_fields = [IndexField(-1, valid_actions_field)] # assert False # gan ??? action_sequence_field = ListField(index_fields) # The fields["valid_actions"] include the global rule which is the same in all SQL and database specific rules. # For example, 'binaryop -> ["!="]' and 'query -> [select_core, groupby_clause, limit]' are global rules. # 'column_name -> ["department@budget_in_billions"]' and 'col_ref -> ["department@budget_in_billions"]' are not global rules. # So fields["valid_actions"] is case by case but will contain the same global rules. And it must include the rules appearing in fields["action_sequence"]. # Now the attribute of _rule_id is None. But when finish loading all data, allennlp will automatically build the vocabulary and give a unique _ruled_id for every rule. # But when forward the fields["valid_actions"] to the model, it will become the ProductionRule List. # We will find that the global rules will contain a _rule_id but non-global rules will not. # And in ProductionRule List, the global rules will become a tuple and non-global will become a ProductionRule obj. # In a tuple and ProductionRule, its[0] shows its rule value, such as: 'where_clause -> ["where", expr, where_conj]' or 'col_ref -> ["department@budget_in_billions"]'. # In a tuple and ProductionRule, its[1] shows whether it is global rules, such as True or False. # In a tuple and ProductionRule, its[2] shows _rule_id. But if it is non-global rule, it will be None. # In a tuple and ProductionRule, its[3] shows left rule value. For example: 'where_clause' is the left rule value of 'where_clause -> ["where", expr, where_conj]'. # fields["valid_actions"] # All action / All grammar # The information of fields["valid_actions"] is almost the same as the world.valid_actions but using different representations (world is a SpiderWorld obj) # There are two kinds of valid actions in the project but their information is the same. # The first one is a set list: (We call it as list-type-action) # [ #key #value #key #value #other key and value pairs # {rule: 'arg_list -> [expr, ",", arg_list]' , is_global_rule: True, ... } # {rule: 'arg_list -> [expr]' , is_global_rule: True, ... } # {...} # ] # You can easily extract the all valid action to a list, such as: # all_actions: # ['arg_list -> [expr, ",", arg_list]' , 'arg_list -> [expr]' , ... ] # The second one is also a dict but its key is different and it will combine the same left key value together: (We call it as dict-type-action) # { #key #value-list # arg_list:[ # '[expr, ",", arg_list]', # '[expr]' # ] # # ...: [...] # ... # } # Say it again, they are valid actions. # fields["utterance"] # TextFile for utterance fields[ "action_sequence"] = action_sequence_field # grammar rules (action) of this query, and every rule contains a total grammar set which is fields["valid_actions"]. fields["world"] = MetadataField( world ) #Maybe just for calc the metric. # A MetadataField is a Field that does not get converted into tensors. https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html?highlight=metadatafield#metadata-field fields["schema"] = table_field return Instance(fields)
def read_spider_split(dataset_path, table_path, database_path): with open(dataset_path) as f: split_data = json.load(f) print('read_spider_split', dataset_path, len(split_data)) schemas = read_dataset_schema(table_path, stanza_model) interaction_list = {} for i, ex in enumerate(tqdm(split_data)): db_id = ex['db_id'] ex['query_toks_no_value'] = normalize_original_sql(ex['query_toks_no_value']) turn_sql = ' '.join(ex['query_toks_no_value']) turn_sql = turn_sql.replace('select count ( * ) from follows group by value', 'select count ( * ) from follows group by f1') ex['query_toks_no_value'] = turn_sql.split(' ') ex = fix_number_value(ex) try: ex['query_toks_no_value'] = disambiguate_items(db_id, ex['query_toks_no_value'], tables_file=table_path, allow_aliases=False) except: print(ex['query_toks']) continue final_sql_parse = ' '.join(ex['query_toks_no_value']) final_utterance = ' '.join(ex['question_toks']).lower() if stanza_model is not None: lemma_utterance_stanza = stanza_model(final_utterance) lemma_utterance = [word.lemma for sent in lemma_utterance_stanza.sentences for word in sent.words] original_utterance = final_utterance else: original_utterance = lemma_utterance = final_utterance.split(' ') # using db content db_context = SpiderDBContext(db_id, lemma_utterance, tables_file=table_path, dataset_path=database_path, stanza_model=stanza_model, schemas=schemas, original_utterance=original_utterance) value_match, value_alignment, exact_match, partial_match = db_context.get_db_knowledge_graph(db_id) if value_match != []: print(value_match, value_alignment) if db_id not in interaction_list: interaction_list[db_id] = [] interaction = {} interaction['id'] = i interaction['database_id'] = db_id interaction['interaction'] = [{'utterance': final_utterance, 'db_id': db_id, 'query': ex['query'], 'question': ex['question'], 'sql': final_sql_parse, 'value_match': value_match, 'value_alignment': value_alignment, 'exact_match': exact_match, 'partial_match': partial_match, }] interaction_list[db_id].append(interaction) return interaction_list
def text_to_instance(self, utterance: str, db_id: str, sql: List[str] = None): fields: Dict[str, Field] = {} """KAIMARY""" # Contains # 1. db schema(Tables with corresponding columns) # 2. Tokenized utterance # 3. Knowledge graph(Table entities, column entities and "text" column type related token entities) # 4. Entity_tokens(Retrieved from entities_text from kg) db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) # https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html#knowledge-graph-field # *feature extractors* table_field = SpiderKnowledgeGraphField( db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) world = SpiderWorld(db_context, query=sql) fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers) action_sequence, all_actions = world.get_action_sequence_and_all_actions( ) if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field fields["world"] = MetadataField(world) fields["schema"] = table_field return Instance(fields)
def text_to_instance(self, utterances: List[str], db_id: str, sql: List[List[str]] = None): fields: Dict[str, Field] = {} ctxts = [ SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) for utterance in utterances ] super_utterance = ' '.join(utterances) hack_ctxt = SpiderDBContext(db_id, super_utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) kg = SpiderKnowledgeGraphField( hack_ctxt.knowledge_graph, hack_ctxt.tokenized_utterance, self._utterance_token_indexers, entity_tokens=hack_ctxt.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) ''' kgs = [SpiderKnowledgeGraphField(db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) for db_context in ctxts] ''' worlds = [] for i in range(len(sql)): sqli = sql[i] db_context = ctxts[i] world = SpiderWorld(db_context, query=sqli) worlds.append(world) fields["utterances"] = ListField([ TextField(db_context.tokenized_utterance, self._utterance_token_indexers) for db_context in ctxts ]) #action_sequence, all_actions = world.get_action_sequence_and_all_actions() action_tups = [ world.get_action_sequence_and_all_actions() for world in worlds ] action_sequences = [tup[0] for tup in action_tups] all_actions = [tup[1] for tup in action_tups] for i in range(len(action_sequences)): action_sequence = action_sequences[i] if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None action_sequences[i] = action_sequence all_valid_actions_fields = [] all_action_sequence_fields = [] for i in range(len(all_actions)): index_fields: List[Field] = [] production_rule_fields: List[Field] = [] all_actionsi = all_actions[i] for production_rule in all_actionsi: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) all_valid_actions_fields.append(valid_actions_field) action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } index_fields: List[Field] = [] action_sequence = action_sequences[i] for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) all_action_sequence_fields.append(action_sequence_field) fields["valid_actions"] = ListField(all_valid_actions_fields) fields["action_sequences"] = ListField(all_action_sequence_fields) fields["worlds"] = ListField( [MetadataField(world) for world in worlds]) fields["schema"] = kg ''' fields['utterances'] = ListField[TextField] fields['valid_actions'] = ListField[ListField[ProductionRuleField]] fields['action_sequences'] = ListField[ListField[IndexField]] fields['worlds'] = ListField[MetadataField[SpiderWorld]] fields['schemas'] = ListField[SpiderKnowledgeGraphField] ''' return Instance(fields)