def test_doubly_nested_field_works(self):
        field1 = ProductionRuleField("S -> [NP, VP]", is_global_rule=True)
        field2 = ProductionRuleField("NP -> test", is_global_rule=True)
        field3 = ProductionRuleField("VP -> eat", is_global_rule=False)
        list_field = ListField(
            [ListField([field1, field2, field3]),
             ListField([field1, field2])])
        list_field.index(self.vocab)
        padding_lengths = list_field.get_padding_lengths()
        tensors = list_field.as_tensor(padding_lengths)
        assert isinstance(tensors, list)
        assert len(tensors) == 2
        assert isinstance(tensors[0], list)
        assert len(tensors[0]) == 3
        assert isinstance(tensors[1], list)
        assert len(tensors[1]) == 3

        tensor_tuple = tensors[0][0]
        assert tensor_tuple[0] == "S -> [NP, VP]"
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(),
                            [self.s_rule_index])

        tensor_tuple = tensors[0][1]
        assert tensor_tuple[0] == "NP -> test"
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(),
                            [self.np_index])

        tensor_tuple = tensors[0][2]
        assert tensor_tuple[0] == "VP -> eat"
        assert tensor_tuple[1] is False
        assert tensor_tuple[2] is None

        tensor_tuple = tensors[1][0]
        assert tensor_tuple[0] == "S -> [NP, VP]"
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(),
                            [self.s_rule_index])

        tensor_tuple = tensors[1][1]
        assert tensor_tuple[0] == "NP -> test"
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(),
                            [self.np_index])

        # This item was just padding.
        tensor_tuple = tensors[1][2]
        assert tensor_tuple[0] == ""
        assert tensor_tuple[1] is False
        assert tensor_tuple[2] is None
    def test_field_counts_vocab_items_correctly(self):
        field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True)
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        field.count_vocab_items(namespace_token_counts)
        assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 1

        field = ProductionRuleField("S -> [NP, VP]", is_global_rule=False)
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        field.count_vocab_items(namespace_token_counts)
        assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 0
Example #3
0
    def text_to_instance(
            self,  # type: ignore
            query: List[str],
            prelinked_entities: Dict[str, Dict[str, str]] = None,
            sql: List[str] = None) -> Instance:
        # pylint: disable=arguments-differ
        fields: Dict[str, Field] = {}
        tokens = TextField([Token(t) for t in query], self._token_indexers)
        fields["tokens"] = tokens

        if sql is not None:
            action_sequence, all_actions = self._world.get_action_sequence_and_all_actions(
                sql, prelinked_entities)
            if action_sequence is None and self._keep_if_unparsable:
                print("Parse error")
                action_sequence = []
            elif action_sequence is None:
                return None

        index_fields: List[Field] = []
        production_rule_fields: List[Field] = []

        for production_rule in all_actions:
            nonterminal, _ = production_rule.split(' ->')
            production_rule = ' '.join(production_rule.split(' '))
            field = ProductionRuleField(
                production_rule,
                self._world.is_global_rule(nonterminal),
                nonterminal=nonterminal)
            production_rule_fields.append(field)

        valid_actions_field = ListField(production_rule_fields)
        fields["valid_actions"] = valid_actions_field

        action_map = {
            action.rule: i  # type: ignore
            for i, action in enumerate(valid_actions_field.field_list)
        }

        for production_rule in action_sequence:
            index_fields.append(
                IndexField(action_map[production_rule], valid_actions_field))
        if not action_sequence:
            index_fields = [IndexField(-1, valid_actions_field)]

        action_sequence_field = ListField(index_fields)
        fields["action_sequence"] = action_sequence_field
        return Instance(fields)
