def setUp(self):
        self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        json = {
                'question': self.utterance,
                'columns': ['Name in English', 'Location in English'],
                'cells': [['Paradeniz', 'Mersin'],
                          ['Lake Gala', 'Edirne']]
                }
        self.graph = TableQuestionKnowledgeGraph.read_from_json(json)
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name", namespace='tokens')
        self.in_index = self.vocab.add_token_to_namespace("in", namespace='tokens')
        self.english_index = self.vocab.add_token_to_namespace("english", namespace='tokens')
        self.location_index = self.vocab.add_token_to_namespace("location", namespace='tokens')
        self.paradeniz_index = self.vocab.add_token_to_namespace("paradeniz", namespace='tokens')
        self.mersin_index = self.vocab.add_token_to_namespace("mersin", namespace='tokens')
        self.lake_index = self.vocab.add_token_to_namespace("lake", namespace='tokens')
        self.gala_index = self.vocab.add_token_to_namespace("gala", namespace='tokens')
        self.negative_one_index = self.vocab.add_token_to_namespace("-1", namespace='tokens')
        self.zero_index = self.vocab.add_token_to_namespace("0", namespace='tokens')
        self.one_index = self.vocab.add_token_to_namespace("1", namespace='tokens')

        self.oov_index = self.vocab.get_token_index('random OOV string', namespace='tokens')
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)

        super(KnowledgeGraphFieldTest, self).setUp()
Exemple #2
0
    def setUp(self):
        self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        table_file = self.FIXTURES_ROOT / "data" / "wikitables" / "tables" / "341.tagged"
        self.graph = TableQuestionContext.read_from_file(
            table_file, self.utterance).get_table_knowledge_graph()
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name",
                                                            namespace='tokens')
        self.in_index = self.vocab.add_token_to_namespace("in",
                                                          namespace='tokens')
        self.english_index = self.vocab.add_token_to_namespace(
            "english", namespace='tokens')
        self.location_index = self.vocab.add_token_to_namespace(
            "location", namespace='tokens')
        self.mersin_index = self.vocab.add_token_to_namespace(
            "mersin", namespace='tokens')

        self.oov_index = self.vocab.get_token_index('random OOV string',
                                                    namespace='tokens')
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance,
                                         self.token_indexers, self.tokenizer)

        super(KnowledgeGraphFieldTest, self).setUp()
 def test_span_overlap_fraction(self):
     # pylint: disable=protected-access
     utterance = self.tokenizer.tokenize("what is the name in english of mersin?")
     field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)
     entity = 'fb:row.row.name_in_english'
     entity_text = field._entity_text_map[entity]
     feature_values = [field._span_overlap_fraction(entity, entity_text, token, i, utterance)
                       for i, token in enumerate(utterance)]
     assert feature_values == [0, 0, 0, 1, 2/3, 1/3, 0, 0, 0]
 def test_span_overlap_fraction(self):
     # pylint: disable=protected-access
     utterance = self.tokenizer.tokenize("what is the name in english of mersin?")
     field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)
     entity = 'fb:row.row.name_in_english'
     entity_text = field._entity_text_map[entity]
     feature_values = [field._span_overlap_fraction(entity, entity_text, token, i, utterance)
                       for i, token in enumerate(utterance)]
     assert feature_values == [0, 0, 0, 1, 1, 1, 0, 0, 0]
    def test_lemma_feature_extractor(self):

        utterance = self.tokenizer.tokenize("Names in English")
        field = KnowledgeGraphField(self.graph, self.utterance,
                                    self.token_indexers, self.tokenizer)
        entity = "string_column:name_in_english"
        lemma_feature = field._contains_lemma_match(
            entity, field._entity_text_map[entity], utterance[0], 0, utterance)
        assert lemma_feature == 1
 def test_lemma_feature_extractor(self):
     # pylint: disable=protected-access
     utterance = self.tokenizer.tokenize("Names in English")
     field = KnowledgeGraphField(self.graph, self.utterance,
                                 self.token_indexers, self.tokenizer)
     entity = 'fb:row.row.name_in_english'
     lemma_feature = field._contains_lemma_match(
         entity, field._entity_text_map[entity], utterance[0], 0, utterance)
     assert lemma_feature == 1
 def test_lemma_feature_extractor(self):
     # pylint: disable=protected-access
     utterance = self.tokenizer.tokenize("Names in English")
     field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)
     entity = 'fb:row.row.name_in_english'
     lemma_feature = field._contains_lemma_match(entity,
                                                 field._entity_text_map[entity],
                                                 utterance[0],
                                                 0,
                                                 utterance)
     assert lemma_feature == 1
    def test_span_overlap_fraction(self):

        utterance = self.tokenizer.tokenize(
            "what is the name in english of mersin?")
        field = KnowledgeGraphField(self.graph, self.utterance,
                                    self.token_indexers, self.tokenizer)
        entity = "string_column:name_in_english"
        entity_text = field._entity_text_map[entity]
        feature_values = [
            field._span_overlap_fraction(entity, entity_text, token, i,
                                         utterance)
            for i, token in enumerate(utterance)
        ]
        assert feature_values == [0, 0, 0, 1, 1, 1, 0, 0, 0]
    def setUp(self):
        self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        json = {
                'question': self.utterance,
                'columns': ['Name in English', 'Location in English'],
                'cells': [['Paradeniz', 'Mersin'],
                          ['Lake Gala', 'Edirne']]
                }
        self.graph = TableQuestionKnowledgeGraph.read_from_json(json)
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name", namespace='tokens')
        self.in_index = self.vocab.add_token_to_namespace("in", namespace='tokens')
        self.english_index = self.vocab.add_token_to_namespace("english", namespace='tokens')
        self.location_index = self.vocab.add_token_to_namespace("location", namespace='tokens')
        self.paradeniz_index = self.vocab.add_token_to_namespace("paradeniz", namespace='tokens')
        self.mersin_index = self.vocab.add_token_to_namespace("mersin", namespace='tokens')
        self.lake_index = self.vocab.add_token_to_namespace("lake", namespace='tokens')
        self.gala_index = self.vocab.add_token_to_namespace("gala", namespace='tokens')
        self.negative_one_index = self.vocab.add_token_to_namespace("-1", namespace='tokens')
        self.zero_index = self.vocab.add_token_to_namespace("0", namespace='tokens')
        self.one_index = self.vocab.add_token_to_namespace("1", namespace='tokens')

        self.oov_index = self.vocab.get_token_index('random OOV string', namespace='tokens')
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)

        super(KnowledgeGraphFieldTest, self).setUp()
