def test_doubly_nested_field_works(self): field1 = ProductionRuleField("S -> [NP, VP]", is_global_rule=True) field2 = ProductionRuleField("NP -> test", is_global_rule=True) field3 = ProductionRuleField("VP -> eat", is_global_rule=False) list_field = ListField( [ListField([field1, field2, field3]), ListField([field1, field2])]) list_field.index(self.vocab) padding_lengths = list_field.get_padding_lengths() tensors = list_field.as_tensor(padding_lengths) assert isinstance(tensors, list) assert len(tensors) == 2 assert isinstance(tensors[0], list) assert len(tensors[0]) == 3 assert isinstance(tensors[1], list) assert len(tensors[1]) == 3 tensor_tuple = tensors[0][0] assert tensor_tuple[0] == "S -> [NP, VP]" assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index]) tensor_tuple = tensors[0][1] assert tensor_tuple[0] == "NP -> test" assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.np_index]) tensor_tuple = tensors[0][2] assert tensor_tuple[0] == "VP -> eat" assert tensor_tuple[1] is False assert tensor_tuple[2] is None tensor_tuple = tensors[1][0] assert tensor_tuple[0] == "S -> [NP, VP]" assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index]) tensor_tuple = tensors[1][1] assert tensor_tuple[0] == "NP -> test" assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.np_index]) # This item was just padding. tensor_tuple = tensors[1][2] assert tensor_tuple[0] == "" assert tensor_tuple[1] is False assert tensor_tuple[2] is None
def test_field_counts_vocab_items_correctly(self): field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True) namespace_token_counts = defaultdict(lambda: defaultdict(int)) field.count_vocab_items(namespace_token_counts) assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 1 field = ProductionRuleField("S -> [NP, VP]", is_global_rule=False) namespace_token_counts = defaultdict(lambda: defaultdict(int)) field.count_vocab_items(namespace_token_counts) assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 0
def text_to_instance( self, # type: ignore query: List[str], prelinked_entities: Dict[str, Dict[str, str]] = None, sql: List[str] = None) -> Instance: # pylint: disable=arguments-differ fields: Dict[str, Field] = {} tokens = TextField([Token(t) for t in query], self._token_indexers) fields["tokens"] = tokens if sql is not None: action_sequence, all_actions = self._world.get_action_sequence_and_all_actions( sql, prelinked_entities) if action_sequence is None and self._keep_if_unparsable: print("Parse error") action_sequence = [] elif action_sequence is None: return None index_fields: List[Field] = [] production_rule_fields: List[Field] = [] for production_rule in all_actions: nonterminal, _ = production_rule.split(' ->') production_rule = ' '.join(production_rule.split(' ')) field = ProductionRuleField( production_rule, self._world.is_global_rule(nonterminal), nonterminal=nonterminal) production_rule_fields.append(field) valid_actions_field = ListField(production_rule_fields) fields["valid_actions"] = valid_actions_field action_map = { action.rule: i # type: ignore for i, action in enumerate(valid_actions_field.field_list) } for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], valid_actions_field)) if not action_sequence: index_fields = [IndexField(-1, valid_actions_field)] action_sequence_field = ListField(index_fields) fields["action_sequence"] = action_sequence_field return Instance(fields)
def text_to_instance( self, # type: ignore question: str, logical_forms: List[str] = None, additional_metadata: Dict[str, Any] = None, world_extractions: Dict[str, Union[str, List[str]]] = None, entity_literals: Dict[str, Union[str, List[str]]] = None, tokenized_question: List[Token] = None, debug_counter: int = None, qr_spec_override: List[Dict[str, int]] = None, dynamic_entities_override: Dict[str, str] = None) -> Instance: # pylint: disable=arguments-differ tokenized_question = tokenized_question or self._tokenizer.tokenize( question.lower()) additional_metadata = additional_metadata or dict() additional_metadata['question_tokens'] = [ token.text for token in tokenized_question ] if world_extractions is not None: additional_metadata['world_extractions'] = world_extractions question_field = TextField(tokenized_question, self._question_token_indexers) if qr_spec_override is not None or dynamic_entities_override is not None: # Dynamically specify theory and/or entities dynamic_entities = dynamic_entities_override or self._dynamic_entities neighbors: Dict[str, List[str]] = { key: [] for key in dynamic_entities.keys() } knowledge_graph = KnowledgeGraph(entities=set( dynamic_entities.keys()), neighbors=neighbors, entity_text=dynamic_entities) world = QuarelWorld(knowledge_graph, self._lf_syntax, qr_coeff_sets=qr_spec_override) else: knowledge_graph = self._knowledge_graph world = self._world table_field = KnowledgeGraphField(knowledge_graph, tokenized_question, self._entity_token_indexers, tokenizer=self._tokenizer) if self._tagger_only: fields: Dict[str, Field] = {'tokens': question_field} if entity_literals is not None: entity_tags = self._get_entity_tags(self._all_entities, table_field, entity_literals, tokenized_question) if debug_counter > 0: logger.info(f'raw entity tags = {entity_tags}') entity_tags_bio = self._convert_tags_bio(entity_tags) fields['tags'] = SequenceLabelField(entity_tags_bio, question_field) additional_metadata['tags_gold'] = entity_tags_bio additional_metadata['words'] = [x.text for x in tokenized_question] fields['metadata'] = MetadataField(additional_metadata) return Instance(fields) world_field = MetadataField(world) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): _, rule_right_side = production_rule.split(' -> ') is_global_rule = not world.is_table_entity(rule_right_side) field = ProductionRuleField(production_rule, is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = { 'question': question_field, 'table': table_field, 'world': world_field, 'actions': action_field } if self._denotation_only: denotation_field = LabelField(additional_metadata['answer_index'], skip_indexing=True) fields['denotation_target'] = denotation_field if self._entity_bits_mode is not None and world_extractions is not None: entity_bits = self._get_entity_tags(['world1', 'world2'], table_field, world_extractions, tokenized_question) if self._entity_bits_mode == "simple": entity_bits_v = [[[0, 0], [1, 0], [0, 1]][tag] for tag in entity_bits] elif self._entity_bits_mode == "simple_collapsed": entity_bits_v = [[[0], [1], [1]][tag] for tag in entity_bits] elif self._entity_bits_mode == "simple3": entity_bits_v = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]][tag] for tag in entity_bits] entity_bits_field = ArrayField(np.array(entity_bits_v)) fields['entity_bits'] = entity_bits_field if logical_forms: action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore action_sequence_fields: List[Field] = [] for logical_form in logical_forms: expression = world.parse_logical_form(logical_form) action_sequence = world.get_action_sequence(expression) try: index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except KeyError as error: logger.info( f'Missing production rule: {error.args}, skipping logical form' ) logger.info(f'Question was: {question}') logger.info(f'Logical form was: {logical_form}') continue fields['target_action_sequences'] = ListField( action_sequence_fields) fields['metadata'] = MetadataField(additional_metadata or {}) return Instance(fields)
def test_batch_tensors_does_not_modify_list(self): field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True) field.index(self.vocab) padding_lengths = field.get_padding_lengths() tensor_dict1 = field.as_tensor(padding_lengths) field = ProductionRuleField("NP -> test", is_global_rule=True) field.index(self.vocab) padding_lengths = field.get_padding_lengths() tensor_dict2 = field.as_tensor(padding_lengths) tensor_list = [tensor_dict1, tensor_dict2] assert field.batch_tensors(tensor_list) == tensor_list
def test_as_tensor_produces_correct_output(self): field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True) field.index(self.vocab) tensor_tuple = field.as_tensor(field.get_padding_lengths()) assert isinstance(tensor_tuple, tuple) assert len(tensor_tuple) == 4 assert tensor_tuple[0] == "S -> [NP, VP]" assert tensor_tuple[1] is True assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index]) field = ProductionRuleField("S -> [NP, VP]", is_global_rule=False) field.index(self.vocab) tensor_tuple = field.as_tensor(field.get_padding_lengths()) assert isinstance(tensor_tuple, tuple) assert len(tensor_tuple) == 4 assert tensor_tuple[0] == "S -> [NP, VP]" assert tensor_tuple[1] is False assert tensor_tuple[2] is None
def test_padding_lengths_are_computed_correctly(self): field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True) field.index(self.vocab) assert field.get_padding_lengths() == {}
def test_index_converts_field_correctly(self): field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True) field.index(self.vocab) assert field._rule_id == self.s_rule_index
def test_production_rule_field_can_print(self): field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True) print(field)
def text_to_instance( self, # type: ignore question: str, table_lines: List[List[str]], target_values: List[str] = None, offline_search_output: List[str] = None, ) -> Instance: """ Reads text inputs and makes an instance. We pass the ``table_lines`` to ``TableQuestionContext``, and that method accepts this field either as lines from CoreNLP processed tagged files that come with the dataset, or simply in a tsv format where each line corresponds to a row and the cells are tab-separated. Parameters ---------- question : ``str`` Input question table_lines : ``List[List[str]]`` The table content optionally preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines`` for the expected format. target_values : ``List[str]``, optional Target values for the denotations the logical forms should execute to. Not required for testing. offline_search_output : ``List[str]``, optional List of logical forms, produced by offline search. Not required during test. """ tokenized_question = self._tokenizer.tokenize(question.lower()) question_field = TextField(tokenized_question, self._question_token_indexers) metadata: Dict[str, Any] = { "question_tokens": [x.text for x in tokenized_question] } table_context = TableQuestionContext.read_from_lines( table_lines, tokenized_question) world = WikiTablesLanguage(table_context) world_field = MetadataField(world) # Note: Not passing any featre extractors when instantiating the field below. This will make # it use all the available extractors. table_field = KnowledgeGraphField( table_context.get_table_knowledge_graph(), tokenized_question, self._table_token_indexers, tokenizer=self._tokenizer, include_in_vocab=self._use_table_for_vocab, max_table_tokens=self._max_table_tokens, ) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_productions(): _, rule_right_side = production_rule.split(" -> ") is_global_rule = not world.is_instance_specific_entity( rule_right_side) field = ProductionRuleField(production_rule, is_global_rule=is_global_rule) production_rule_fields.append(field) action_field = ListField(production_rule_fields) fields = { "question": question_field, "metadata": MetadataField(metadata), "table": table_field, "world": world_field, "actions": action_field, } if target_values is not None: target_values_field = MetadataField(target_values) fields["target_values"] = target_values_field # We'll make each target action sequence a List[IndexField], where the index is into # the action list we made above. We need to ignore the type here because mypy doesn't # like `action.rule` - it's hard to tell mypy that the ListField is made up of # ProductionRuleFields. action_map = { action.rule: i for i, action in enumerate(action_field.field_list) } # type: ignore if offline_search_output: action_sequence_fields: List[Field] = [] for logical_form in offline_search_output: try: action_sequence = world.logical_form_to_action_sequence( logical_form) index_fields: List[Field] = [] for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) action_sequence_fields.append(ListField(index_fields)) except ParsingError as error: logger.debug( f"Parsing error: {error.message}, skipping logical form" ) logger.debug(f"Question was: {question}") logger.debug(f"Logical form was: {logical_form}") logger.debug(f"Table info was: {table_lines}") continue except KeyError as error: logger.debug( f"Missing production rule: {error.args}, skipping logical form" ) logger.debug(f"Question was: {question}") logger.debug(f"Table info was: {table_lines}") logger.debug(f"Logical form was: {logical_form}") continue except: # noqa logger.error(logical_form) raise if len(action_sequence_fields ) >= self._max_offline_logical_forms: break if not action_sequence_fields: # This is not great, but we're only doing it when we're passed logical form # supervision, so we're expecting labeled logical forms, but we can't actually # produce the logical forms. We should skip this instance. Note that this affects # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the # full test data. return None fields["target_action_sequences"] = ListField( action_sequence_fields) if self._output_agendas: agenda_index_fields: List[Field] = [] for agenda_string in world.get_agenda(conservative=True): agenda_index_fields.append( IndexField(action_map[agenda_string], action_field)) if not agenda_index_fields: agenda_index_fields = [IndexField(-1, action_field)] fields["agenda"] = ListField(agenda_index_fields) return Instance(fields)
def text_to_instance( # type: ignore self, utterances: List[str], sql_query_labels: List[str] = None) -> Instance: """ Parameters ---------- utterances: ``List[str]``, required. List of utterances in the interaction, the last element is the current utterance. sql_query_labels: ``List[str]``, optional The SQL queries that are given as labels during training or validation. """ if self._num_turns_to_concatenate: utterances[-1] = f" {END_OF_UTTERANCE_TOKEN} ".join( utterances[-self._num_turns_to_concatenate:]) utterance = utterances[-1] action_sequence: List[str] = [] if not utterance: return None world = AtisWorld(utterances=utterances) if sql_query_labels: # If there are multiple sql queries given as labels, we use the shortest # one for training. sql_query = min(sql_query_labels, key=len) try: action_sequence = world.get_action_sequence(sql_query) except ParseError: action_sequence = [] logger.debug("Parsing error") tokenized_utterance = self._tokenizer.tokenize(utterance.lower()) utterance_field = TextField(tokenized_utterance, self._token_indexers) production_rule_fields: List[Field] = [] for production_rule in world.all_possible_actions(): nonterminal, _ = production_rule.split(" ->") # The whitespaces are not semantically meaningful, so we filter them out. production_rule = " ".join([ token for token in production_rule.split(" ") if token != "ws" ]) field = ProductionRuleField(production_rule, self._is_global_rule(nonterminal)) production_rule_fields.append(field) action_field = ListField(production_rule_fields) action_map = { action.rule: i for i, action in enumerate(action_field.field_list) # type: ignore } index_fields: List[Field] = [] world_field = MetadataField(world) fields = { "utterance": utterance_field, "actions": action_field, "world": world_field, "linking_scores": ArrayField(world.linking_scores), } if sql_query_labels is not None: fields["sql_queries"] = MetadataField(sql_query_labels) if self._keep_if_unparseable or action_sequence: for production_rule in action_sequence: index_fields.append( IndexField(action_map[production_rule], action_field)) if not action_sequence: index_fields = [IndexField(-1, action_field)] action_sequence_field = ListField(index_fields) fields["target_action_sequence"] = action_sequence_field else: # If we are given a SQL query, but we are unable to parse it, and we do not specify explicitly # to keep it, then we will skip the it. return None return Instance(fields)
def text_to_instance( self, # type: ignore sentence: str, structured_representations: List[List[List[JsonDict]]], labels: List[str] = None, target_sequences: List[List[str]] = None, identifier: str = None, ) -> Instance: """ Parameters ---------- sentence : ``str`` The query sentence. structured_representations : ``List[List[List[JsonDict]]]`` A list of Json representations of all the worlds. See expected format in this class' docstring. labels : ``List[str]`` (optional) List of string representations of the labels (true or false) corresponding to the ``structured_representations``. Not required while testing. target_sequences : ``List[List[str]]`` (optional) List of target action sequences for each element which lead to the correct denotation in worlds corresponding to the structured representations. identifier : ``str`` (optional) The identifier from the dataset if available. """ worlds = [] for structured_representation in structured_representations: boxes = { Box(object_list, box_id) for box_id, object_list in enumerate(structured_representation) } worlds.append(NlvrLanguage(boxes)) tokenized_sentence = self._tokenizer.tokenize(sentence) sentence_field = TextField(tokenized_sentence, self._sentence_token_indexers) production_rule_fields: List[Field] = [] instance_action_ids: Dict[str, int] = {} # TODO(pradeep): Assuming that possible actions are the same in all worlds. This may change # later. for production_rule in worlds[0].all_possible_productions(): instance_action_ids[production_rule] = len(instance_action_ids) field = ProductionRuleField(production_rule, is_global_rule=True) production_rule_fields.append(field) action_field = ListField(production_rule_fields) worlds_field = ListField([MetadataField(world) for world in worlds]) metadata: Dict[str, Any] = {"sentence_tokens": [x.text for x in tokenized_sentence]} fields: Dict[str, Field] = { "sentence": sentence_field, "worlds": worlds_field, "actions": action_field, "metadata": MetadataField(metadata), } if identifier is not None: fields["identifier"] = MetadataField(identifier) # Depending on the type of supervision used for training the parser, we may want either # target action sequences or an agenda in our instance. We check if target sequences are # provided, and include them if they are. If not, we'll get an agenda for the sentence, and # include that in the instance. if target_sequences: action_sequence_fields: List[Field] = [] for target_sequence in target_sequences: index_fields = ListField( [ IndexField(instance_action_ids[action], action_field) for action in target_sequence ] ) action_sequence_fields.append(index_fields) # TODO(pradeep): Define a max length for this field. fields["target_action_sequences"] = ListField(action_sequence_fields) elif self._output_agendas: # TODO(pradeep): Assuming every world gives the same agenda for a sentence. This is true # now, but may change later too. agenda = worlds[0].get_agenda_for_sentence(sentence) assert agenda, "No agenda found for sentence: %s" % sentence # agenda_field contains indices into actions. agenda_field = ListField( [IndexField(instance_action_ids[action], action_field) for action in agenda] ) fields["agenda"] = agenda_field if labels: labels_field = ListField( [LabelField(label, label_namespace="denotations") for label in labels] ) fields["labels"] = labels_field return Instance(fields)