示例#1
0
    def test_doubly_nested_field_works(self):
        field1 = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
        field2 = ProductionRuleField('NP -> test', is_global_rule=True)
        field3 = ProductionRuleField('VP -> eat', is_global_rule=False)
        list_field = ListField([ListField([field1, field2, field3]),
                                ListField([field1, field2])])
        list_field.index(self.vocab)
        padding_lengths = list_field.get_padding_lengths()
        tensors = list_field.as_tensor(padding_lengths)
        assert isinstance(tensors, list)
        assert len(tensors) == 2
        assert isinstance(tensors[0], list)
        assert len(tensors[0]) == 3
        assert isinstance(tensors[1], list)
        assert len(tensors[1]) == 3

        tensor_tuple = tensors[0][0]
        assert tensor_tuple[0] == 'S -> [NP, VP]'
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index])

        tensor_tuple = tensors[0][1]
        assert tensor_tuple[0] == 'NP -> test'
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.np_index])

        tensor_tuple = tensors[0][2]
        assert tensor_tuple[0] == 'VP -> eat'
        assert tensor_tuple[1] is False
        assert tensor_tuple[2] is None

        tensor_tuple = tensors[1][0]
        assert tensor_tuple[0] == 'S -> [NP, VP]'
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index])

        tensor_tuple = tensors[1][1]
        assert tensor_tuple[0] == 'NP -> test'
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.np_index])

        # This item was just padding.
        tensor_tuple = tensors[1][2]
        assert tensor_tuple[0] == ''
        assert tensor_tuple[1] is False
        assert tensor_tuple[2] is None
示例#2
0
    def test_field_counts_vocab_items_correctly(self):
        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        field.count_vocab_items(namespace_token_counts)
        assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 1

        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=False)
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        field.count_vocab_items(namespace_token_counts)
        assert namespace_token_counts["rule_labels"]["S -> [NP, VP]"] == 0