Exemple #10
0
    def _json_blob_to_instance(self, json_obj: JsonDict) -> Instance:
        question_tokens = self._read_tokens_from_json_list(
            json_obj['question_tokens'])
        question_field = TextField(question_tokens,
                                   self._question_token_indexers)
        table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines(
            json_obj['table_lines'], question_tokens)
        entity_tokens = [
            self._read_tokens_from_json_list(token_list)
            for token_list in json_obj['entity_texts']
        ]
        table_field = KnowledgeGraphField(
            table_knowledge_graph,
            question_tokens,
            tokenizer=None,
            token_indexers=self._table_token_indexers,
            entity_tokens=entity_tokens,
            linking_features=json_obj['linking_features'],
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        world = WikiTablesWorld(table_knowledge_graph)
        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        example_string_field = MetadataField(json_obj['example_lisp_string'])

        fields = {
            'question': question_field,
            'table': table_field,
            'world': world_field,
            'actions': action_field,
            'example_lisp_string': example_string_field
        }

        if 'target_action_sequences' in json_obj:
            action_map = {
                action.rule: i
                for i, action in enumerate(action_field.field_list)
            }  # type: ignore
            action_sequence_fields: List[Field] = []
            for sequence in json_obj['target_action_sequences']:
                index_fields: List[Field] = []
                for production_rule in sequence:
                    index_fields.append(
                        IndexField(action_map[production_rule], action_field))
                action_sequence_fields.append(ListField(index_fields))
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)

        return Instance(fields)