Example #4
0
    def text_to_instance(
            self,  # type: ignore
            question: str,
            logical_forms: List[str] = None,
            additional_metadata: Dict[str, Any] = None,
            world_extractions: Dict[str, Union[str, List[str]]] = None,
            entity_literals: Dict[str, Union[str, List[str]]] = None,
            tokenized_question: List[Token] = None,
            debug_counter: int = None,
            qr_spec_override: List[Dict[str, int]] = None,
            dynamic_entities_override: Dict[str, str] = None) -> Instance:

        # pylint: disable=arguments-differ
        tokenized_question = tokenized_question or self._tokenizer.tokenize(
            question.lower())
        additional_metadata = additional_metadata or dict()
        additional_metadata['question_tokens'] = [
            token.text for token in tokenized_question
        ]
        if world_extractions is not None:
            additional_metadata['world_extractions'] = world_extractions
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)

        if qr_spec_override is not None or dynamic_entities_override is not None:
            # Dynamically specify theory and/or entities
            dynamic_entities = dynamic_entities_override or self._dynamic_entities
            neighbors: Dict[str, List[str]] = {
                key: []
                for key in dynamic_entities.keys()
            }
            knowledge_graph = KnowledgeGraph(entities=set(
                dynamic_entities.keys()),
                                             neighbors=neighbors,
                                             entity_text=dynamic_entities)
            world = QuarelWorld(knowledge_graph,
                                self._lf_syntax,
                                qr_coeff_sets=qr_spec_override)
        else:
            knowledge_graph = self._knowledge_graph
            world = self._world

        table_field = KnowledgeGraphField(knowledge_graph,
                                          tokenized_question,
                                          self._entity_token_indexers,
                                          tokenizer=self._tokenizer)

        if self._tagger_only:
            fields: Dict[str, Field] = {'tokens': question_field}
            if entity_literals is not None:
                entity_tags = self._get_entity_tags(self._all_entities,
                                                    table_field,
                                                    entity_literals,
                                                    tokenized_question)
                if debug_counter > 0:
                    logger.info(f'raw entity tags = {entity_tags}')
                entity_tags_bio = self._convert_tags_bio(entity_tags)
                fields['tags'] = SequenceLabelField(entity_tags_bio,
                                                    question_field)
                additional_metadata['tags_gold'] = entity_tags_bio
            additional_metadata['words'] = [x.text for x in tokenized_question]
            fields['metadata'] = MetadataField(additional_metadata)
            return Instance(fields)

        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            'question': question_field,
            'table': table_field,
            'world': world_field,
            'actions': action_field
        }

        if self._denotation_only:
            denotation_field = LabelField(additional_metadata['answer_index'],
                                          skip_indexing=True)
            fields['denotation_target'] = denotation_field

        if self._entity_bits_mode is not None and world_extractions is not None:
            entity_bits = self._get_entity_tags(['world1', 'world2'],
                                                table_field, world_extractions,
                                                tokenized_question)
            if self._entity_bits_mode == "simple":
                entity_bits_v = [[[0, 0], [1, 0], [0, 1]][tag]
                                 for tag in entity_bits]
            elif self._entity_bits_mode == "simple_collapsed":
                entity_bits_v = [[[0], [1], [1]][tag] for tag in entity_bits]
            elif self._entity_bits_mode == "simple3":
                entity_bits_v = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]][tag]
                                 for tag in entity_bits]

            entity_bits_field = ArrayField(np.array(entity_bits_v))
            fields['entity_bits'] = entity_bits_field

        if logical_forms:
            action_map = {
                action.rule: i
                for i, action in enumerate(action_field.field_list)
            }  # type: ignore
            action_sequence_fields: List[Field] = []
            for logical_form in logical_forms:
                expression = world.parse_logical_form(logical_form)
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.info(
                        f'Missing production rule: {error.args}, skipping logical form'
                    )
                    logger.info(f'Question was: {question}')
                    logger.info(f'Logical form was: {logical_form}')
                    continue
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        fields['metadata'] = MetadataField(additional_metadata or {})
        return Instance(fields)
    def test_batch_tensors_does_not_modify_list(self):
        field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True)
        field.index(self.vocab)
        padding_lengths = field.get_padding_lengths()
        tensor_dict1 = field.as_tensor(padding_lengths)

        field = ProductionRuleField("NP -> test", is_global_rule=True)
        field.index(self.vocab)
        padding_lengths = field.get_padding_lengths()
        tensor_dict2 = field.as_tensor(padding_lengths)
        tensor_list = [tensor_dict1, tensor_dict2]
        assert field.batch_tensors(tensor_list) == tensor_list
    def test_as_tensor_produces_correct_output(self):
        field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True)
        field.index(self.vocab)
        tensor_tuple = field.as_tensor(field.get_padding_lengths())
        assert isinstance(tensor_tuple, tuple)
        assert len(tensor_tuple) == 4
        assert tensor_tuple[0] == "S -> [NP, VP]"
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(),
                            [self.s_rule_index])

        field = ProductionRuleField("S -> [NP, VP]", is_global_rule=False)
        field.index(self.vocab)
        tensor_tuple = field.as_tensor(field.get_padding_lengths())
        assert isinstance(tensor_tuple, tuple)
        assert len(tensor_tuple) == 4
        assert tensor_tuple[0] == "S -> [NP, VP]"
        assert tensor_tuple[1] is False
        assert tensor_tuple[2] is None
 def test_padding_lengths_are_computed_correctly(self):
     field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True)
     field.index(self.vocab)
     assert field.get_padding_lengths() == {}
 def test_index_converts_field_correctly(self):
     field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True)
     field.index(self.vocab)
     assert field._rule_id == self.s_rule_index
 def test_production_rule_field_can_print(self):
     field = ProductionRuleField("S -> [NP, VP]", is_global_rule=True)
     print(field)
    def text_to_instance(
        self,  # type: ignore
        question: str,
        table_lines: List[List[str]],
        target_values: List[str] = None,
        offline_search_output: List[str] = None,
    ) -> Instance:
        """
        Reads text inputs and makes an instance. We pass the ``table_lines`` to ``TableQuestionContext``, and that
        method accepts this field either as lines from CoreNLP processed tagged files that come with the dataset,
        or simply in a tsv format where each line corresponds to a row and the cells are tab-separated.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[List[str]]``
            The table content optionally preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines``
            for the expected format.
        target_values : ``List[str]``, optional
            Target values for the denotations the logical forms should execute to. Not required for testing.
        offline_search_output : ``List[str]``, optional
            List of logical forms, produced by offline search. Not required during test.
        """
        tokenized_question = self._tokenizer.tokenize(question.lower())
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)
        metadata: Dict[str, Any] = {
            "question_tokens": [x.text for x in tokenized_question]
        }
        table_context = TableQuestionContext.read_from_lines(
            table_lines, tokenized_question)
        world = WikiTablesLanguage(table_context)
        world_field = MetadataField(world)
        # Note: Not passing any featre extractors when instantiating the field below. This will make
        # it use all the available extractors.
        table_field = KnowledgeGraphField(
            table_context.get_table_knowledge_graph(),
            tokenized_question,
            self._table_token_indexers,
            tokenizer=self._tokenizer,
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens,
        )
        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_productions():
            _, rule_right_side = production_rule.split(" -> ")
            is_global_rule = not world.is_instance_specific_entity(
                rule_right_side)
            field = ProductionRuleField(production_rule,
                                        is_global_rule=is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            "question": question_field,
            "metadata": MetadataField(metadata),
            "table": table_field,
            "world": world_field,
            "actions": action_field,
        }

        if target_values is not None:
            target_values_field = MetadataField(target_values)
            fields["target_values"] = target_values_field

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)
        }  # type: ignore
        if offline_search_output:
            action_sequence_fields: List[Field] = []
            for logical_form in offline_search_output:
                try:
                    action_sequence = world.logical_form_to_action_sequence(
                        logical_form)
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except ParsingError as error:
                    logger.debug(
                        f"Parsing error: {error.message}, skipping logical form"
                    )
                    logger.debug(f"Question was: {question}")
                    logger.debug(f"Logical form was: {logical_form}")
                    logger.debug(f"Table info was: {table_lines}")
                    continue
                except KeyError as error:
                    logger.debug(
                        f"Missing production rule: {error.args}, skipping logical form"
                    )
                    logger.debug(f"Question was: {question}")
                    logger.debug(f"Table info was: {table_lines}")
                    logger.debug(f"Logical form was: {logical_form}")
                    continue
                except:  # noqa
                    logger.error(logical_form)
                    raise
                if len(action_sequence_fields
                       ) >= self._max_offline_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields["target_action_sequences"] = ListField(
                action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda(conservative=True):
                agenda_index_fields.append(
                    IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields["agenda"] = ListField(agenda_index_fields)
        return Instance(fields)
Example #11
0
    def text_to_instance(  # type: ignore
            self,
            utterances: List[str],
            sql_query_labels: List[str] = None) -> Instance:
        """
        Parameters
        ----------
        utterances: ``List[str]``, required.
            List of utterances in the interaction, the last element is the current utterance.
        sql_query_labels: ``List[str]``, optional
            The SQL queries that are given as labels during training or validation.
        """
        if self._num_turns_to_concatenate:
            utterances[-1] = f" {END_OF_UTTERANCE_TOKEN} ".join(
                utterances[-self._num_turns_to_concatenate:])

        utterance = utterances[-1]
        action_sequence: List[str] = []

        if not utterance:
            return None

        world = AtisWorld(utterances=utterances)

        if sql_query_labels:
            # If there are multiple sql queries given as labels, we use the shortest
            # one for training.
            sql_query = min(sql_query_labels, key=len)
            try:
                action_sequence = world.get_action_sequence(sql_query)
            except ParseError:
                action_sequence = []
                logger.debug("Parsing error")

        tokenized_utterance = self._tokenizer.tokenize(utterance.lower())
        utterance_field = TextField(tokenized_utterance, self._token_indexers)

        production_rule_fields: List[Field] = []

        for production_rule in world.all_possible_actions():
            nonterminal, _ = production_rule.split(" ->")
            # The whitespaces are not semantically meaningful, so we filter them out.
            production_rule = " ".join([
                token for token in production_rule.split(" ") if token != "ws"
            ])
            field = ProductionRuleField(production_rule,
                                        self._is_global_rule(nonterminal))
            production_rule_fields.append(field)

        action_field = ListField(production_rule_fields)
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)  # type: ignore
        }
        index_fields: List[Field] = []
        world_field = MetadataField(world)
        fields = {
            "utterance": utterance_field,
            "actions": action_field,
            "world": world_field,
            "linking_scores": ArrayField(world.linking_scores),
        }

        if sql_query_labels is not None:
            fields["sql_queries"] = MetadataField(sql_query_labels)
            if self._keep_if_unparseable or action_sequence:
                for production_rule in action_sequence:
                    index_fields.append(
                        IndexField(action_map[production_rule], action_field))
                if not action_sequence:
                    index_fields = [IndexField(-1, action_field)]
                action_sequence_field = ListField(index_fields)
                fields["target_action_sequence"] = action_sequence_field
            else:
                # If we are given a SQL query, but we are unable to parse it, and we do not specify explicitly
                # to keep it, then we will skip the it.
                return None

        return Instance(fields)
