def test_doubly_nested_field_works(self): field1 = ProductionRuleField('S -> [NP, VP]', is_global_rule=True) field2 = ProductionRuleField('NP -> test', is_global_rule=True) field3 = ProductionRuleField('VP -> eat', is_global_rule=False) list_field = ListField( [ListField([field1, field2, field3]), ListField([field1, field2])]) list_field.index(self.vocab) padding_lengths = list_field.get_padding_lengths() tensors = list_field.as_tensor(padding_lengths) assert isinstance(tensors, list) assert len(tensors) == 2 assert isinstance(tensors[0], list) assert len(tensors[0]) == 3 assert isinstance(tensors[1], list) assert len(tensors[1]) == 3 tensor_tuple = tensors[0][0] assert tensor_tuple[0] == 'S -> [NP, VP]' assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index]) tensor_tuple = tensors[0][1] assert tensor_tuple[0] == 'NP -> test' assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.np_index]) tensor_tuple = tensors[0][2] assert tensor_tuple[0] == 'VP -> eat' assert tensor_tuple[1] is False assert tensor_tuple[2] is None tensor_tuple = tensors[1][0] assert tensor_tuple[0] == 'S -> [NP, VP]' assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index]) tensor_tuple = tensors[1][1] assert tensor_tuple[0] == 'NP -> test' assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.np_index]) # This item was just padding. tensor_tuple = tensors[1][2] assert tensor_tuple[0] == '' assert tensor_tuple[1] is False assert tensor_tuple[2] is None
def text_to_instance(self, question_tokens: List[Token], logical_form: str, question_entities: List[str], question_predicates: List[str]) -> Instance: context = LCQuADContext(self.executor, question_tokens, question_entities, question_predicates) language = LCQuADLanguage(context) print("CONSTANT:" + str(language._functions['http://dbpedia.org/ontology/creator'])) target_action_sequence = language.logical_form_to_action_sequence( logical_form) production_rule_fields = [ ProductionRuleField(rule, is_global_rule=True) for rule in language.all_possible_productions() ] action_field = ListField(production_rule_fields) action_map = { action.rule: i for i, action in enumerate(production_rule_fields) } target_action_sequence_field = ListField([ IndexField(action_map[a], action_field) for a in target_action_sequence ]) fields = { 'question': TextField(question_tokens, self.token_indexers), 'question_entities': MetadataField(question_entities), 'question_predicates': MetadataField(question_predicates), 'world': MetadataField(language), 'actions': action_field, 'target_action_sequence': target_action_sequence_field } return Instance(fields)
def text_to_instance(self, question_tokens: List[Token], logical_form: str, question_entities: List[str], question_predicates: List[str]) -> Optional[Instance]: try: if any([True for i in self.ontology_types if i in logical_form]): return None if "intersection" in logical_form or "contains" in logical_form or "count" in logical_form: return None for old, new in zip(self.original_predicates, self.predicates): logical_form = logical_form.replace(old, new) question_entities = question_entities #+ list(self.ontology_types) import random random.shuffle(question_predicates) context = LCQuADContext(self.executor, question_tokens, question_entities, question_predicates) language = LCQuADLanguage(context) # print("CONSTANT:" + str(language._functions['http://dbpedia.org/ontology/creator'])) target_action_sequence = language.logical_form_to_action_sequence( logical_form) #labelled_results = language.execute_action_sequence(target_action_sequence) # if isinstance(labelled_results, set) and len(labelled_results) > 1000: # return None production_rule_fields = [ ProductionRuleField(rule, is_global_rule=True) for rule in language.all_possible_productions() ] action_field = ListField(production_rule_fields) action_map = { action.rule: i for i, action in enumerate(production_rule_fields) } target_action_sequence_field = ListField([ ListField([ IndexField(action_map[a], action_field) for a in target_action_sequence ]) ]) fields = { 'question': TextField(question_tokens, self.token_indexers), 'question_entities': MetadataField(question_entities), 'question_predicates': MetadataField(question_predicates), 'world': MetadataField(language), 'actions': action_field, 'target_action_sequences': target_action_sequence_field, 'logical_forms': MetadataField([logical_form]) #'labelled_results': MetadataField(labelled_results), } return Instance(fields) except ParsingError: print(logical_form) return None
def test_field_counts_vocab_items_correctly(self): field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True) namespace_token_counts = defaultdict(lambda: defaultdict(int)) field.count_vocab_items(namespace_token_counts) assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 1 field = ProductionRuleField('S -> [NP, VP]', is_global_rule=False) namespace_token_counts = defaultdict(lambda: defaultdict(int)) field.count_vocab_items(namespace_token_counts) assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 0
def text_to_instance(self, # type: ignore query: List[str], derived_cols: List[Tuple[str, str]], derived_tables: List[str], prelinked_entities: Dict[str, Dict[str, str]] = None, sql: List[str] = None, spans: List[Tuple[int, int]] = None) -> Instance: # pylint: disable=arguments-differ fields: Dict[str, Field] = {} tokens_tokenized = self._token_tokenizer.tokenize(' '.join(query)) tokens = TextField(tokens_tokenized, self._token_indexers) fields["tokens"] = tokens spans_field: List[Field] = [] spans = self._fix_spans_coverage(spans, len(tokens_tokenized)) for start, end in spans: spans_field.append(SpanField(start, end, tokens)) span_list_field: ListField = ListField(spans_field) fields["spans"] = span_list_field if sql is not None: action_sequence, all_actions = self._world.get_action_sequence_and_all_actions(query=sql, derived_cols=derived_cols, derived_tables=derived_tables, prelinked_entities=prelinked_entities) if action_sequence is None and self._keep_if_unparsable: print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, _ = production_rule.split(' ->') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, self._world.is_global_rule(nonterminal), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = {action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list)} for production_rule in action_sequence: index_fields.append(IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] # if not action_sequence and re.findall(r"COUNT \( \* \) (?:<|>|<>|=) 0", " ".join(sql)): # index_fields = [IndexField(-2, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field return Instance(fields)
def test_as_tensor_produces_correct_output(self): field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True) field.index(self.vocab) tensor_tuple = field.as_tensor(field.get_padding_lengths()) assert isinstance(tensor_tuple, tuple) assert len(tensor_tuple) == 4 assert tensor_tuple[0] == 'S -> [NP, VP]' assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index]) field = ProductionRuleField('S -> [NP, VP]', is_global_rule=False) field.index(self.vocab) tensor_tuple = field.as_tensor(field.get_padding_lengths()) assert isinstance(tensor_tuple, tuple) assert len(tensor_tuple) == 4 assert tensor_tuple[0] == 'S -> [NP, VP]' assert tensor_tuple[1] is False assert tensor_tuple[2] is None
def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance: question_tokens = self._read_tokens_from_json_list( json_obj['question_tokens']) question_field = TextField(question_tokens, self._question_token_indexers) table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines( json_obj['table_lines'], question_tokens) entity_tokens = [ self._read_tokens_from_json_list(token_list) for token_list in json_obj['entity_texts'] ] table_field = KnowledgeGraphField( table_knowledge_graph, question_tokens, tokenizer=None, token_indexers=self._table_token_indexers, entity_tokens=entity_tokens, linking_features=json_obj['linking_features'], include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens) world = WikiTablesWorld(table_knowledge_graph) world_field = MetadataField(world) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_table_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) example_string_field = MetadataField(json_obj['example_lisp_string']) fields = { 'question': question_field, 'table': table_field, 'world': world_field, 'actions': action_field, 'example_lisp_string': example_string_field } if 'target_action_sequences' in json_obj: action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore action_sequence_fields: List[Field] = [] for sequence in json_obj['target_action_sequences']: index_fields: List[Field] = [] for production_rule in sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) fields['target_action_sequences'] = ListField( action_sequence_fields) return Instance(fields)
def text_to_instance(self, # type: ignore logical_forms: List[str], table_lines: List[List[str]], question: str) -> Instance: # pylint: disable=arguments-differ tokenized_question = self._tokenizer.tokenize(question.lower()) tokenized_question.insert(0, Token(START_SYMBOL)) tokenized_question.append(Token(END_SYMBOL)) question_field = TextField(tokenized_question, self._question_token_indexers) table_context = TableQuestionContext.read_from_lines(table_lines, tokenized_question) world = WikiTablesLanguage(table_context) action_sequences_list: List[List[str]] = [] action_sequence_fields_list: List[TextField] = [] for logical_form in logical_forms: try: action_sequence = world.logical_form_to_action_sequence(logical_form) action_sequence = reader_utils.make_bottom_up_action_sequence(action_sequence, world.is_nonterminal) action_sequence_field = TextField([Token(rule) for rule in action_sequence], self._rule_indexers) action_sequences_list.append(action_sequence) action_sequence_fields_list.append(action_sequence_field) except ParsingError as error: logger.debug(f'Parsing error: {error.message}, skipping logical form') logger.debug(f'Question was: {question}') logger.debug(f'Logical form was: {logical_form}') logger.debug(f'Table info was: {table_lines}') except: logger.error(logical_form) raise if not action_sequences_list: return None all_production_rule_fields: List[List[Field]] = [] for action_sequence in action_sequences_list: all_production_rule_fields.append([]) for production_rule in action_sequence: _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_instance_specific_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule=is_global_rule) all_production_rule_fields[-1].append(field) action_field = ListField([ListField(production_rule_fields) for production_rule_fields in all_production_rule_fields]) fields = {'action_sequences': ListField(action_sequence_fields_list), 'target_tokens': question_field, 'world': MetadataField(world), 'actions': action_field} return Instance(fields)
def text_to_instance(self, utterance: str, db_id: str, sql: List[str] = None): fields: Dict[str, Field] = {} db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) world = SpiderWorld(db_context, query=sql) fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers) action_sequence, all_actions = world.get_action_sequence_and_all_actions() if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = {action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list)} for production_rule in action_sequence: index_fields.append(IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field fields["world"] = MetadataField(world) fields["schema"] = table_field return Instance(fields)
def text_to_instance( self, # type: ignore query: List[str], sql: text2sql_utils.SqlData = None) -> Instance: # pylint: disable=arguments-differ fields: Dict[str, Field] = {} tokens = TextField([Token(t) for t in query], self._token_indexers) fields["tokens"] = tokens if sql is not None: try: action_sequence, all_actions = self._world.get_action_sequence_and_all_actions( sql.sql) except ParseError: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, _ = production_rule.split(' ->') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField( production_rule, self._world.is_global_rule(nonterminal)) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } for production_rule in action_sequence: # Temporarily skipping this production to # make a PR smaller. The next PR will constrain # the strings produced to be from the table, # but at the moment they are blank so they # aren't present in the global actions. # TODO(Mark): fix the above. if production_rule.startswith("string"): continue index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field return Instance(fields)
def text_to_instance( self, # type: ignore query: List[str], prelinked_entities: Dict[str, Dict[str, str]] = None, sql: List[str] = None, ) -> Instance: fields: Dict[str, Field] = {} tokens = TextField([Token(t) for t in query], self._token_indexers) fields["tokens"] = tokens if sql is not None: action_sequence, all_actions = self._world.get_action_sequence_and_all_actions( sql, prelinked_entities) if action_sequence is None and self._keep_if_unparsable: print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, _ = production_rule.split(" ->") production_rule = " ".join(production_rule.split(" ")) field = ProductionRuleField( production_rule, self._world.is_global_rule(nonterminal), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field return Instance(fields)
def text_to_instance( self, # type: ignore query: List[str], prelinked_entities: Dict[str, Dict[str, str]] = None, sql: List[str] = None) -> Instance: # pylint: disable=arguments-differ fields: Dict[str, Field] = {} tokens = TextField([Token(t) for t in query], self._token_indexers) fields["tokens"] = tokens if sql is not None: try: action_sequence, all_actions = self._world.get_action_sequence_and_all_actions( sql, prelinked_entities) except ParseError: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, _ = production_rule.split(' ->') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField( production_rule, self._world.is_global_rule(nonterminal), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field return Instance(fields)
def text_to_instance(self, utterances: List[str], db_id: str, sql: List[List[str]] = None): fields: Dict[str, Field] = {} ctxts = [ SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) for utterance in utterances ] super_utterance = ' '.join(utterances) hack_ctxt = SpiderDBContext(db_id, super_utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) kg = SpiderKnowledgeGraphField( hack_ctxt.knowledge_graph, hack_ctxt.tokenized_utterance, self._utterance_token_indexers, entity_tokens=hack_ctxt.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) ''' kgs = [SpiderKnowledgeGraphField(db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) for db_context in ctxts] ''' worlds = [] for i in range(len(sql)): sqli = sql[i] db_context = ctxts[i] world = SpiderWorld(db_context, query=sqli) worlds.append(world) fields["utterances"] = ListField([ TextField(db_context.tokenized_utterance, self._utterance_token_indexers) for db_context in ctxts ]) #action_sequence, all_actions = world.get_action_sequence_and_all_actions() action_tups = [ world.get_action_sequence_and_all_actions() for world in worlds ] action_sequences = [tup[0] for tup in action_tups] all_actions = [tup[1] for tup in action_tups] for i in range(len(action_sequences)): action_sequence = action_sequences[i] if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None action_sequences[i] = action_sequence all_valid_actions_fields = [] all_action_sequence_fields = [] for i in range(len(all_actions)): index_fields: List[Field] = [] production_rule_fields: List[Field] = [] all_actionsi = all_actions[i] for production_rule in all_actionsi: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) all_valid_actions_fields.append(valid_actions_field) action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } index_fields: List[Field] = [] action_sequence = action_sequences[i] for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) all_action_sequence_fields.append(action_sequence_field) fields["valid_actions"] = ListField(all_valid_actions_fields) fields["action_sequences"] = ListField(all_action_sequence_fields) fields["worlds"] = ListField( [MetadataField(world) for world in worlds]) fields["schema"] = kg ''' fields['utterances'] = ListField[TextField] fields['valid_actions'] = ListField[ListField[ProductionRuleField]] fields['action_sequences'] = ListField[ListField[IndexField]] fields['worlds'] = ListField[MetadataField[SpiderWorld]] fields['schemas'] = ListField[SpiderKnowledgeGraphField] ''' return Instance(fields)
def text_to_instance(self, utterance: str, db_id: str, sql: List[str] = None): fields: Dict[str, Field] = {} db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) table_field = SpiderKnowledgeGraphField(db_context.knowledge_graph, db_context.tokenized_utterance, {}, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) combined_tokens = [] + db_context.tokenized_utterance entity_token_map = dict(zip(db_context.knowledge_graph.entities, db_context.entity_tokens)) entity_tokens = [] for e in db_context.knowledge_graph.entities: if e.startswith('column:'): table_name, column_name = e.split(':')[-2:] table_tokens = entity_token_map['table:'+table_name] column_tokens = entity_token_map[e] if column_name.startswith(table_name): column_tokens = column_tokens[len(table_tokens):] entity_tokens.append(table_tokens + [Token(text='[unused30]')] + column_tokens) else: entity_tokens.append(entity_token_map[e]) for e in entity_tokens: combined_tokens += [Token(text='[SEP]')] + e if len(combined_tokens) > 450: return None db_context.entity_tokens = entity_tokens fields["utterance"] = TextField(combined_tokens, self._utterance_token_indexers) world = SpiderWorld(db_context, query=sql) action_sequence, all_actions = world.get_action_sequence_and_all_actions() if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, _ = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(nonterminal), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = {action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list)} for production_rule in action_sequence: index_fields.append(IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field fields["world"] = MetadataField(world) fields["schema"] = table_field return Instance(fields)
def text_to_instance(self, utterance: str, db_id: str, sql: List[str] = None): fields: Dict[str, Field] = {} if self._is_spider: db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) table_field = SpiderKnowledgeGraphField( db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) world = SpiderWorld(db_context, nl_context=None, query=sql) fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers) action_sequence, all_actions = world.get_action_sequence_and_all_actions( ) if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field fields["world"] = MetadataField(world) fields["schema"] = table_field else: db_context = WikiDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) #print(db_context.entity_tokens) #todo 这个WikiKnowledgeGraphField和对应的spider的一模一样 只是改动了类名 table_field = WikiKnowledgeGraphField( db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) world = WikiWorld(db_context, nl_context=None, query=sql) fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers) #todo 这一步会报错 应该是grammar不匹配的问题 ParseError:['select', '1-10015132-11@Position', 'where', '1-10015132-11@School/Club Team', '=', "'value'"] #todo action_sequence: None #todo all_actions: ['arg_list -> [expr, ",", arg_list]', 'arg_list -> [expr]', 'arg_list_or_star -> ["*"]', 'arg_list_or_star -> [arg_list]', 'binaryop -> ["!="]', 'binaryop -> ["*"]', 'binaryop -> ["+"]', 'binaryop -> ["-"]', 'binaryop -> ["/"]', 'binaryop -> ["<"]', 'binaryop -> ["<="]', 'binaryop -> ["<>"]', 'binaryop -> ["="]', 'binaryop -> [">"]', 'binaryop -> [">="]', 'binaryop -> ["and"]', 'binaryop -> ["like"]', 'binaryop -> ["or"]', 'boolean -> ["false"]', 'boolean -> ["true"]', 'col_ref -> ["1-10015132-11@col0"]', 'col_ref -> ["1-10015132-11@col1"]', 'col_ref -> ["1-10015132-11@col2"]', 'col_ref -> ["1-10015132-11@col3"]', 'col_ref -> ["1-10015132-11@col4"]', 'col_ref -> ["1-10015132-11@col5"]', 'column_name -> ["1-10015132-11@col0"]', 'column_name -> ["1-10015132-11@col1"]', 'column_name -> ["1-10015132-11@col2"]', 'column_name -> ["1-10015132-11@col3"]', 'column_name -> ["1-10015132-11@col4"]', 'column_name -> ["1-10015132-11@col5"]', 'expr -> [in_expr]', 'expr -> [source_subq]', 'expr -> [unaryop, expr]', 'expr -> [value, "between", value, "and", value]', 'expr -> [value, "like", string]', 'expr -> [value, binaryop, expr]', 'expr -> [value]', 'fname -> ["all"]', 'fname -> ["avg"]', 'fname -> ["count"]', 'fname -> ["max"]', 'fname -> ["min"]', 'fname -> ["sum"]', 'from_clause -> ["from", source]', 'from_clause -> ["from", table_name, join_clauses]', 'function -> [fname, "(", "distinct", arg_list_or_star, ")"]', 'function -> [fname, "(", arg_list_or_star, ")"]', 'group_clause -> [expr, ",", group_clause]', 'group_clause -> [expr]', 'groupby_clause -> ["group", "by", group_clause, "having", expr]', 'groupby_clause -> ["group", "by", group_clause]', 'in_expr -> [value, "in", expr]', 'in_expr -> [value, "in", string_set]', 'in_expr -> [value, "not", "in", expr]', 'in_expr -> [value, "not", "in", string_set]', 'iue -> ["except"]', 'iue -> ["intersect"]', 'iue -> ["union"]', 'join_clause -> ["join", table_name, "on", join_condition_clause]', 'join_clauses -> [join_clause, join_clauses]', 'join_clauses -> [join_clause]', 'join_condition -> [column_name, "=", column_name]', 'join_condition_clause -> [join_condition, "and", join_condition_clause]', 'join_condition_clause -> [join_condition]', 'limit -> ["limit", non_literal_number]', 'non_literal_number -> ["1"]', 'non_literal_number -> ["2"]', 'non_literal_number -> ["3"]', 'non_literal_number -> ["4"]', 'number -> ["value"]', 'order_clause -> [ordering_term, ",", order_clause]', 'order_clause -> [ordering_term]', 'orderby_clause -> ["order", "by", order_clause]', 'ordering -> ["asc"]', 'ordering -> ["desc"]', 'ordering_term -> [expr, ordering]', 'ordering_term -> [expr]', 'parenval -> ["(", expr, ")"]', 'query -> [select_core, groupby_clause, limit]', 'query -> [select_core, groupby_clause, orderby_clause, limit]', 'query -> [select_core, groupby_clause, orderby_clause]', 'query -> [select_core, groupby_clause]', 'query -> [select_core, orderby_clause, limit]', 'query -> [select_core, orderby_clause]', 'query -> [select_core]', 'select_core -> [select_with_distinct, select_results, from_clause, where_clause]', 'select_core -> [select_with_distinct, select_results, from_clause]', 'select_core -> [select_with_distinct, select_results, where_clause]', 'select_core -> [select_with_distinct, select_results]', 'select_result -> ["*"]', 'select_result -> [column_name]', 'select_result -> [expr]', 'select_result -> [table_name, ".*"]', 'select_results -> [select_result, ",", select_results]', 'select_results -> [select_result]', 'select_with_distinct -> ["select", "distinct"]', 'select_with_distinct -> ["select"]', 'single_source -> [source_subq]', 'single_source -> [table_name]', 'source -> [single_source, ",", source]', 'source -> [single_source]', 'source_subq -> ["(", query, ")"]', 'statement -> [query, iue, query]', 'statement -> [query]', 'string -> ["\'", "value", "\'"]', 'string_set -> ["(", string_set_vals, ")"]', 'string_set_vals -> [string, ",", string_set_vals]', 'string_set_vals -> [string]', "table_name -> ['1-10015132-11']", "table_source -> ['1-10015132-11']", 'unaryop -> ["+"]', 'unaryop -> ["-"]', 'unaryop -> ["not"]', 'value -> ["YEAR(CURDATE())"]', 'value -> [boolean]', 'value -> [column_name]', 'value -> [function]', 'value -> [number]', 'value -> [parenval]', 'value -> [string]', 'where_clause -> ["where", expr, where_conj]', 'where_clause -> ["where", expr]', 'where_conj -> ["and", expr, where_conj]', 'where_conj -> ["and", expr]'] action_sequence, all_actions = world.get_action_sequence_and_all_actions( ) if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field index_fields: List[Field] = [] # action: ProductionRuleField action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field fields["world"] = MetadataField(world) fields["schema"] = table_field #print(fields) return Instance(fields)
def text_to_instance( self, # type: ignore question: str, table_lines: List[str], example_lisp_string: str = None, dpd_output: List[str] = None, tokenized_question: List[Token] = None) -> Instance: """ Reads text inputs and makes an instance. WikitableQuestions dataset provides tables as TSV files, which we use for training. Parameters ---------- question : ``str`` Input question table_lines : ``List[str]`` The table content itself, as a list of rows. See ``TableQuestionKnowledgeGraph.read_from_lines`` for the expected format. example_lisp_string : ``str``, optional The original (lisp-formatted) example string in the WikiTableQuestions dataset. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE for evaluating logical forms during training. It isn't otherwise used for anything. dpd_output : List[str], optional List of logical forms, produced by dynamic programming on denotations. Not required during test. tokenized_question : ``List[Token]``, optional If you have already tokenized the question, you can pass that in here, so we don't duplicate that work. You might, for example, do batch processing on the questions in the whole dataset, then pass the result in here. """ # pylint: disable=arguments-differ tokenized_question = tokenized_question or self._tokenizer.tokenize( question.lower()) question_field = TextField(tokenized_question, self._question_token_indexers) metadata: Dict[str, Any] = { "question_tokens": [x.text for x in tokenized_question] } metadata["original_table"] = "\n".join(table_lines) table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines( table_lines, tokenized_question) table_metadata = MetadataField(table_lines) table_field = KnowledgeGraphField( table_knowledge_graph, tokenized_question, self._table_token_indexers, tokenizer=self._tokenizer, feature_extractors=self._linking_feature_extractors, include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens) world = WikiTablesWorld(table_knowledge_graph) world_field = MetadataField(world) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_table_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = { 'question': question_field, 'metadata': MetadataField(metadata), 'table': table_field, 'world': world_field, 'actions': action_field } if self._include_table_metadata: fields['table_metadata'] = table_metadata if example_lisp_string: fields['example_lisp_string'] = MetadataField(example_lisp_string) # We'll make each target action sequence a List[IndexField], where the index is into # the action list we made above. We need to ignore the type here because mypy doesn't # like `action.rule` - it's hard to tell mypy that the ListField is made up of # ProductionRuleFields. action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore if dpd_output: action_sequence_fields: List[Field] = [] for logical_form in dpd_output: if not self._should_keep_logical_form(logical_form): logger.debug(f'Question was: {question}') logger.debug(f'Table info was: {table_lines}') continue try: expression = world.parse_logical_form(logical_form) except ParsingError as error: logger.debug( f'Parsing error: {error.message}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Logical form was: {logical_form}') logger.debug(f'Table info was: {table_lines}') continue except: logger.error(logical_form) raise action_sequence = world.get_action_sequence(expression) try: index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except KeyError as error: logger.debug( f'Missing production rule: {error.args}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Table info was: {table_lines}') logger.debug(f'Logical form was: {logical_form}') continue if len(action_sequence_fields) >= self._max_dpd_logical_forms: break if not action_sequence_fields: # This is not great, but we're only doing it when we're passed logical form # supervision, so we're expecting labeled logical forms, but we can't actually # produce the logical forms. We should skip this instance. Note that this affects # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the # full test data. return None fields['target_action_sequences'] = ListField( action_sequence_fields) if self._output_agendas: agenda_index_fields: List[Field] = [] for agenda_string in world.get_agenda(): agenda_index_fields.append( IndexField(action_map[agenda_string], action_field)) if not agenda_index_fields: agenda_index_fields = [IndexField(-1, action_field)] fields['agenda'] = ListField(agenda_index_fields) return Instance(fields)
def test_batch_tensors_does_not_modify_list(self): field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True) field.index(self.vocab) padding_lengths = field.get_padding_lengths() tensor_dict1 = field.as_tensor(padding_lengths) field = ProductionRuleField('NP -> test', is_global_rule=True) field.index(self.vocab) padding_lengths = field.get_padding_lengths() tensor_dict2 = field.as_tensor(padding_lengths) tensor_list = [tensor_dict1, tensor_dict2] assert field.batch_tensors(tensor_list) == tensor_list
def test_padding_lengths_are_computed_correctly(self): field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True) field.index(self.vocab) assert field.get_padding_lengths() == {}
def text_to_instance(self, utterance: str, db_id: str, sql: List[str] = None): fields: Dict[str, Field] = {} """KAIMARY""" # Contains # 1. db schema(Tables with corresponding columns) # 2. Tokenized utterance # 3. Knowledge graph(Table entities, column entities and "text" column type related token entities) # 4. Entity_tokens(Retrieved from entities_text from kg) db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) # https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html#knowledge-graph-field # *feature extractors* table_field = SpiderKnowledgeGraphField( db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None) # self._max_table_tokens) world = SpiderWorld(db_context, query=sql) fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers) action_sequence, all_actions = world.get_action_sequence_and_all_actions( ) if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field fields["world"] = MetadataField(world) fields["schema"] = table_field return Instance(fields)
def text_to_instance( self, # type: ignore utterances: List[str], sql_query_labels: List[str] = None) -> Instance: # pylint: disable=arguments-differ """ Parameters ---------- utterances: ``List[str]``, required. List of utterances in the interaction, the last element is the current utterance. sql_query_labels: ``List[str]``, optional The SQL queries that are given as labels during training or validation. """ if self._num_turns_to_concatenate: utterances[-1] = f' {END_OF_UTTERANCE_TOKEN} '.join( utterances[-self._num_turns_to_concatenate:]) utterance = utterances[-1] action_sequence: List[str] = [] if not utterance: return None world = AtisWorld(utterances=utterances) if sql_query_labels: # If there are multiple sql queries given as labels, we use the shortest # one for training. sql_query = min(sql_query_labels, key=len) try: action_sequence = world.get_action_sequence(sql_query) except ParseError: logger.debug(f'Parsing error') tokenized_utterance = self._tokenizer.tokenize(utterance.lower()) utterance_field = TextField(tokenized_utterance, self._token_indexers) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): nonterminal, _ = production_rule.split(' ->') # The whitespaces are not semantically meaningful, so we filter them out. production_rule = ' '.join([ token for token in production_rule.split(' ') if token != 'ws' ]) field = ProductionRuleField(production_rule, self._is_global_rule(nonterminal)) production_rule_fields.append(field) action_field = ListField(production_rule_fields) action_map = { action.rule: i # type: ignore for i, action in enumerate(action_field.field_list) } index_fields: List[Field] = [] world_field = MetadataField(world) fields = { 'utterance': utterance_field, 'actions': action_field, 'world': world_field, 'linking_scores': ArrayField(world.linking_scores) } if sql_query_labels != None: fields['sql_queries'] = MetadataField(sql_query_labels) if action_sequence and not self._keep_if_unparseable: for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_field = ListField(index_fields) fields['target_action_sequence'] = action_sequence_field elif not self._keep_if_unparseable: # If we are given a SQL query, but we are unable to parse it, and we do not specify explicitly # to keep it, then we will skip the it. return None return Instance(fields)
def text_to_instance( self, # type: ignore question: str, logical_forms: List[str] = None, additional_metadata: Dict[str, Any] = None, world_extractions: Dict[str, Union[str, List[str]]] = None, entity_literals: Dict[str, Union[str, List[str]]] = None, tokenized_question: List[Token] = None, debug_counter: int = None, qr_spec_override: List[Dict[str, int]] = None, dynamic_entities_override: Dict[str, str] = None, ) -> Instance: tokenized_question = tokenized_question or self._tokenizer.tokenize(question.lower()) additional_metadata = additional_metadata or dict() additional_metadata["question_tokens"] = [token.text for token in tokenized_question] if world_extractions is not None: additional_metadata["world_extractions"] = world_extractions question_field = TextField(tokenized_question, self._question_token_indexers) if qr_spec_override is not None or dynamic_entities_override is not None: # Dynamically specify theory and/or entities dynamic_entities = dynamic_entities_override or self._dynamic_entities neighbors: Dict[str, List[str]] = {key: [] for key in dynamic_entities.keys()} knowledge_graph = KnowledgeGraph( entities=set(dynamic_entities.keys()), neighbors=neighbors, entity_text=dynamic_entities, ) world = QuarelWorld(knowledge_graph, self._lf_syntax, qr_coeff_sets=qr_spec_override) else: knowledge_graph = self._knowledge_graph world = self._world table_field = KnowledgeGraphField( knowledge_graph, tokenized_question, self._entity_token_indexers, tokenizer=self._tokenizer, ) if self._tagger_only: fields: Dict[str, Field] = {"tokens": question_field} if entity_literals is not None: entity_tags = self._get_entity_tags( self._all_entities, table_field, entity_literals, tokenized_question ) if debug_counter > 0: logger.info(f"raw entity tags = {entity_tags}") entity_tags_bio = self._convert_tags_bio(entity_tags) fields["tags"] = SequenceLabelField(entity_tags_bio, question_field) additional_metadata["tags_gold"] = entity_tags_bio additional_metadata["words"] = [x.text for x in tokenized_question] fields["metadata"] = MetadataField(additional_metadata) return Instance(fields) world_field = MetadataField(world) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(" -> ") is_global_rule = not world.is_table_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = { "question": question_field, "table": table_field, "world": world_field, "actions": action_field, } if self._denotation_only: denotation_field = LabelField(additional_metadata["answer_index"], skip_indexing=True) fields["denotation_target"] = denotation_field if self._entity_bits_mode is not None and world_extractions is not None: entity_bits = self._get_entity_tags( ["world1", "world2"], table_field, world_extractions, tokenized_question ) if self._entity_bits_mode == "simple": entity_bits_v = [[[0, 0], [1, 0], [0, 1]][tag] for tag in entity_bits] elif self._entity_bits_mode == "simple_collapsed": entity_bits_v = [[[0], [1], [1]][tag] for tag in entity_bits] elif self._entity_bits_mode == "simple3": entity_bits_v = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]][tag] for tag in entity_bits] entity_bits_field = ArrayField(np.array(entity_bits_v)) fields["entity_bits"] = entity_bits_field if logical_forms: action_map = { action.rule: i for i, action in enumerate(action_field.field_list) # type: ignore } action_sequence_fields: List[Field] = [] for logical_form in logical_forms: expression = world.parse_logical_form(logical_form) action_sequence = world.get_action_sequence(expression) try: index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append(IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except KeyError as error: logger.info(f"Missing production rule: {error.args}, skipping logical form") logger.info(f"Question was: {question}") logger.info(f"Logical form was: {logical_form}") continue fields["target_action_sequences"] = ListField(action_sequence_fields) fields["metadata"] = MetadataField(additional_metadata or {}) return Instance(fields)
def text_to_instance( self, utterance: str, # question db_id: str, sql: List[str] = None): fields: Dict[str, Field] = {} # db_context is db graph and its tokens. It include: utterance, db graph db_context = SpiderDBContext(db_id, utterance, tokenizer=self._tokenizer, tables_file=self._tables_file, dataset_path=self._dataset_path) # A instance contain many fields and must be filed obj in allennlp. (You can consider fields are columns in table or attribute in obj) # So we need to convert the db_context to a Filed obj which is table_field. # db_context.knowledge_graph is a graph so we need a graph field obj and SpiderKnowledgeGraphField inherit KnowledgeGraphField. table_field = SpiderKnowledgeGraphField( db_context.knowledge_graph, db_context.tokenized_utterance, self._utterance_token_indexers, entity_tokens=db_context.entity_tokens, include_in_vocab=False, # TODO: self._use_table_for_vocab, max_table_tokens=None, conceptnet=self._concept_word) # self._max_table_tokens) world = SpiderWorld(db_context, query=sql) fields["utterance"] = TextField(db_context.tokenized_utterance, self._utterance_token_indexers) # action_sequence is the parsed result by grammar. The grammar is created by certain database. # all_actions is the total grammar string list. # So you can consider action_sequence is subset of all_actions. # And this subset include all grammar you need for this query. # The grammar is defined in semparse/contexts/spider_db_grammar.py which is similar to the BNF grammar type. action_sequence, all_actions = world.get_action_sequence_and_all_actions( ) if action_sequence is None and self._keep_if_unparsable: # print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, rhs = production_rule.split(' -> ') production_rule = ' '.join(production_rule.split(' ')) # Help ProductionRuleField: https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html?highlight=productionrulefield#production-rule-field field = ProductionRuleField(production_rule, world.is_global_rule(rhs), nonterminal=nonterminal) production_rule_fields.append(field) # valid_actions_field is generated by all_actions that include all grammar. valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field # give every grammar a id. action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } # give every grammar rule in action_sequence a total grammar. # So maybe you can infer this rule to others through the total grammar rules easily. if action_sequence: for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) else: # action_sequence is None, which means: our grammar for the query is error. index_fields = [IndexField(-1, valid_actions_field)] # assert False # gan ??? action_sequence_field = ListField(index_fields) # The fields["valid_actions"] include the global rule which is the same in all SQL and database specific rules. # For example, 'binaryop -> ["!="]' and 'query -> [select_core, groupby_clause, limit]' are global rules. # 'column_name -> ["department@budget_in_billions"]' and 'col_ref -> ["department@budget_in_billions"]' are not global rules. # So fields["valid_actions"] is case by case but will contain the same global rules. And it must include the rules appearing in fields["action_sequence"]. # Now the attribute of _rule_id is None. But when finish loading all data, allennlp will automatically build the vocabulary and give a unique _ruled_id for every rule. # But when forward the fields["valid_actions"] to the model, it will become the ProductionRule List. # We will find that the global rules will contain a _rule_id but non-global rules will not. # And in ProductionRule List, the global rules will become a tuple and non-global will become a ProductionRule obj. # In a tuple and ProductionRule, its[0] shows its rule value, such as: 'where_clause -> ["where", expr, where_conj]' or 'col_ref -> ["department@budget_in_billions"]'. # In a tuple and ProductionRule, its[1] shows whether it is global rules, such as True or False. # In a tuple and ProductionRule, its[2] shows _rule_id. But if it is non-global rule, it will be None. # In a tuple and ProductionRule, its[3] shows left rule value. For example: 'where_clause' is the left rule value of 'where_clause -> ["where", expr, where_conj]'. # fields["valid_actions"] # All action / All grammar # The information of fields["valid_actions"] is almost the same as the world.valid_actions but using different representations (world is a SpiderWorld obj) # There are two kinds of valid actions in the project but their information is the same. # The first one is a set list: (We call it as list-type-action) # [ #key #value #key #value #other key and value pairs # {rule: 'arg_list -> [expr, ",", arg_list]' , is_global_rule: True, ... } # {rule: 'arg_list -> [expr]' , is_global_rule: True, ... } # {...} # ] # You can easily extract the all valid action to a list, such as: # all_actions: # ['arg_list -> [expr, ",", arg_list]' , 'arg_list -> [expr]' , ... ] # The second one is also a dict but its key is different and it will combine the same left key value together: (We call it as dict-type-action) # { #key #value-list # arg_list:[ # '[expr, ",", arg_list]', # '[expr]' # ] # # ...: [...] # ... # } # Say it again, they are valid actions. # fields["utterance"] # TextFile for utterance fields[ "action_sequence"] = action_sequence_field # grammar rules (action) of this query, and every rule contains a total grammar set which is fields["valid_actions"]. fields["world"] = MetadataField( world ) #Maybe just for calc the metric. # A MetadataField is a Field that does not get converted into tensors. https://allenai.github.io/allennlp-docs/api/allennlp.data.fields.html?highlight=metadatafield#metadata-field fields["schema"] = table_field return Instance(fields)
def text_to_instance(self, # type: ignore question: str, table_lines: List[List[str]], answer_json: JsonDict, offline_search_output: List[str] = None) -> Instance: """ Reads text inputs and makes an instance. We assume we have access to DROP paragraphs parsed and tagged in a format similar to the tagged tables in WikiTableQuestions. # TODO(pradeep): Explain the format. Parameters ---------- question : ``str`` Input question table_lines : ``List[List[str]]`` Preprocessed paragraph content. See ``ParagraphQuestionContext.read_from_lines`` for the expected format. answer_json : ``JsonDict`` The "answer" dict from the original data file. offline_search_output : List[str], optional List of logical forms, produced by offline search. Not required during test. """ # pylint: disable=arguments-differ tokenized_question = self._tokenizer.tokenize(question.lower()) question_field = TextField(tokenized_question, self._question_token_indexers) # TODO(pradeep): We'll need a better way to input processed lines. paragraph_context = ParagraphQuestionContext.read_from_lines(table_lines, tokenized_question, self._entity_extraction_embedding, self._entity_extraction_distance_threshold) target_values_field = MetadataField(answer_json) world = DropWorld(paragraph_context) world_field = MetadataField(world) # Note: Not passing any featre extractors when instantiating the field below. This will make # it use all the available extractors. table_field = KnowledgeGraphField(paragraph_context.get_knowledge_graph(), tokenized_question, self._table_token_indexers, tokenizer=self._tokenizer, include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_instance_specific_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule=is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = {'question': question_field, 'table': table_field, 'world': world_field, 'actions': action_field, 'target_values': target_values_field} # We'll make each target action sequence a List[IndexField], where the index is into # the action list we made above. We need to ignore the type here because mypy doesn't # like `action.rule` - it's hard to tell mypy that the ListField is made up of # ProductionRuleFields. action_map = {action.rule: i for i, action in enumerate(action_field.field_list)} # type: ignore if offline_search_output: action_sequence_fields: List[Field] = [] for logical_form in offline_search_output: try: expression = world.parse_logical_form(logical_form) except ParsingError as error: logger.debug(f'Parsing error: {error.message}, skipping logical form') logger.debug(f'Question was: {question}') logger.debug(f'Logical form was: {logical_form}') logger.debug(f'Table info was: {table_lines}') continue except: logger.error(logical_form) raise action_sequence = world.get_action_sequence(expression) try: index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append(IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except KeyError as error: logger.debug(f'Missing production rule: {error.args}, skipping logical form') logger.debug(f'Question was: {question}') logger.debug(f'Table info was: {table_lines}') logger.debug(f'Logical form was: {logical_form}') continue if len(action_sequence_fields) >= self._max_offline_logical_forms: break if not action_sequence_fields: # This is not great, but we're only doing it when we're passed logical form # supervision, so we're expecting labeled logical forms, but we can't actually # produce the logical forms. We should skip this instance. Note that this affects # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the # full test data. return None fields['target_action_sequences'] = ListField(action_sequence_fields) if self._output_agendas: agenda_index_fields: List[Field] = [] for agenda_string in world.get_agenda(): agenda_index_fields.append(IndexField(action_map[agenda_string], action_field)) if not agenda_index_fields: agenda_index_fields = [IndexField(-1, action_field)] fields['agenda'] = ListField(agenda_index_fields) return Instance(fields)
def text_to_instance( self, # type: ignore utterances: List[str], sql_query: str = None) -> Instance: # pylint: disable=arguments-differ """ Parameters ---------- utterances: ``List[str]``, required. List of utterances in the interaction, the last element is the current utterance. sql_query: ``str``, optional The SQL query, given as label during training or validation. """ utterance = utterances[-1] action_sequence: List[str] = [] if not utterance: return None world = AtisWorld(utterances=utterances, database_directory=self._database_directory) if sql_query: try: action_sequence = world.get_action_sequence(sql_query) except ParseError: logger.debug(f'Parsing error') tokenized_utterance = self._tokenizer.tokenize(utterance.lower()) utterance_field = TextField(tokenized_utterance, self._token_indexers) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): lhs, _ = production_rule.split(' ->') is_global_rule = not lhs in ['number', 'string'] # The whitespaces are not semantically meaningful, so we filter them out. production_rule = ' '.join([ token for token in production_rule.split(' ') if token != 'ws' ]) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) action_map = { action.rule: i # type: ignore for i, action in enumerate(action_field.field_list) } index_fields: List[Field] = [] world_field = MetadataField(world) fields = { 'utterance': utterance_field, 'actions': action_field, 'world': world_field, 'linking_scores': ArrayField(world.linking_scores) } if sql_query: if action_sequence: for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_field: List[Field] = [] action_sequence_field.append(ListField(index_fields)) fields['target_action_sequence'] = ListField( action_sequence_field) else: # If we are given a SQL query, but we are unable to parse it, then we will skip it. return None return Instance(fields)
def text_to_instance( self, # type: ignore sentence: str, structured_representations: List[List[List[JsonDict]]], labels: List[str] = None, target_sequences: List[List[str]] = None, identifier: str = None) -> Instance: """ Parameters ---------- sentence : ``str`` The query sentence. structured_representations : ``List[List[List[JsonDict]]]`` A list of Json representations of all the worlds. See expected format in this class' docstring. labels : ``List[str]`` (optional) List of string representations of the labels (true or false) corresponding to the ``structured_representations``. Not required while testing. target_sequences : ``List[List[str]]`` (optional) List of target action sequences for each element which lead to the correct denotation in worlds corresponding to the structured representations. identifier : ``str`` (optional) The identifier from the dataset if available. """ # pylint: disable=arguments-differ worlds = [NlvrWorld(data) for data in structured_representations] tokenized_sentence = self._tokenizer.tokenize(sentence) sentence_field = TextField(tokenized_sentence, self._sentence_token_indexers) production_rule_fields: List[Field] = [] instance_action_ids: Dict[str, int] = {} # TODO(pradeep): Assuming that possible actions are the same in all worlds. This may change # later. for production_rule in worlds[0].all_possible_actions(): instance_action_ids[production_rule] = len(instance_action_ids) field = ProductionRuleField(production_rule, is_global_rule=True) production_rule_fields.append(field) action_field = ListField(production_rule_fields) worlds_field = ListField([MetadataField(world) for world in worlds]) fields: Dict[str, Field] = { "sentence": sentence_field, "worlds": worlds_field, "actions": action_field } if identifier is not None: fields["identifier"] = MetadataField(identifier) # Depending on the type of supervision used for training the parser, we may want either # target action sequences or an agenda in our instance. We check if target sequences are # provided, and include them if they are. If not, we'll get an agenda for the sentence, and # include that in the instance. if target_sequences: action_sequence_fields: List[Field] = [] for target_sequence in target_sequences: index_fields = ListField([ IndexField(instance_action_ids[action], action_field) for action in target_sequence ]) action_sequence_fields.append(index_fields) # TODO(pradeep): Define a max length for this field. fields["target_action_sequences"] = ListField( action_sequence_fields) elif self._output_agendas: # TODO(pradeep): Assuming every world gives the same agenda for a sentence. This is true # now, but may change later too. agenda = worlds[0].get_agenda_for_sentence( sentence, add_paths_to_agenda=False) assert agenda, "No agenda found for sentence: %s" % sentence # agenda_field contains indices into actions. agenda_field = ListField([ IndexField(instance_action_ids[action], action_field) for action in agenda ]) fields["agenda"] = agenda_field if labels: labels_field = ListField([ LabelField(label, label_namespace='denotations') for label in labels ]) fields["labels"] = labels_field return Instance(fields)
def test_production_rule_field_can_print(self): field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True) print(field)
def __init__( self, image_feat_path: str, topk: int = -1, lazy: bool = False, reload_tsv: bool = False, cache_path: str = "", relations_path: str = "", attributes_path: str = "", objects_path: str = "", positive_threshold: float = 0.5, negative_threshold: float = 0.5, object_supervision: bool = False, require_some_positive: bool = False, ) -> None: super().__init__(lazy) self._tokenizer = PretrainedTransformerTokenizer("bert-base-uncased", do_lowercase=True) self._word_tokenizer = WordTokenizer() self._token_indexers = { "tokens": PretrainedTransformerIndexer("bert-base-uncased", do_lowercase=True) } self._language = VisualReasoningGqaLanguage(None, None, None, None, None) self._production_rules = self._language.all_possible_productions() self._action_map = { rule: i for i, rule in enumerate(self._production_rules) } production_rule_fields = [ ProductionRuleField(rule, is_global_rule=True) for rule in self._production_rules ] self._production_rule_field = ListField(production_rule_fields) self._image_feat_cache_dir = os.path.join( "cache", image_feat_path.split("/")[-1]) if len(cache_path) > 0: self._image_feat_cache_dir = os.path.join( cache_path, "cache", image_feat_path.split("/")[-1]) self.img_data = None if reload_tsv: self.img_data = load_obj_tsv( image_feat_path, topk, save_cache=False, cache_path=self._image_feat_cache_dir, ) self.img_data = {img["img_id"]: img for img in self.img_data} self.object_data = None self.attribute_data = None if len(objects_path) > 0: self.object_data = json.load(open(objects_path)) self.object_data = { img["image_id"]: img for img in self.object_data } if len(attributes_path) > 0: self.attribute_data = json.load(open(attributes_path)) self.attribute_data = { img["image_id"]: img for img in self.attribute_data } if len(relations_path) > 0: self.relation_data = json.load(open(relations_path)) self.relation_data = { img["image_id"]: img for img in self.relation_data } self.positive_threshold = positive_threshold self.negative_threshold = negative_threshold self.object_supervision = object_supervision self.require_some_positive = require_some_positive
def test_index_converts_field_correctly(self): field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True) field.index(self.vocab) assert field._rule_id == self.s_rule_index
def text_to_instance( self, # type: ignore question: str, table_lines: List[List[str]], target_values: List[str] = None, offline_search_output: List[str] = None) -> Instance: """ Reads text inputs and makes an instance. We pass the ``table_lines`` to ``TableQuestionContext``, and that method accepts this field either as lines from CoreNLP processed tagged files that come with the dataset, or simply in a tsv format where each line corresponds to a row and the cells are tab-separated. Parameters ---------- question : ``str`` Input question table_lines : ``List[List[str]]`` The table content optionally preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines`` for the expected format. target_values : ``List[str]``, optional Target values for the denotations the logical forms should execute to. Not required for testing. offline_search_output : ``List[str]``, optional List of logical forms, produced by offline search. Not required during test. """ # pylint: disable=arguments-differ tokenized_question = self._tokenizer.tokenize(question.lower()) question_field = TextField(tokenized_question, self._question_token_indexers) metadata: Dict[str, Any] = { "question_tokens": [x.text for x in tokenized_question] } table_context = TableQuestionContext.read_from_lines( table_lines, tokenized_question) target_values_field = MetadataField(target_values) world = WikiTablesLanguage(table_context) world_field = MetadataField(world) # Note: Not passing any featre extractors when instantiating the field below. This will make # it use all the available extractors. table_field = KnowledgeGraphField( table_context.get_table_knowledge_graph(), tokenized_question, self._table_token_indexers, tokenizer=self._tokenizer, include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_productions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_instance_specific_entity( rule_right_side) field = ProductionRuleField(production_rule, is_global_rule=is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = { 'question': question_field, 'metadata': MetadataField(metadata), 'table': table_field, 'world': world_field, 'actions': action_field, 'target_values': target_values_field } # We'll make each target action sequence a List[IndexField], where the index is into # the action list we made above. We need to ignore the type here because mypy doesn't # like `action.rule` - it's hard to tell mypy that the ListField is made up of # ProductionRuleFields. action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore if offline_search_output: action_sequence_fields: List[Field] = [] for logical_form in offline_search_output: try: action_sequence = world.logical_form_to_action_sequence( logical_form) index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except ParsingError as error: logger.debug( f'Parsing error: {error.message}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Logical form was: {logical_form}') logger.debug(f'Table info was: {table_lines}') continue except KeyError as error: logger.debug( f'Missing production rule: {error.args}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Table info was: {table_lines}') logger.debug(f'Logical form was: {logical_form}') continue except: logger.error(logical_form) raise if len(action_sequence_fields ) >= self._max_offline_logical_forms: break if not action_sequence_fields: # This is not great, but we're only doing it when we're passed logical form # supervision, so we're expecting labeled logical forms, but we can't actually # produce the logical forms. We should skip this instance. Note that this affects # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the # full test data. return None fields['target_action_sequences'] = ListField( action_sequence_fields) if self._output_agendas: agenda_index_fields: List[Field] = [] for agenda_string in world.get_agenda(conservative=True): agenda_index_fields.append( IndexField(action_map[agenda_string], action_field)) if not agenda_index_fields: agenda_index_fields = [IndexField(-1, action_field)] fields['agenda'] = ListField(agenda_index_fields) return Instance(fields)
def text_to_instance( self, sentence: str, identifier: str, image_ids: List[str], logical_form: str = None, attention_mode: int = None, box_annotation: Dict = None, denotation: str = None, ) -> Instance: tokenized_sentence = self._tokenizer.tokenize(sentence) sentence_field = TextField(tokenized_sentence, self._token_indexers) world = VisualReasoningNlvr2Language(None, None, None, None, None, None) production_rule_fields: List[Field] = [] instance_action_ids: Dict[str, int] = {} for production_rule in world.all_possible_productions(): instance_action_ids[production_rule] = len(instance_action_ids) field = ProductionRuleField(production_rule, is_global_rule=True) production_rule_fields.append(field) action_field = ListField(production_rule_fields) boxes2 = [] feats2 = [] max_num_boxes = 0 for key in image_ids: if self.img_data is not None: img_info = self.img_data[key] else: split_name = "train" if "dev" in key: split_name = "valid" img_info = pickle.load( open( os.path.join(self._image_feat_cache_dir, split_name + "_obj36.tsv", key), "rb", )) boxes = img_info["boxes"].copy() feats = img_info["features"].copy() assert len(boxes) == len(feats) # Normalize the boxes (to 0 ~ 1) img_h, img_w = img_info["img_h"], img_info["img_w"] boxes[..., (0, 2)] /= img_w boxes[..., (1, 3)] /= img_h np.testing.assert_array_less(boxes, 1 + 1e-5) np.testing.assert_array_less(-boxes, 0 + 1e-5) if boxes.shape[0] > self._max_boxes: boxes = boxes[:self._max_boxes, :] feats = feats[:self._max_boxes, :] max_num_boxes = max(max_num_boxes, boxes.shape[0]) boxes2.append(boxes) feats2.append(feats) boxes3 = [ np.zeros((max_num_boxes, img_boxes.shape[-1])) for img_boxes in boxes2 ] feats3 = [ np.zeros((max_num_boxes, img_feats.shape[-1])) for img_feats in feats2 ] for i in range(len(boxes2)): boxes3[i][:boxes2[i].shape[0], :] = boxes2[i] feats3[i][:feats2[i].shape[0], :] = feats2[i] boxes2 = boxes3 feats2 = feats3 feats = np.stack(feats2) boxes = np.stack(boxes2) metadata: Dict[str, Any] = { "utterance": sentence, "tokenized_utterance": tokenized_sentence, "identifier": identifier, } fields: Dict[str, Field] = { "sentence": sentence_field, "actions": action_field, "metadata": MetadataField(metadata), "image_id": MetadataField(identifier[:-2]), "visual_feat": ArrayField(feats), "pos": ArrayField(boxes), } if denotation is not None: fields["denotation"] = LabelField(denotation, skip_indexing=True) if logical_form: lisp_exp = annotation_to_lisp_exp(logical_form) target_sequence = world.logical_form_to_action_sequence(lisp_exp) index_field = [ IndexField(instance_action_ids[action], action_field) for action in target_sequence ] fields["target_action_sequence"] = ListField(index_field) module_attention = annotation_to_module_attention(logical_form) target_attention = target_sequence_to_target_attn( target_sequence, module_attention) gold_question_attentions = self._assign_attention_to_tokens( target_attention, sentence, attention_mode) attn_index_field = [ ListField( [IndexField(att, sentence_field) for att in target_att]) for target_att in gold_question_attentions ] fields["gold_question_attentions"] = ListField(attn_index_field) if box_annotation is None and len(self.box_annotations) > 0: fields["gold_box_annotations"] = MetadataField([]) elif box_annotation is not None: modules = logical_form.split("\n") children = [[] for _ in modules] for j, module in enumerate(modules): num_periods = len(module) - len(module.strip(".")) for k in range(j + 1, len(modules)): num_periods_k = len(modules[k]) - len( modules[k].strip(".")) if num_periods_k <= num_periods: break if num_periods_k == num_periods + 1: children[j].append(k) for j in range(len(modules) - 1, -1, -1): if modules[j].strip(".") == "in_left_image": box_annotation[j] = {} box_annotation[j]["module"] = modules[j].strip(".") box_annotation[j][0] = box_annotation[j + 1][0] box_annotation[j][1] = [] """for k in children[j]: box_annotation[k][0] = box_annotation[k][0] box_annotation[k][1] = []""" elif modules[j].strip(".") == "in_right_image": box_annotation[j] = {} box_annotation[j]["module"] = modules[j].strip(".") box_annotation[j][1] = box_annotation[j + 1][1] box_annotation[j][0] = [] elif modules[j].strip(".") in { "in_one_image", "in_other_image" }: box_annotation[j] = {} box_annotation[j]["module"] = modules[j].strip(".") box_annotation[j][0] = box_annotation[j + 1][0] box_annotation[j][1] = box_annotation[j + 1][1] """for k in children[j]: box_annotation[k][0] = [] box_annotation[k][1] = box_annotation[k][1]""" keys = sorted(list(box_annotation.keys())) # print(identifier, keys) # print(box_annotation) # print(target_sequence) module_boxes = [( mod, box_annotation[mod]["module"], [box_annotation[mod][0], box_annotation[mod][1]], ) for mod in keys] gold_boxes, gold_counts = target_sequence_to_target_boxes( target_sequence, module_boxes, children) # print(identifier, target_sequence, module_boxes, gold_boxes) fields["gold_box_annotations"] = MetadataField(gold_boxes) metadata["gold"] = world.action_sequence_to_logical_form( target_sequence) fields["valid_target_sequence"] = ArrayField( np.array(1, dtype=np.int32)) else: fields["target_action_sequence"] = ListField( [IndexField(0, action_field)]) fields["gold_question_attentions"] = ListField( [ListField([IndexField(0, sentence_field)])]) fields["valid_target_sequence"] = ArrayField( np.array(0, dtype=np.int32)) if len(self.box_annotations) > 0: fields["gold_box_annotations"] = MetadataField([]) return Instance(fields)