def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance: question_tokens = self._read_tokens_from_json_list( json_obj['question_tokens']) question_field = TextField(question_tokens, self._question_token_indexers) question_metadata = MetadataField( {"question_tokens": [x.text for x in question_tokens]}) table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines( json_obj['table_lines'], question_tokens) entity_tokens = [ self._read_tokens_from_json_list(token_list) for token_list in json_obj['entity_texts'] ] table_field = KnowledgeGraphField( table_knowledge_graph, question_tokens, tokenizer=None, token_indexers=self._table_token_indexers, entity_tokens=entity_tokens, linking_features=json_obj['linking_features'], include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens) world = WikiTablesWorld(table_knowledge_graph) world_field = MetadataField(world) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_table_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) example_string_field = MetadataField(json_obj['example_lisp_string']) fields = { 'question': question_field, 'metadata': question_metadata, 'table': table_field, 'world': world_field, 'actions': action_field, 'example_lisp_string': example_string_field } if 'target_action_sequences' in json_obj or 'agenda' in json_obj: action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore if 'target_action_sequences' in json_obj: action_sequence_fields: List[Field] = [] for sequence in json_obj['target_action_sequences']: index_fields: List[Field] = [] for production_rule in sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) fields['target_action_sequences'] = ListField( action_sequence_fields) if 'agenda' in json_obj: agenda_index_fields: List[Field] = [] for agenda_action in json_obj['agenda']: agenda_index_fields.append( IndexField(action_map[agenda_action], action_field)) fields['agenda'] = ListField(agenda_index_fields) return Instance(fields)
def text_to_instance( self, # type: ignore question: str, table_lines: List[str], example_lisp_string: str = None, dpd_output: List[str] = None, tokenized_question: List[Token] = None) -> Instance: """ Reads text inputs and makes an instance. WikitableQuestions dataset provides tables as TSV files, which we use for training. Parameters ---------- question : ``str`` Input question table_lines : ``List[str]`` The table content itself, as a list of rows. See ``TableQuestionKnowledgeGraph.read_from_lines`` for the expected format. example_lisp_string : ``str``, optional The original (lisp-formatted) example string in the WikiTableQuestions dataset. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE for evaluating logical forms during training. It isn't otherwise used for anything. dpd_output : List[str], optional List of logical forms, produced by dynamic programming on denotations. Not required during test. tokenized_question : ``List[Token]``, optional If you have already tokenized the question, you can pass that in here, so we don't duplicate that work. You might, for example, do batch processing on the questions in the whole dataset, then pass the result in here. """ # pylint: disable=arguments-differ tokenized_question = tokenized_question or self._tokenizer.tokenize( question.lower()) question_field = TextField(tokenized_question, self._question_token_indexers) metadata: Dict[str, Any] = { "question_tokens": [x.text for x in tokenized_question] } metadata["original_table"] = "".join(table_lines) table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines( table_lines, tokenized_question) table_metadata = MetadataField(table_lines) table_field = KnowledgeGraphField( table_knowledge_graph, tokenized_question, self._table_token_indexers, tokenizer=self._tokenizer, feature_extractors=self._linking_feature_extractors, include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens) world = WikiTablesWorld(table_knowledge_graph) world_field = MetadataField(world) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_table_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = { 'question': question_field, 'metadata': MetadataField(metadata), 'table': table_field, 'world': world_field, 'actions': action_field } if self._include_table_metadata: fields['table_metadata'] = table_metadata if example_lisp_string: fields['example_lisp_string'] = MetadataField(example_lisp_string) # We'll make each target action sequence a List[IndexField], where the index is into # the action list we made above. We need to ignore the type here because mypy doesn't # like `action.rule` - it's hard to tell mypy that the ListField is made up of # ProductionRuleFields. action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore if dpd_output: action_sequence_fields: List[Field] = [] for logical_form in dpd_output: if not self._should_keep_logical_form(logical_form): logger.debug(f'Question was: {question}') logger.debug(f'Table info was: {table_lines}') continue try: expression = world.parse_logical_form(logical_form) except ParsingError as error: logger.debug( f'Parsing error: {error.message}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Logical form was: {logical_form}') logger.debug(f'Table info was: {table_lines}') continue except: logger.error(logical_form) raise action_sequence = world.get_action_sequence(expression) try: index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except KeyError as error: logger.debug( f'Missing production rule: {error.args}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Table info was: {table_lines}') logger.debug(f'Logical form was: {logical_form}') continue if len(action_sequence_fields) >= self._max_dpd_logical_forms: break if not action_sequence_fields: # This is not great, but we're only doing it when we're passed logical form # supervision, so we're expecting labeled logical forms, but we can't actually # produce the logical forms. We should skip this instance. Note that this affects # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the # full test data. return None fields['target_action_sequences'] = ListField( action_sequence_fields) if self._output_agendas: agenda_index_fields: List[Field] = [] for agenda_string in world.get_agenda(): agenda_index_fields.append( IndexField(action_map[agenda_string], action_field)) if not agenda_index_fields: agenda_index_fields = [IndexField(-1, action_field)] fields['agenda'] = ListField(agenda_index_fields) return Instance(fields)