Example #12
0
    def text_to_instance(
        self,  # type: ignore
        sentence: str,
        structured_representations: List[List[List[JsonDict]]],
        labels: List[str] = None,
        target_sequences: List[List[str]] = None,
        identifier: str = None,
    ) -> Instance:
        """
        Parameters
        ----------
        sentence : ``str``
            The query sentence.
        structured_representations : ``List[List[List[JsonDict]]]``
            A list of Json representations of all the worlds. See expected format in this class' docstring.
        labels : ``List[str]`` (optional)
            List of string representations of the labels (true or false) corresponding to the
            ``structured_representations``. Not required while testing.
        target_sequences : ``List[List[str]]`` (optional)
            List of target action sequences for each element which lead to the correct denotation in
            worlds corresponding to the structured representations.
        identifier : ``str`` (optional)
            The identifier from the dataset if available.
        """
        worlds = []
        for structured_representation in structured_representations:
            boxes = {
                Box(object_list, box_id)
                for box_id, object_list in enumerate(structured_representation)
            }
            worlds.append(NlvrLanguage(boxes))
        tokenized_sentence = self._tokenizer.tokenize(sentence)
        sentence_field = TextField(tokenized_sentence, self._sentence_token_indexers)
        production_rule_fields: List[Field] = []
        instance_action_ids: Dict[str, int] = {}
        # TODO(pradeep): Assuming that possible actions are the same in all worlds. This may change
        # later.
        for production_rule in worlds[0].all_possible_productions():
            instance_action_ids[production_rule] = len(instance_action_ids)
            field = ProductionRuleField(production_rule, is_global_rule=True)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)
        worlds_field = ListField([MetadataField(world) for world in worlds])
        metadata: Dict[str, Any] = {"sentence_tokens": [x.text for x in tokenized_sentence]}
        fields: Dict[str, Field] = {
            "sentence": sentence_field,
            "worlds": worlds_field,
            "actions": action_field,
            "metadata": MetadataField(metadata),
        }
        if identifier is not None:
            fields["identifier"] = MetadataField(identifier)
        # Depending on the type of supervision used for training the parser, we may want either
        # target action sequences or an agenda in our instance. We check if target sequences are
        # provided, and include them if they are. If not, we'll get an agenda for the sentence, and
        # include that in the instance.
        if target_sequences:
            action_sequence_fields: List[Field] = []
            for target_sequence in target_sequences:
                index_fields = ListField(
                    [
                        IndexField(instance_action_ids[action], action_field)
                        for action in target_sequence
                    ]
                )
                action_sequence_fields.append(index_fields)
                # TODO(pradeep): Define a max length for this field.
            fields["target_action_sequences"] = ListField(action_sequence_fields)
        elif self._output_agendas:
            # TODO(pradeep): Assuming every world gives the same agenda for a sentence. This is true
            # now, but may change later too.
            agenda = worlds[0].get_agenda_for_sentence(sentence)
            assert agenda, "No agenda found for sentence: %s" % sentence
            # agenda_field contains indices into actions.
            agenda_field = ListField(
                [IndexField(instance_action_ids[action], action_field) for action in agenda]
            )
            fields["agenda"] = agenda_field
        if labels:
            labels_field = ListField(
                [LabelField(label, label_namespace="denotations") for label in labels]
            )
            fields["labels"] = labels_field

        return Instance(fields)