def test_doubly_nested_field_works(self): field1 = ProductionRuleField('S -> [NP, VP]', is_global_rule=True) field2 = ProductionRuleField('NP -> test', is_global_rule=True) field3 = ProductionRuleField('VP -> eat', is_global_rule=False) list_field = ListField([ListField([field1, field2, field3]), ListField([field1, field2])]) list_field.index(self.vocab) padding_lengths = list_field.get_padding_lengths() tensors = list_field.as_tensor(padding_lengths) assert isinstance(tensors, list) assert len(tensors) == 2 assert isinstance(tensors[0], list) assert len(tensors[0]) == 3 assert isinstance(tensors[1], list) assert len(tensors[1]) == 3 tensor_tuple = tensors[0][0] assert tensor_tuple[0] == 'S -> [NP, VP]' assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index]) tensor_tuple = tensors[0][1] assert tensor_tuple[0] == 'NP -> test' assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.np_index]) tensor_tuple = tensors[0][2] assert tensor_tuple[0] == 'VP -> eat' assert tensor_tuple[1] is False assert tensor_tuple[2] is None tensor_tuple = tensors[1][0] assert tensor_tuple[0] == 'S -> [NP, VP]' assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index]) tensor_tuple = tensors[1][1] assert tensor_tuple[0] == 'NP -> test' assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.np_index]) # This item was just padding. tensor_tuple = tensors[1][2] assert tensor_tuple[0] == '' assert tensor_tuple[1] is False assert tensor_tuple[2] is None
def test_field_counts_vocab_items_correctly(self): field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True) namespace_token_counts = defaultdict(lambda: defaultdict(int)) field.count_vocab_items(namespace_token_counts) assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 1 field = ProductionRuleField('S -> [NP, VP]', is_global_rule=False) namespace_token_counts = defaultdict(lambda: defaultdict(int)) field.count_vocab_items(namespace_token_counts) assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 0
def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance: question_tokens = self._read_tokens_from_json_list( json_obj['question_tokens']) question_field = TextField(question_tokens, self._question_token_indexers) question_metadata = MetadataField( {"question_tokens": [x.text for x in question_tokens]}) table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines( json_obj['table_lines'], question_tokens) entity_tokens = [ self._read_tokens_from_json_list(token_list) for token_list in json_obj['entity_texts'] ] table_field = KnowledgeGraphField( table_knowledge_graph, question_tokens, tokenizer=None, token_indexers=self._table_token_indexers, entity_tokens=entity_tokens, linking_features=json_obj['linking_features'], include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens) world = WikiTablesWorld(table_knowledge_graph) world_field = MetadataField(world) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_table_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) example_string_field = MetadataField(json_obj['example_lisp_string']) fields = { 'question': question_field, 'metadata': question_metadata, 'table': table_field, 'world': world_field, 'actions': action_field, 'example_lisp_string': example_string_field } if 'target_action_sequences' in json_obj or 'agenda' in json_obj: action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore if 'target_action_sequences' in json_obj: action_sequence_fields: List[Field] = [] for sequence in json_obj['target_action_sequences']: index_fields: List[Field] = [] for production_rule in sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) fields['target_action_sequences'] = ListField( action_sequence_fields) if 'agenda' in json_obj: agenda_index_fields: List[Field] = [] for agenda_action in json_obj['agenda']: agenda_index_fields.append( IndexField(action_map[agenda_action], action_field)) fields['agenda'] = ListField(agenda_index_fields) return Instance(fields)
def text_to_instance( self, # type: ignore question: str, table_lines: List[str], example_lisp_string: str = None, dpd_output: List[str] = None, tokenized_question: List[Token] = None) -> Instance: """ Reads text inputs and makes an instance. WikitableQuestions dataset provides tables as TSV files, which we use for training. Parameters ---------- question : ``str`` Input question table_lines : ``List[str]`` The table content itself, as a list of rows. See ``TableQuestionKnowledgeGraph.read_from_lines`` for the expected format. example_lisp_string : ``str``, optional The original (lisp-formatted) example string in the WikiTableQuestions dataset. This comes directly from the ``.examples`` file provided with the dataset. We pass this to SEMPRE for evaluating logical forms during training. It isn't otherwise used for anything. dpd_output : List[str], optional List of logical forms, produced by dynamic programming on denotations. Not required during test. tokenized_question : ``List[Token]``, optional If you have already tokenized the question, you can pass that in here, so we don't duplicate that work. You might, for example, do batch processing on the questions in the whole dataset, then pass the result in here. """ # pylint: disable=arguments-differ tokenized_question = tokenized_question or self._tokenizer.tokenize( question.lower()) question_field = TextField(tokenized_question, self._question_token_indexers) metadata: Dict[str, Any] = { "question_tokens": [x.text for x in tokenized_question] } metadata["original_table"] = "".join(table_lines) table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines( table_lines, tokenized_question) table_metadata = MetadataField(table_lines) table_field = KnowledgeGraphField( table_knowledge_graph, tokenized_question, self._table_token_indexers, tokenizer=self._tokenizer, feature_extractors=self._linking_feature_extractors, include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens) world = WikiTablesWorld(table_knowledge_graph) world_field = MetadataField(world) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_table_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = { 'question': question_field, 'metadata': MetadataField(metadata), 'table': table_field, 'world': world_field, 'actions': action_field } if self._include_table_metadata: fields['table_metadata'] = table_metadata if example_lisp_string: fields['example_lisp_string'] = MetadataField(example_lisp_string) # We'll make each target action sequence a List[IndexField], where the index is into # the action list we made above. We need to ignore the type here because mypy doesn't # like `action.rule` - it's hard to tell mypy that the ListField is made up of # ProductionRuleFields. action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore if dpd_output: action_sequence_fields: List[Field] = [] for logical_form in dpd_output: if not self._should_keep_logical_form(logical_form): logger.debug(f'Question was: {question}') logger.debug(f'Table info was: {table_lines}') continue try: expression = world.parse_logical_form(logical_form) except ParsingError as error: logger.debug( f'Parsing error: {error.message}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Logical form was: {logical_form}') logger.debug(f'Table info was: {table_lines}') continue except: logger.error(logical_form) raise action_sequence = world.get_action_sequence(expression) try: index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except KeyError as error: logger.debug( f'Missing production rule: {error.args}, skipping logical form' ) logger.debug(f'Question was: {question}') logger.debug(f'Table info was: {table_lines}') logger.debug(f'Logical form was: {logical_form}') continue if len(action_sequence_fields) >= self._max_dpd_logical_forms: break if not action_sequence_fields: # This is not great, but we're only doing it when we're passed logical form # supervision, so we're expecting labeled logical forms, but we can't actually # produce the logical forms. We should skip this instance. Note that this affects # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the # full test data. return None fields['target_action_sequences'] = ListField( action_sequence_fields) if self._output_agendas: agenda_index_fields: List[Field] = [] for agenda_string in world.get_agenda(): agenda_index_fields.append( IndexField(action_map[agenda_string], action_field)) if not agenda_index_fields: agenda_index_fields = [IndexField(-1, action_field)] fields['agenda'] = ListField(agenda_index_fields) return Instance(fields)
def text_to_instance( self, # type: ignore sentence: str, structured_representations: List[List[List[JsonDict]]], labels: List[str] = None, target_sequences: List[List[str]] = None, identifier: str = None) -> Instance: """ Parameters ---------- sentence : ``str`` The query sentence. structured_representations : ``List[List[List[JsonDict]]]`` A list of Json representations of all the worlds. See expected format in this class' docstring. labels : ``List[str]`` (optional) List of string representations of the labels (true or false) corresponding to the ``structured_representations``. Not required while testing. target_sequences : ``List[List[str]]`` (optional) List of target action sequences for each element which lead to the correct denotation in worlds corresponding to the structured representations. identifier : ``str`` (optional) The identifier from the dataset if available. """ # pylint: disable=arguments-differ worlds = [NlvrWorld(data) for data in structured_representations] tokenized_sentence = self._tokenizer.tokenize(sentence) sentence_field = TextField(tokenized_sentence, self._sentence_token_indexers) production_rule_fields: List[Field] = [] instance_action_ids: Dict[str, int] = {} # TODO(pradeep): Assuming that possible actions are the same in all worlds. This may change # later. for production_rule in worlds[0].all_possible_actions(): instance_action_ids[production_rule] = len(instance_action_ids) field = ProductionRuleField(production_rule, is_global_rule=True) production_rule_fields.append(field) action_field = ListField(production_rule_fields) worlds_field = ListField([MetadataField(world) for world in worlds]) fields: Dict[str, Field] = { "sentence": sentence_field, "worlds": worlds_field, "actions": action_field } if identifier is not None: fields["identifier"] = MetadataField(identifier) # Depending on the type of supervision used for training the parser, we may want either # target action sequences or an agenda in our instance. We check if target sequences are # provided, and include them if they are. If not, we'll get an agenda for the sentence, and # include that in the instance. if target_sequences: action_sequence_fields: List[Field] = [] for target_sequence in target_sequences: index_fields = ListField([ IndexField(instance_action_ids[action], action_field) for action in target_sequence ]) action_sequence_fields.append(index_fields) # TODO(pradeep): Define a max length for this field. fields["target_action_sequences"] = ListField( action_sequence_fields) elif self._output_agendas: # TODO(pradeep): Assuming every world gives the same agenda for a sentence. This is true # now, but may change later too. agenda = worlds[0].get_agenda_for_sentence( sentence, add_paths_to_agenda=False) assert agenda, "No agenda found for sentence: %s" % sentence # agenda_field contains indices into actions. agenda_field = ListField([ IndexField(instance_action_ids[action], action_field) for action in agenda ]) fields["agenda"] = agenda_field if labels: labels_field = ListField([ LabelField(label, label_namespace='denotations') for label in labels ]) fields["labels"] = labels_field return Instance(fields)
def text_to_instance(self, # type: ignore utterances: List[str], sql_query: str = None) -> Instance: # pylint: disable=arguments-differ """ Parameters ---------- utterances: ``List[str]``, required. List of utterances in the interaction, the last element is the current utterance. sql_query: ``str``, optional The SQL query, given as label during training or validation. """ utterance = utterances[-1] action_sequence: List[str] = [] if not utterance: return None world = AtisWorld(utterances) if sql_query: try: action_sequence = world.get_action_sequence(sql_query) except ParseError: logger.debug(f'Parsing error') tokenized_utterance = self._tokenizer.tokenize(utterance.lower()) utterance_field = TextField(tokenized_utterance, self._token_indexers) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): lhs, _ = production_rule.split(' ->') is_global_rule = not lhs in ['number', 'string'] # The whitespaces are not semantically meaningful, so we filter them out. production_rule = ' '.join([token for token in production_rule.split(' ') if token != 'ws']) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) action_map = {action.rule: i # type: ignore for i, action in enumerate(action_field.field_list)} index_fields: List[Field] = [] world_field = MetadataField(world) fields = {'utterance' : utterance_field, 'actions' : action_field, 'world' : world_field, 'linking_scores' : ArrayField(world.linking_scores)} if sql_query: if action_sequence: for production_rule in action_sequence: index_fields.append(IndexField(action_map[production_rule], action_field)) action_sequence_field: List[Field] = [] action_sequence_field.append(ListField(index_fields)) fields['target_action_sequence'] = ListField(action_sequence_field) else: # If we are given a SQL query, but we are unable to parse it, then we will skip it. return None return Instance(fields)
def test_batch_tensors_does_not_modify_list(self): field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True) field.index(self.vocab) padding_lengths = field.get_padding_lengths() tensor_dict1 = field.as_tensor(padding_lengths) field = ProductionRuleField('NP -> test', is_global_rule=True) field.index(self.vocab) padding_lengths = field.get_padding_lengths() tensor_dict2 = field.as_tensor(padding_lengths) tensor_list = [tensor_dict1, tensor_dict2] assert field.batch_tensors(tensor_list) == tensor_list
def test_as_tensor_produces_correct_output(self): field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True) field.index(self.vocab) tensor_tuple = field.as_tensor(field.get_padding_lengths()) assert isinstance(tensor_tuple, tuple) assert len(tensor_tuple) == 3 assert tensor_tuple[0] == 'S -> [NP, VP]' assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index]) field = ProductionRuleField('S -> [NP, VP]', is_global_rule=False) field.index(self.vocab) tensor_tuple = field.as_tensor(field.get_padding_lengths()) assert isinstance(tensor_tuple, tuple) assert len(tensor_tuple) == 3 assert tensor_tuple[0] == 'S -> [NP, VP]' assert tensor_tuple[1] is False assert tensor_tuple[2] is None
def test_padding_lengths_are_computed_correctly(self): field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True) field.index(self.vocab) assert field.get_padding_lengths() == {}
def test_index_converts_field_correctly(self): field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True) field.index(self.vocab) assert field._rule_id == self.s_rule_index