Exemple #11
0
    def text_to_instance(
            self,  # type: ignore
            question: str,
            table_lines: List[List[str]],
            target_values: List[str] = None,
            offline_search_output: List[str] = None) -> Instance:
        """
        Reads text inputs and makes an instance. We pass the ``table_lines`` to ``TableQuestionContext``, and that
        method accepts this field either as lines from CoreNLP processed tagged files that come with the dataset,
        or simply in a tsv format where each line corresponds to a row and the cells are tab-separated.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[List[str]]``
            The table content optionally preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines``
            for the expected format.
        target_values : ``List[str]``, optional
            Target values for the denotations the logical forms should execute to. Not required for testing.
        offline_search_output : ``List[str]``, optional
            List of logical forms, produced by offline search. Not required during test.
        """
        # pylint: disable=arguments-differ
        tokenized_question = self._tokenizer.tokenize(question.lower())
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)
        metadata: Dict[str, Any] = {
            "question_tokens": [x.text for x in tokenized_question]
        }
        table_context = TableQuestionContext.read_from_lines(
            table_lines, tokenized_question)
        target_values_field = MetadataField(target_values)
        world = WikiTablesLanguage(table_context)
        world_field = MetadataField(world)
        # Note: Not passing any featre extractors when instantiating the field below. This will make
        # it use all the available extractors.
        table_field = KnowledgeGraphField(
            table_context.get_table_knowledge_graph(),
            tokenized_question,
            self._table_token_indexers,
            tokenizer=self._tokenizer,
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_productions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_instance_specific_entity(
                rule_right_side)
            field = ProductionRuleField(production_rule,
                                        is_global_rule=is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            'question': question_field,
            'metadata': MetadataField(metadata),
            'table': table_field,
            'world': world_field,
            'actions': action_field,
            'target_values': target_values_field
        }

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)
        }  # type: ignore
        if offline_search_output:
            action_sequence_fields: List[Field] = []
            for logical_form in offline_search_output:
                try:
                    action_sequence = world.logical_form_to_action_sequence(
                        logical_form)
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except ParsingError as error:
                    logger.debug(
                        f'Parsing error: {error.message}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Logical form was: {logical_form}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                except KeyError as error:
                    logger.debug(
                        f'Missing production rule: {error.args}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    logger.debug(f'Logical form was: {logical_form}')
                    continue
                except:
                    logger.error(logical_form)
                    raise
                if len(action_sequence_fields
                       ) >= self._max_offline_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda(conservative=True):
                agenda_index_fields.append(
                    IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)
class KnowledgeGraphFieldTest(AllenNlpTestCase):
    def setUp(self):
        self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        table_file = self.FIXTURES_ROOT / "data" / "wikitables" / "tables" / "341.tagged"
        self.graph = TableQuestionContext.read_from_file(
            table_file, self.utterance).get_table_knowledge_graph()
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name",
                                                            namespace='tokens')
        self.in_index = self.vocab.add_token_to_namespace("in",
                                                          namespace='tokens')
        self.english_index = self.vocab.add_token_to_namespace(
            "english", namespace='tokens')
        self.location_index = self.vocab.add_token_to_namespace(
            "location", namespace='tokens')
        self.mersin_index = self.vocab.add_token_to_namespace(
            "mersin", namespace='tokens')

        self.oov_index = self.vocab.get_token_index('random OOV string',
                                                    namespace='tokens')
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance,
                                         self.token_indexers, self.tokenizer)

        super().setUp()

    def test_count_vocab_items(self):
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        self.field.count_vocab_items(namespace_token_counts)

        assert namespace_token_counts["tokens"] == {
            'name': 1,
            'in': 2,
            'english': 2,
            'location': 1,
            'mersin': 1,
        }

    def test_index_converts_field_correctly(self):
        # pylint: disable=protected-access
        self.field.index(self.vocab)
        assert self.field._indexed_entity_texts.keys() == {'tokens'}
        # Note that these are sorted by their _identifiers_, not their cell text, so the
        # `fb:row.rows` show up after the `fb:cells`.
        expected_array = [[self.mersin_index],
                          [
                              self.location_index, self.in_index,
                              self.english_index
                          ],
                          [self.name_index, self.in_index, self.english_index]]
        assert self.field._indexed_entity_texts['tokens'] == expected_array

    def test_get_padding_lengths_raises_if_not_indexed(self):
        with pytest.raises(AssertionError):
            self.field.get_padding_lengths()

    def test_padding_lengths_are_computed_correctly(self):
        # pylint: disable=protected-access
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {
            'num_entities': 3,
            'num_entity_tokens': 3,
            'num_utterance_tokens': 4
        }
        self.field._token_indexers[
            'token_characters'] = TokenCharactersIndexer(min_padding_length=1)
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {
            'num_entities': 3,
            'num_entity_tokens': 3,
            'num_utterance_tokens': 4,
            'num_token_characters': 8
        }

    def test_as_tensor_produces_correct_output(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        padding_lengths['num_utterance_tokens'] += 1
        padding_lengths['num_entities'] += 1
        tensor_dict = self.field.as_tensor(padding_lengths)
        assert tensor_dict.keys() == {'text', 'linking'}
        expected_text_tensor = [
            [self.mersin_index, 0, 0],
            [self.location_index, self.in_index, self.english_index],
            [self.name_index, self.in_index, self.english_index], [0, 0, 0]
        ]
        assert_almost_equal(
            tensor_dict['text']['tokens'].detach().cpu().numpy(),
            expected_text_tensor)

        linking_tensor = tensor_dict['linking'].detach().cpu().numpy()
        expected_linking_tensor = [
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # string:mersin, "where"
                [0, 0, 0, 0, 0, -1.5, 0, 0, 0, 0],  # string:mersin, "is"
                [0, 1, 1, 1, 1, 1, 0, 0, 1, 1],  # string:mersin, "mersin"
                [0, 0, 0, 0, 0, -5, 0, 0, 0, 0],  # string:mersin, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # string:mersin, padding
            [
                [0, 0, 0, 0, 0, -2.6, 0, 0, 0,
                 0],  # string_column:name_in_english, "where"
                [0, 0, 0, 0, 0, -7.5, 0, 0, 0,
                 0],  # string_column:name_in_english, "is"
                [0, 0, 0, 0, 0, -1.8333, 1, 1, 0,
                 0],  # string_column:..in_english, "mersin"
                [0, 0, 0, 0, 0, -18, 0, 0, 0,
                 0],  # string_column:name_in_english, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # string_column:name_in_english, padding
            [
                [0, 0, 0, 0, 0, -1.6, 0, 0, 0,
                 0],  # string_..:location_in_english, "where"
                [0, 0, 0, 0, 0, -5.5, 0, 0, 0,
                 0],  # string_column:location_in_english, "is"
                [0, 0, 0, 0, 0, -1, 0, 0, 0,
                 0],  # string_column:location_in_english, "mersin"
                [0, 0, 0, 0, 0, -14, 0, 0, 0,
                 0],  # string_column:location_in_english, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # string_column:location_in_english, padding
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "where"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "is"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "mersin"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ]
        ]  # padding, padding
        for entity_index, entity_features in enumerate(
                expected_linking_tensor):
            for question_index, feature_vector in enumerate(entity_features):
                assert_almost_equal(linking_tensor[entity_index,
                                                   question_index],
                                    feature_vector,
                                    decimal=4,
                                    err_msg=f"{entity_index} {question_index}")

    def test_lemma_feature_extractor(self):
        # pylint: disable=protected-access
        utterance = self.tokenizer.tokenize("Names in English")
        field = KnowledgeGraphField(self.graph, self.utterance,
                                    self.token_indexers, self.tokenizer)
        entity = 'string_column:name_in_english'
        lemma_feature = field._contains_lemma_match(
            entity, field._entity_text_map[entity], utterance[0], 0, utterance)
        assert lemma_feature == 1

    def test_span_overlap_fraction(self):
        # pylint: disable=protected-access
        utterance = self.tokenizer.tokenize(
            "what is the name in english of mersin?")
        field = KnowledgeGraphField(self.graph, self.utterance,
                                    self.token_indexers, self.tokenizer)
        entity = 'string_column:name_in_english'
        entity_text = field._entity_text_map[entity]
        feature_values = [
            field._span_overlap_fraction(entity, entity_text, token, i,
                                         utterance)
            for i, token in enumerate(utterance)
        ]
        assert feature_values == [0, 0, 0, 1, 1, 1, 0, 0, 0]

    def test_batch_tensors(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        tensor_dict1 = self.field.as_tensor(padding_lengths)
        tensor_dict2 = self.field.as_tensor(padding_lengths)
        batched_tensor_dict = self.field.batch_tensors(
            [tensor_dict1, tensor_dict2])
        assert batched_tensor_dict.keys() == {'text', 'linking'}
        expected_single_tensor = [
            [self.mersin_index, 0, 0],
            [self.location_index, self.in_index, self.english_index],
            [self.name_index, self.in_index, self.english_index]
        ]
        expected_batched_tensor = [
            expected_single_tensor, expected_single_tensor
        ]
        assert_almost_equal(
            batched_tensor_dict['text']['tokens'].detach().cpu().numpy(),
            expected_batched_tensor)
        expected_linking_tensor = torch.stack(
            [tensor_dict1['linking'], tensor_dict2['linking']])
        assert_almost_equal(
            batched_tensor_dict['linking'].detach().cpu().numpy(),
            expected_linking_tensor.detach().cpu().numpy())

    def test_field_initialized_with_empty_constructor(self):
        try:
            self.field.empty_field()
        except AssertionError as e:
            pytest.fail(str(e), pytrace=True)
Exemple #13
0
    def text_to_instance(
            self,  # type: ignore
            question: str,
            table_lines: List[str],
            example_lisp_string: str = None,
            dpd_output: List[str] = None,
            tokenized_question: List[Token] = None) -> Instance:
        """
        Reads text inputs and makes an instance. WikitableQuestions dataset provides tables as TSV
        files, which we use for training.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[str]``
            The table content itself, as a list of rows. See
            ``TableQuestionKnowledgeGraph.read_from_lines`` for the expected format.
        example_lisp_string : ``str``, optional
            The original (lisp-formatted) example string in the WikiTableQuestions dataset.  This
            comes directly from the ``.examples`` file provided with the dataset.  We pass this to
            SEMPRE for evaluating logical forms during training.  It isn't otherwise used for
            anything.
        dpd_output : List[str], optional
            List of logical forms, produced by dynamic programming on denotations. Not required
            during test.
        tokenized_question : ``List[Token]``, optional
            If you have already tokenized the question, you can pass that in here, so we don't
            duplicate that work.  You might, for example, do batch processing on the questions in
            the whole dataset, then pass the result in here.
        """
        # pylint: disable=arguments-differ
        tokenized_question = tokenized_question or self._tokenizer.tokenize(
            question.lower())
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)
        metadata: Dict[str, Any] = {
            "question_tokens": [x.text for x in tokenized_question]
        }
        metadata["original_table"] = "\n".join(table_lines)
        table_knowledge_graph = TableQuestionKnowledgeGraph.read_from_lines(
            table_lines, tokenized_question)
        table_metadata = MetadataField(table_lines)
        table_field = KnowledgeGraphField(
            table_knowledge_graph,
            tokenized_question,
            self._table_token_indexers,
            tokenizer=self._tokenizer,
            feature_extractors=self._linking_feature_extractors,
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens)
        world = WikiTablesWorld(table_knowledge_graph)
        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            'question': question_field,
            'metadata': MetadataField(metadata),
            'table': table_field,
            'world': world_field,
            'actions': action_field
        }
        if self._include_table_metadata:
            fields['table_metadata'] = table_metadata
        if example_lisp_string:
            fields['example_lisp_string'] = MetadataField(example_lisp_string)

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)
        }  # type: ignore
        if dpd_output:
            action_sequence_fields: List[Field] = []
            for logical_form in dpd_output:
                if not self._should_keep_logical_form(logical_form):
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                try:
                    expression = world.parse_logical_form(logical_form)
                except ParsingError as error:
                    logger.debug(
                        f'Parsing error: {error.message}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Logical form was: {logical_form}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                except:
                    logger.error(logical_form)
                    raise
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.debug(
                        f'Missing production rule: {error.args}, skipping logical form'
                    )
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    logger.debug(f'Logical form was: {logical_form}')
                    continue
                if len(action_sequence_fields) >= self._max_dpd_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda():
                agenda_index_fields.append(
                    IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)
Exemple #14
0
    def text_to_instance(
        self,  # type: ignore
        question: str,
        logical_forms: List[str] = None,
        additional_metadata: Dict[str, Any] = None,
        world_extractions: Dict[str, Union[str, List[str]]] = None,
        entity_literals: Dict[str, Union[str, List[str]]] = None,
        tokenized_question: List[Token] = None,
        debug_counter: int = None,
        qr_spec_override: List[Dict[str, int]] = None,
        dynamic_entities_override: Dict[str, str] = None,
    ) -> Instance:

        tokenized_question = tokenized_question or self._tokenizer.tokenize(question.lower())
        additional_metadata = additional_metadata or dict()
        additional_metadata["question_tokens"] = [token.text for token in tokenized_question]
        if world_extractions is not None:
            additional_metadata["world_extractions"] = world_extractions
        question_field = TextField(tokenized_question, self._question_token_indexers)

        if qr_spec_override is not None or dynamic_entities_override is not None:
            # Dynamically specify theory and/or entities
            dynamic_entities = dynamic_entities_override or self._dynamic_entities
            neighbors: Dict[str, List[str]] = {key: [] for key in dynamic_entities.keys()}
            knowledge_graph = KnowledgeGraph(
                entities=set(dynamic_entities.keys()),
                neighbors=neighbors,
                entity_text=dynamic_entities,
            )
            world = QuarelWorld(knowledge_graph, self._lf_syntax, qr_coeff_sets=qr_spec_override)
        else:
            knowledge_graph = self._knowledge_graph
            world = self._world

        table_field = KnowledgeGraphField(
            knowledge_graph,
            tokenized_question,
            self._entity_token_indexers,
            tokenizer=self._tokenizer,
        )

        if self._tagger_only:
            fields: Dict[str, Field] = {"tokens": question_field}
            if entity_literals is not None:
                entity_tags = self._get_entity_tags(
                    self._all_entities, table_field, entity_literals, tokenized_question
                )
                if debug_counter > 0:
                    logger.info(f"raw entity tags = {entity_tags}")
                entity_tags_bio = self._convert_tags_bio(entity_tags)
                fields["tags"] = SequenceLabelField(entity_tags_bio, question_field)
                additional_metadata["tags_gold"] = entity_tags_bio
            additional_metadata["words"] = [x.text for x in tokenized_question]
            fields["metadata"] = MetadataField(additional_metadata)
            return Instance(fields)

        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(" -> ")
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            "question": question_field,
            "table": table_field,
            "world": world_field,
            "actions": action_field,
        }

        if self._denotation_only:
            denotation_field = LabelField(additional_metadata["answer_index"], skip_indexing=True)
            fields["denotation_target"] = denotation_field

        if self._entity_bits_mode is not None and world_extractions is not None:
            entity_bits = self._get_entity_tags(
                ["world1", "world2"], table_field, world_extractions, tokenized_question
            )
            if self._entity_bits_mode == "simple":
                entity_bits_v = [[[0, 0], [1, 0], [0, 1]][tag] for tag in entity_bits]
            elif self._entity_bits_mode == "simple_collapsed":
                entity_bits_v = [[[0], [1], [1]][tag] for tag in entity_bits]
            elif self._entity_bits_mode == "simple3":
                entity_bits_v = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]][tag] for tag in entity_bits]

            entity_bits_field = ArrayField(np.array(entity_bits_v))
            fields["entity_bits"] = entity_bits_field

        if logical_forms:
            action_map = {
                action.rule: i for i, action in enumerate(action_field.field_list)  # type: ignore
            }
            action_sequence_fields: List[Field] = []
            for logical_form in logical_forms:
                expression = world.parse_logical_form(logical_form)
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(IndexField(action_map[production_rule], action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.info(f"Missing production rule: {error.args}, skipping logical form")
                    logger.info(f"Question was: {question}")
                    logger.info(f"Logical form was: {logical_form}")
                    continue
            fields["target_action_sequences"] = ListField(action_sequence_fields)
        fields["metadata"] = MetadataField(additional_metadata or {})
        return Instance(fields)
class KnowledgeGraphFieldTest(AllenNlpTestCase):
    def setUp(self):
        self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        json = {
            'question': self.utterance,
            'columns': ['Name in English', 'Location in English'],
            'cells': [['Paradeniz', 'Mersin'], ['Lake Gala', 'Edirne']]
        }
        self.graph = TableQuestionKnowledgeGraph.read_from_json(json)
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name",
                                                            namespace='tokens')
        self.in_index = self.vocab.add_token_to_namespace("in",
                                                          namespace='tokens')
        self.english_index = self.vocab.add_token_to_namespace(
            "english", namespace='tokens')
        self.location_index = self.vocab.add_token_to_namespace(
            "location", namespace='tokens')
        self.paradeniz_index = self.vocab.add_token_to_namespace(
            "paradeniz", namespace='tokens')
        self.mersin_index = self.vocab.add_token_to_namespace(
            "mersin", namespace='tokens')
        self.lake_index = self.vocab.add_token_to_namespace("lake",
                                                            namespace='tokens')
        self.gala_index = self.vocab.add_token_to_namespace("gala",
                                                            namespace='tokens')
        self.negative_one_index = self.vocab.add_token_to_namespace(
            "-1", namespace='tokens')
        self.zero_index = self.vocab.add_token_to_namespace("0",
                                                            namespace='tokens')
        self.one_index = self.vocab.add_token_to_namespace("1",
                                                           namespace='tokens')

        self.oov_index = self.vocab.get_token_index('random OOV string',
                                                    namespace='tokens')
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance,
                                         self.token_indexers, self.tokenizer)

        super(KnowledgeGraphFieldTest, self).setUp()

    def test_count_vocab_items(self):
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        self.field.count_vocab_items(namespace_token_counts)

        assert namespace_token_counts["tokens"] == {
            '-1': 1,
            '0': 1,
            '1': 1,
            'name': 1,
            'in': 2,
            'english': 2,
            'location': 1,
            'paradeniz': 1,
            'mersin': 1,
            'lake': 1,
            'gala': 1,
            'edirne': 1,
        }

    def test_index_converts_field_correctly(self):
        # pylint: disable=protected-access
        self.field.index(self.vocab)
        assert self.field._indexed_entity_texts.keys() == {'tokens'}
        # Note that these are sorted by their _identifiers_, not their cell text, so the
        # `fb:row.rows` show up after the `fb:cells`.
        expected_array = [[self.negative_one_index], [self.zero_index],
                          [self.one_index], [self.edirne_index],
                          [self.lake_index, self.gala_index],
                          [self.mersin_index], [self.paradeniz_index],
                          [
                              self.location_index, self.in_index,
                              self.english_index
                          ],
                          [self.name_index, self.in_index, self.english_index]]
        assert self.field._indexed_entity_texts['tokens'] == expected_array

    def test_get_padding_lengths_raises_if_not_indexed(self):
        with pytest.raises(AssertionError):
            self.field.get_padding_lengths()

    def test_padding_lengths_are_computed_correctly(self):
        # pylint: disable=protected-access
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {
            'num_entities': 9,
            'num_entity_tokens': 3,
            'num_utterance_tokens': 4
        }
        self.field._token_indexers[
            'token_characters'] = TokenCharactersIndexer(min_padding_length=1)
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {
            'num_entities': 9,
            'num_entity_tokens': 3,
            'num_utterance_tokens': 4,
            'num_token_characters': 9
        }

    def test_as_tensor_produces_correct_output(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        padding_lengths['num_utterance_tokens'] += 1
        padding_lengths['num_entities'] += 1
        tensor_dict = self.field.as_tensor(padding_lengths)
        assert tensor_dict.keys() == {'text', 'linking'}
        expected_text_tensor = [
            [self.negative_one_index, 0, 0], [self.zero_index, 0, 0],
            [self.one_index, 0, 0], [self.edirne_index, 0, 0],
            [self.lake_index, self.gala_index, 0], [self.mersin_index, 0, 0],
            [self.paradeniz_index, 0, 0],
            [self.location_index, self.in_index, self.english_index],
            [self.name_index, self.in_index, self.english_index], [0, 0, 0]
        ]
        assert_almost_equal(
            tensor_dict['text']['tokens'].detach().cpu().numpy(),
            expected_text_tensor)

        linking_tensor = tensor_dict['linking'].detach().cpu().numpy()
        expected_linking_tensor = [
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # -1, "where"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # -1, "is"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # -1, "mersin"
                [0, 0, 0, 0, 0, -1, 0, 0, 0, 0]
            ],  # -1, "?"
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 0, "where"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 0, "is"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 0, "mersin"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # 0, "?"
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 1, "where"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 1, "is"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 1, "mersin"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # 1, "?"
            [
                [0, 0, 0, 0, 0, .2, 0, 0, 0, 0],  # fb:cell.edirne, "where"
                [0, 0, 0, 0, 0, -1.5, 0, 0, 0, 0],  # fb:cell.edirne, "is"
                [0, 0, 0, 0, 0, .1666, 0, 0, 0, 0],  # fb:cell.edirne, "mersin"
                [0, 0, 0, 0, 0, -5, 0, 0, 0, 0],  # fb:cell.edirne, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # fb:cell.edirne, padding
            [
                [0, 0, 0, 0, 0, -.6, 0, 0, 0, 0],  # fb:cell.lake_gala, "where"
                [0, 0, 0, 0, 0, -3.5, 0, 0, 0, 0],  # fb:cell.lake_gala, "is"
                [0, 0, 0, 0, 0, -.3333, 0, 0, 0,
                 0],  # fb:cell.lake_gala, "mersin"
                [0, 0, 0, 0, 0, -8, 0, 0, 0, 0],  # fb:cell.lake_gala, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # fb:cell.lake_gala, padding
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # fb:cell.mersin, "where"
                [0, 0, 0, 0, 0, -1.5, 0, 0, 0, 0],  # fb:cell.mersin, "is"
                [0, 1, 1, 1, 1, 1, 0, 0, 1, 1],  # fb:cell.mersin, "mersin"
                [0, 0, 0, 0, 0, -5, 0, 0, 0, 0],  # fb:cell.mersin, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # fb:cell.mersin, padding
            [
                [0, 0, 0, 0, 0, -.6, 0, 0, 0, 0],  # fb:cell.paradeniz, "where"
                [0, 0, 0, 0, 0, -3, 0, 0, 0, 0],  # fb:cell.paradeniz, "is"
                [0, 0, 0, 0, 0, -.1666, 0, 0, 0,
                 0],  # fb:cell.paradeniz, "mersin"
                [0, 0, 0, 0, 0, -8, 0, 0, 0, 0],  # fb:cell.paradeniz, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # fb:cell.paradeniz, padding
            [
                [0, 0, 0, 0, 0, -2.6, 0, 0, 0,
                 0],  # fb:row.row.name_in_english, "where"
                [0, 0, 0, 0, 0, -7.5, 0, 0, 0,
                 0],  # fb:row.row.name_in_english, "is"
                [0, 0, 0, 0, 0, -1.8333, 1, 1, 0,
                 0],  # fb:row.row.name_in_english, "mersin"
                [0, 0, 0, 0, 0, -18, 0, 0, 0,
                 0],  # fb:row.row.name_in_english, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # fb:row.row.name_in_english, padding
            [
                [0, 0, 0, 0, 0, -1.6, 0, 0, 0,
                 0],  # fb:row.row.location_in_english, "where"
                [0, 0, 0, 0, 0, -5.5, 0, 0, 0,
                 0],  # fb:row.row.location_in_english, "is"
                [0, 0, 0, 0, 0, -1, 0, 0, 0,
                 0],  # fb:row.row.location_in_english, "mersin"
                [0, 0, 0, 0, 0, -14, 0, 0, 0,
                 0],  # fb:row.row.location_in_english, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ],  # fb:row.row.location_in_english, padding
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "where"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "is"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "mersin"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
            ]
        ]  # padding, padding
        for entity_index, entity_features in enumerate(
                expected_linking_tensor):
            for question_index, feature_vector in enumerate(entity_features):
                assert_almost_equal(linking_tensor[entity_index,
                                                   question_index],
                                    feature_vector,
                                    decimal=4,
                                    err_msg=f"{entity_index} {question_index}")

    def test_lemma_feature_extractor(self):
        # pylint: disable=protected-access
        utterance = self.tokenizer.tokenize("Names in English")
        field = KnowledgeGraphField(self.graph, self.utterance,
                                    self.token_indexers, self.tokenizer)
        entity = 'fb:row.row.name_in_english'
        lemma_feature = field._contains_lemma_match(
            entity, field._entity_text_map[entity], utterance[0], 0, utterance)
        assert lemma_feature == 1

    def test_span_overlap_fraction(self):
        # pylint: disable=protected-access
        utterance = self.tokenizer.tokenize(
            "what is the name in english of mersin?")
        field = KnowledgeGraphField(self.graph, self.utterance,
                                    self.token_indexers, self.tokenizer)
        entity = 'fb:row.row.name_in_english'
        entity_text = field._entity_text_map[entity]
        feature_values = [
            field._span_overlap_fraction(entity, entity_text, token, i,
                                         utterance)
            for i, token in enumerate(utterance)
        ]
        assert feature_values == [0, 0, 0, 1, 1, 1, 0, 0, 0]

    def test_batch_tensors(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        tensor_dict1 = self.field.as_tensor(padding_lengths)
        tensor_dict2 = self.field.as_tensor(padding_lengths)
        batched_tensor_dict = self.field.batch_tensors(
            [tensor_dict1, tensor_dict2])
        assert batched_tensor_dict.keys() == {'text', 'linking'}
        expected_single_tensor = [
            [self.negative_one_index, 0, 0], [self.zero_index, 0, 0],
            [self.one_index, 0, 0], [self.edirne_index, 0, 0],
            [self.lake_index, self.gala_index, 0], [self.mersin_index, 0, 0],
            [self.paradeniz_index, 0, 0],
            [self.location_index, self.in_index, self.english_index],
            [self.name_index, self.in_index, self.english_index]
        ]
        expected_batched_tensor = [
            expected_single_tensor, expected_single_tensor
        ]
        assert_almost_equal(
            batched_tensor_dict['text']['tokens'].detach().cpu().numpy(),
            expected_batched_tensor)
        expected_linking_tensor = torch.stack(
            [tensor_dict1['linking'], tensor_dict2['linking']])
        assert_almost_equal(
            batched_tensor_dict['linking'].detach().cpu().numpy(),
            expected_linking_tensor.detach().cpu().numpy())
Exemple #16
0
    def text_to_instance(self,  # type: ignore
                         question: str,
                         table_lines: List[List[str]],
                         answer_json: JsonDict,
                         offline_search_output: List[str] = None) -> Instance:
        """
        Reads text inputs and makes an instance. We assume we have access to DROP paragraphs parsed
        and tagged in a format similar to the tagged tables in WikiTableQuestions.
        # TODO(pradeep): Explain the format.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[List[str]]``
            Preprocessed paragraph content. See ``ParagraphQuestionContext.read_from_lines``
            for the expected format.
        answer_json : ``JsonDict``
            The "answer" dict from the original data file.
        offline_search_output : List[str], optional
            List of logical forms, produced by offline search. Not required during test.
        """
        # pylint: disable=arguments-differ
        tokenized_question = self._tokenizer.tokenize(question.lower())
        question_field = TextField(tokenized_question, self._question_token_indexers)
        # TODO(pradeep): We'll need a better way to input processed lines.
        paragraph_context = ParagraphQuestionContext.read_from_lines(table_lines,
                                                                     tokenized_question,
                                                                     self._entity_extraction_embedding,
                                                                     self._entity_extraction_distance_threshold)
        target_values_field = MetadataField(answer_json)
        world = DropWorld(paragraph_context)
        world_field = MetadataField(world)
        # Note: Not passing any featre extractors when instantiating the field below. This will make
        # it use all the available extractors.
        table_field = KnowledgeGraphField(paragraph_context.get_knowledge_graph(),
                                          tokenized_question,
                                          self._table_token_indexers,
                                          tokenizer=self._tokenizer,
                                          include_in_vocab=self._use_table_for_vocab,
                                          max_table_tokens=self._max_table_tokens)
        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_instance_specific_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule=is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {'question': question_field,
                  'table': table_field,
                  'world': world_field,
                  'actions': action_field,
                  'target_values': target_values_field}

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {action.rule: i for i, action in enumerate(action_field.field_list)}  # type: ignore
        if offline_search_output:
            action_sequence_fields: List[Field] = []
            for logical_form in offline_search_output:
                try:
                    expression = world.parse_logical_form(logical_form)
                except ParsingError as error:
                    logger.debug(f'Parsing error: {error.message}, skipping logical form')
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Logical form was: {logical_form}')
                    logger.debug(f'Table info was: {table_lines}')
                    continue
                except:
                    logger.error(logical_form)
                    raise
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(IndexField(action_map[production_rule], action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.debug(f'Missing production rule: {error.args}, skipping logical form')
                    logger.debug(f'Question was: {question}')
                    logger.debug(f'Table info was: {table_lines}')
                    logger.debug(f'Logical form was: {logical_form}')
                    continue
                if len(action_sequence_fields) >= self._max_offline_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields['target_action_sequences'] = ListField(action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda():
                agenda_index_fields.append(IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields['agenda'] = ListField(agenda_index_fields)
        return Instance(fields)
class KnowledgeGraphFieldTest(AllenNlpTestCase):
    def setUp(self):
        self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        json = {
                'question': self.utterance,
                'columns': ['Name in English', 'Location in English'],
                'cells': [['Paradeniz', 'Mersin'],
                          ['Lake Gala', 'Edirne']]
                }
        self.graph = TableQuestionKnowledgeGraph.read_from_json(json)
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name", namespace='tokens')
        self.in_index = self.vocab.add_token_to_namespace("in", namespace='tokens')
        self.english_index = self.vocab.add_token_to_namespace("english", namespace='tokens')
        self.location_index = self.vocab.add_token_to_namespace("location", namespace='tokens')
        self.paradeniz_index = self.vocab.add_token_to_namespace("paradeniz", namespace='tokens')
        self.mersin_index = self.vocab.add_token_to_namespace("mersin", namespace='tokens')
        self.lake_index = self.vocab.add_token_to_namespace("lake", namespace='tokens')
        self.gala_index = self.vocab.add_token_to_namespace("gala", namespace='tokens')
        self.negative_one_index = self.vocab.add_token_to_namespace("-1", namespace='tokens')
        self.zero_index = self.vocab.add_token_to_namespace("0", namespace='tokens')
        self.one_index = self.vocab.add_token_to_namespace("1", namespace='tokens')

        self.oov_index = self.vocab.get_token_index('random OOV string', namespace='tokens')
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)

        super(KnowledgeGraphFieldTest, self).setUp()

    def test_count_vocab_items(self):
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        self.field.count_vocab_items(namespace_token_counts)

        assert namespace_token_counts["tokens"] == {
                '-1': 1,
                '0': 1,
                '1': 1,
                'name': 1,
                'in': 2,
                'english': 2,
                'location': 1,
                'paradeniz': 1,
                'mersin': 1,
                'lake': 1,
                'gala': 1,
                'edirne': 1,
                }

    def test_index_converts_field_correctly(self):
        # pylint: disable=protected-access
        self.field.index(self.vocab)
        assert self.field._indexed_entity_texts.keys() == {'tokens'}
        # Note that these are sorted by their _identifiers_, not their cell text, so the
        # `fb:row.rows` show up after the `fb:cells`.
        expected_array = [[self.negative_one_index],
                          [self.zero_index],
                          [self.one_index],
                          [self.edirne_index],
                          [self.lake_index, self.gala_index],
                          [self.mersin_index],
                          [self.paradeniz_index],
                          [self.location_index, self.in_index, self.english_index],
                          [self.name_index, self.in_index, self.english_index]]
        assert self.field._indexed_entity_texts['tokens'] == expected_array

    def test_get_padding_lengths_raises_if_not_indexed(self):
        with pytest.raises(AssertionError):
            self.field.get_padding_lengths()

    def test_padding_lengths_are_computed_correctly(self):
        # pylint: disable=protected-access
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {'num_entities': 9, 'num_entity_tokens': 3,
                                                    'num_utterance_tokens': 4}
        self.field._token_indexers['token_characters'] = TokenCharactersIndexer()
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {'num_entities': 9, 'num_entity_tokens': 3,
                                                    'num_utterance_tokens': 4,
                                                    'num_token_characters': 9}

    def test_as_tensor_produces_correct_output(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        padding_lengths['num_utterance_tokens'] += 1
        padding_lengths['num_entities'] += 1
        tensor_dict = self.field.as_tensor(padding_lengths)
        assert tensor_dict.keys() == {'text', 'linking'}
        expected_text_tensor = [[self.negative_one_index, 0, 0],
                                [self.zero_index, 0, 0],
                                [self.one_index, 0, 0],
                                [self.edirne_index, 0, 0],
                                [self.lake_index, self.gala_index, 0],
                                [self.mersin_index, 0, 0],
                                [self.paradeniz_index, 0, 0],
                                [self.location_index, self.in_index, self.english_index],
                                [self.name_index, self.in_index, self.english_index],
                                [0, 0, 0]]
        assert_almost_equal(tensor_dict['text']['tokens'].detach().cpu().numpy(), expected_text_tensor)

        linking_tensor = tensor_dict['linking'].detach().cpu().numpy()
        expected_linking_tensor = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # -1, "where"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # -1, "is"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # -1, "mersin"
                                    [0, 0, 0, 0, 0, -1, 0, 0, 0, 0]],  # -1, "?"
                                   [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 0, "where"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 0, "is"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 0, "mersin"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # 0, "?"
                                   [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 1, "where"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 1, "is"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # 1, "mersin"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # 1, "?"
                                   [[0, 0, 0, 0, 0, .2, 0, 0, 0, 0],  # fb:cell.edirne, "where"
                                    [0, 0, 0, 0, 0, -1.5, 0, 0, 0, 0],  # fb:cell.edirne, "is"
                                    [0, 0, 0, 0, 0, .1666, 0, 0, 0, 0],  # fb:cell.edirne, "mersin"
                                    [0, 0, 0, 0, 0, -5, 0, 0, 0, 0],  # fb:cell.edirne, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # fb:cell.edirne, padding
                                   [[0, 0, 0, 0, 0, -.6, 0, 0, 0, 0],  # fb:cell.lake_gala, "where"
                                    [0, 0, 0, 0, 0, -3.5, 0, 0, 0, 0],  # fb:cell.lake_gala, "is"
                                    [0, 0, 0, 0, 0, -.3333, 0, 0, 0, 0],  # fb:cell.lake_gala, "mersin"
                                    [0, 0, 0, 0, 0, -8, 0, 0, 0, 0],  # fb:cell.lake_gala, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # fb:cell.lake_gala, padding
                                   [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # fb:cell.mersin, "where"
                                    [0, 0, 0, 0, 0, -1.5, 0, 0, 0, 0],  # fb:cell.mersin, "is"
                                    [0, 1, 1, 1, 1, 1, 0, 0, 1, 1],  # fb:cell.mersin, "mersin"
                                    [0, 0, 0, 0, 0, -5, 0, 0, 0, 0],  # fb:cell.mersin, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # fb:cell.mersin, padding
                                   [[0, 0, 0, 0, 0, -.6, 0, 0, 0, 0],  # fb:cell.paradeniz, "where"
                                    [0, 0, 0, 0, 0, -3, 0, 0, 0, 0],  # fb:cell.paradeniz, "is"
                                    [0, 0, 0, 0, 0, -.1666, 0, 0, 0, 0],  # fb:cell.paradeniz, "mersin"
                                    [0, 0, 0, 0, 0, -8, 0, 0, 0, 0],  # fb:cell.paradeniz, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # fb:cell.paradeniz, padding
                                   [[0, 0, 0, 0, 0, -2.6, 0, 0, 0, 0],  # fb:row.row.name_in_english, "where"
                                    [0, 0, 0, 0, 0, -7.5, 0, 0, 0, 0],  # fb:row.row.name_in_english, "is"
                                    [0, 0, 0, 0, 0, -1.8333, 1, 1, 0, 0],  # fb:row.row.name_in_english, "mersin"
                                    [0, 0, 0, 0, 0, -18, 0, 0, 0, 0],  # fb:row.row.name_in_english, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # fb:row.row.name_in_english, padding
                                   [[0, 0, 0, 0, 0, -1.6, 0, 0, 0, 0],  # fb:row.row.location_in_english, "where"
                                    [0, 0, 0, 0, 0, -5.5, 0, 0, 0, 0],  # fb:row.row.location_in_english, "is"
                                    [0, 0, 0, 0, 0, -1, 0, 0, 0, 0],  # fb:row.row.location_in_english, "mersin"
                                    [0, 0, 0, 0, 0, -14, 0, 0, 0, 0],  # fb:row.row.location_in_english, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # fb:row.row.location_in_english, padding
                                   [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "where"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "is"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "mersin"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]  # padding, padding
        for entity_index, entity_features in enumerate(expected_linking_tensor):
            for question_index, feature_vector in enumerate(entity_features):
                assert_almost_equal(linking_tensor[entity_index, question_index],
                                    feature_vector,
                                    decimal=4,
                                    err_msg=f"{entity_index} {question_index}")

    def test_lemma_feature_extractor(self):
        # pylint: disable=protected-access
        utterance = self.tokenizer.tokenize("Names in English")
        field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)
        entity = 'fb:row.row.name_in_english'
        lemma_feature = field._contains_lemma_match(entity,
                                                    field._entity_text_map[entity],
                                                    utterance[0],
                                                    0,
                                                    utterance)
        assert lemma_feature == 1

    def test_span_overlap_fraction(self):
        # pylint: disable=protected-access
        utterance = self.tokenizer.tokenize("what is the name in english of mersin?")
        field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)
        entity = 'fb:row.row.name_in_english'
        entity_text = field._entity_text_map[entity]
        feature_values = [field._span_overlap_fraction(entity, entity_text, token, i, utterance)
                          for i, token in enumerate(utterance)]
        assert feature_values == [0, 0, 0, 1, 1, 1, 0, 0, 0]

    def test_batch_tensors(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        tensor_dict1 = self.field.as_tensor(padding_lengths)
        tensor_dict2 = self.field.as_tensor(padding_lengths)
        batched_tensor_dict = self.field.batch_tensors([tensor_dict1, tensor_dict2])
        assert batched_tensor_dict.keys() == {'text', 'linking'}
        expected_single_tensor = [[self.negative_one_index, 0, 0],
                                  [self.zero_index, 0, 0],
                                  [self.one_index, 0, 0],
                                  [self.edirne_index, 0, 0],
                                  [self.lake_index, self.gala_index, 0],
                                  [self.mersin_index, 0, 0],
                                  [self.paradeniz_index, 0, 0],
                                  [self.location_index, self.in_index, self.english_index],
                                  [self.name_index, self.in_index, self.english_index]]
        expected_batched_tensor = [expected_single_tensor, expected_single_tensor]
        assert_almost_equal(batched_tensor_dict['text']['tokens'].detach().cpu().numpy(),
                            expected_batched_tensor)
        expected_linking_tensor = torch.stack([tensor_dict1['linking'], tensor_dict2['linking']])
        assert_almost_equal(batched_tensor_dict['linking'].detach().cpu().numpy(),
                            expected_linking_tensor.detach().cpu().numpy())