示例#3
0
    def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance:
        question_tokens = self._read_tokens_from_json_list(
            json_obj['question_tokens'])
        question_field = TextField(question_tokens,
                                   self._question_token_indexers)
        question_metadata = MetadataField(
            {"question_tokens": [x.text for x in question_tokens]})
        table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines(
            json_obj['table_lines'], question_tokens)
        entity_tokens = [
            self._read_tokens_from_json_list(token_list)
            for token_list in json_obj['entity_texts']
        ]
        table_field = KnowledgeGraphField(
            table_knowledge_graph,
            question_tokens,
            tokenizer=None,
            token_indexers=self._table_token_indexers,
            entity_tokens=entity_tokens,
            linking_features=json_obj['linking_features'],
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        world = WikiTablesWorld(table_knowledge_graph)
        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        example_string_field = MetadataField(json_obj['example_lisp_string'])

        fields = {
            'question': question_field,
            'metadata': question_metadata,
            'table': table_field,
            'world': world_field,
            'actions': action_field,
            'example_lisp_string': example_string_field
        }

        if 'target_action_sequences' in json_obj or 'agenda' in json_obj:
            action_map = {
                action.rule: i
                for i, action in enumerate(action_field.field_list)
            }  # type: ignore
        if 'target_action_sequences' in json_obj:
            action_sequence_fields: List[Field] = []
            for sequence in json_obj['target_action_sequences']:
                index_fields: List[Field] = []
                for production_rule in sequence:
                    index_fields.append(
                        IndexField(action_map[production_rule], action_field))
                action_sequence_fields.append(ListField(index_fields))
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        if 'agenda' in json_obj:
            agenda_index_fields: List[Field] = []
            for agenda_action in json_obj['agenda']:
                agenda_index_fields.append(
                    IndexField(action_map[agenda_action], action_field))
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)
示例#4
0
    def text_to_instance(
            self,  # type: ignore
            question: str,
            table_lines: List[str],
            example_lisp_string: str = None,
            dpd_output: List[str] = None,
            tokenized_question: List[Token] = None) -> Instance:
        """
        Reads text inputs and makes an instance. WikitableQuestions dataset provides tables as TSV
        files, which we use for training.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[str]``
            The table content itself, as a list of rows. See
            ``TableQuestionKnowledgeGraph.read_from_lines`` for the expected format.
        example_lisp_string : ``str``, optional
            The original (lisp-formatted) example string in the WikiTableQuestions dataset.  This
            comes directly from the ``.examples`` file provided with the dataset.  We pass this to
            SEMPRE for evaluating logical forms during training.  It isn't otherwise used for
            anything.
        dpd_output : List[str], optional
            List of logical forms, produced by dynamic programming on denotations. Not required
            during test.
        tokenized_question : ``List[Token]``, optional
            If you have already tokenized the question, you can pass that in here, so we don't
            duplicate that work.  You might, for example, do batch processing on the questions in
            the whole dataset, then pass the result in here.
        """
        # pylint: disable=arguments-differ
        tokenized_question = tokenized_question or self._tokenizer.tokenize(
            question.lower())
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)
        metadata: Dict[str, Any] = {
            "question_tokens": [x.text for x in tokenized_question]
        }
        metadata["original_table"] = "".join(table_lines)
        table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines(
            table_lines, tokenized_question)
        table_metadata = MetadataField(table_lines)
        table_field = KnowledgeGraphField(
            table_knowledge_graph,
            tokenized_question,
            self._table_token_indexers,
            tokenizer=self._tokenizer,
            feature_extractors=self._linking_feature_extractors,
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        world = WikiTablesWorld(table_knowledge_graph)
        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            'question': question_field,
            'metadata': MetadataField(metadata),
            'table': table_field,
            'world': world_field,
            'actions': action_field
        }
        if self._include_table_metadata:
            fields['table_metadata'] = table_metadata
        if example_lisp_string:
            fields['example_lisp_string'] = MetadataField(example_lisp_string)

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)
        }  # type: ignore
        if dpd_output:
            action_sequence_fields: List[Field] = []
            for logical_form in dpd_output:
                if not self._should_keep_logical_form(logical_form):
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                try:
                    expression = world.parse_logical_form(logical_form)
                except ParsingError as error:
                    logger.debug(
                        f'Parsing error: {error.message}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Logical form was: {logical_form}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                except:
                    logger.error(logical_form)
                    raise
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.debug(
                        f'Missing production rule: {error.args}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    logger.debug(f'Logical form was: {logical_form}')
                    continue
                if len(action_sequence_fields) >= self._max_dpd_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda():
                agenda_index_fields.append(
                    IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)
示例#5
0
    def text_to_instance(
            self,  # type: ignore
            sentence: str,
            structured_representations: List[List[List[JsonDict]]],
            labels: List[str] = None,
            target_sequences: List[List[str]] = None,
            identifier: str = None) -> Instance:
        """
        Parameters
        ----------
        sentence : ``str``
            The query sentence.
        structured_representations : ``List[List[List[JsonDict]]]``
            A list of Json representations of all the worlds. See expected format in this class' docstring.
        labels : ``List[str]`` (optional)
            List of string representations of the labels (true or false) corresponding to the
            ``structured_representations``. Not required while testing.
        target_sequences : ``List[List[str]]`` (optional)
            List of target action sequences for each element which lead to the correct denotation in
            worlds corresponding to the structured representations.
        identifier : ``str`` (optional)
            The identifier from the dataset if available.
        """
        # pylint: disable=arguments-differ
        worlds = [NlvrWorld(data) for data in structured_representations]
        tokenized_sentence = self._tokenizer.tokenize(sentence)
        sentence_field = TextField(tokenized_sentence,
                                   self._sentence_token_indexers)
        production_rule_fields: List[Field] = []
        instance_action_ids: Dict[str, int] = {}
        # TODO(pradeep): Assuming that possible actions are the same in all worlds. This may change
        # later.
        for production_rule in worlds[0].all_possible_actions():
            instance_action_ids[production_rule] = len(instance_action_ids)
            field = ProductionRuleField(production_rule, is_global_rule=True)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)
        worlds_field = ListField([MetadataField(world) for world in worlds])
        fields: Dict[str, Field] = {
            "sentence": sentence_field,
            "worlds": worlds_field,
            "actions": action_field
        }
        if identifier is not None:
            fields["identifier"] = MetadataField(identifier)
        # Depending on the type of supervision used for training the parser, we may want either
        # target action sequences or an agenda in our instance. We check if target sequences are
        # provided, and include them if they are. If not, we'll get an agenda for the sentence, and
        # include that in the instance.
        if target_sequences:
            action_sequence_fields: List[Field] = []
            for target_sequence in target_sequences:
                index_fields = ListField([
                    IndexField(instance_action_ids[action], action_field)
                    for action in target_sequence
                ])
                action_sequence_fields.append(index_fields)
                # TODO(pradeep): Define a max length for this field.
            fields["target_action_sequences"] = ListField(
                action_sequence_fields)
        elif self._output_agendas:
            # TODO(pradeep): Assuming every world gives the same agenda for a sentence. This is true
            # now, but may change later too.
            agenda = worlds[0].get_agenda_for_sentence(
                sentence, add_paths_to_agenda=False)
            assert agenda, "No agenda found for sentence: %s" % sentence
            # agenda_field contains indices into actions.
            agenda_field = ListField([
                IndexField(instance_action_ids[action], action_field)
                for action in agenda
            ])
            fields["agenda"] = agenda_field
        if labels:
            labels_field = ListField([
                LabelField(label, label_namespace='denotations')
                for label in labels
            ])
            fields["labels"] = labels_field

        return Instance(fields)
示例#6
0
    def text_to_instance(self,  # type: ignore
                         utterances: List[str],
                         sql_query: str = None) -> Instance:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        utterances: ``List[str]``, required.
            List of utterances in the interaction, the last element is the current utterance.
        sql_query: ``str``, optional
            The SQL query, given as label during training or validation.
        """
        utterance = utterances[-1]
        action_sequence: List[str] = []

        if not utterance:
            return None

        world = AtisWorld(utterances)

        if sql_query:
            try:
                action_sequence = world.get_action_sequence(sql_query)
            except ParseError:
                logger.debug(f'Parsing error')

        tokenized_utterance = self._tokenizer.tokenize(utterance.lower())
        utterance_field = TextField(tokenized_utterance, self._token_indexers)

        production_rule_fields: List[Field] = []

        for production_rule in world.all_possible_actions():
            lhs, _ = production_rule.split(' ->')
            is_global_rule = not lhs in ['number', 'string']
            # The whitespaces are not semantically meaningful, so we filter them out.
            production_rule = ' '.join([token for token in production_rule.split(' ') if token != 'ws'])
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)

        action_field = ListField(production_rule_fields)
        action_map = {action.rule: i # type: ignore
                      for i, action in enumerate(action_field.field_list)}
        index_fields: List[Field] = []
        world_field = MetadataField(world)
        fields = {'utterance' : utterance_field,
                  'actions' : action_field,
                  'world' : world_field,
                  'linking_scores' : ArrayField(world.linking_scores)}

        if sql_query:
            if action_sequence:
                for production_rule in action_sequence:
                    index_fields.append(IndexField(action_map[production_rule], action_field))

                action_sequence_field: List[Field] = []
                action_sequence_field.append(ListField(index_fields))
                fields['target_action_sequence'] = ListField(action_sequence_field)
            else:
                # If we are given a SQL query, but we are unable to parse it, then we will skip it.
                return None

        return Instance(fields)
示例#7
0
    def test_batch_tensors_does_not_modify_list(self):
        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
        field.index(self.vocab)
        padding_lengths = field.get_padding_lengths()
        tensor_dict1 = field.as_tensor(padding_lengths)

        field = ProductionRuleField('NP -> test', is_global_rule=True)
        field.index(self.vocab)
        padding_lengths = field.get_padding_lengths()
        tensor_dict2 = field.as_tensor(padding_lengths)
        tensor_list = [tensor_dict1, tensor_dict2]
        assert field.batch_tensors(tensor_list) == tensor_list
示例#8
0
    def test_as_tensor_produces_correct_output(self):
        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
        field.index(self.vocab)
        tensor_tuple = field.as_tensor(field.get_padding_lengths())
        assert isinstance(tensor_tuple, tuple)
        assert len(tensor_tuple) == 3
        assert tensor_tuple[0] == 'S -> [NP, VP]'
        assert tensor_tuple[1] is True
        assert_almost_equal(tensor_tuple[2].detach().cpu().numpy(), [self.s_rule_index])

        field = ProductionRuleField('S -> [NP, VP]', is_global_rule=False)
        field.index(self.vocab)
        tensor_tuple = field.as_tensor(field.get_padding_lengths())
        assert isinstance(tensor_tuple, tuple)
        assert len(tensor_tuple) == 3
        assert tensor_tuple[0] == 'S -> [NP, VP]'
        assert tensor_tuple[1] is False
        assert tensor_tuple[2] is None
示例#9
0
 def test_padding_lengths_are_computed_correctly(self):
     field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
     field.index(self.vocab)
     assert field.get_padding_lengths() == {}
示例#10
0
 def test_index_converts_field_correctly(self):
     field = ProductionRuleField('S -> [NP, VP]', is_global_rule=True)
     field.index(self.vocab)
     assert field._rule_id == self.s_rule_index