def setup_method(self):
        self.tokenizer = SpacyTokenizer(pos_tags=True)
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        table_file = self.FIXTURES_ROOT / "data" / "wikitables" / "tables" / "341.tagged"
        self.graph = TableQuestionContext.read_from_file(
            table_file, self.utterance).get_table_knowledge_graph()
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name",
                                                            namespace="tokens")
        self.in_index = self.vocab.add_token_to_namespace("in",
                                                          namespace="tokens")
        self.english_index = self.vocab.add_token_to_namespace(
            "english", namespace="tokens")
        self.location_index = self.vocab.add_token_to_namespace(
            "location", namespace="tokens")
        self.mersin_index = self.vocab.add_token_to_namespace(
            "mersin", namespace="tokens")

        self.oov_index = self.vocab.get_token_index("random OOV string",
                                                    namespace="tokens")
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance,
                                         self.token_indexers, self.tokenizer)

        super().setup_method()
 def test_lemma_feature_extractor(self):
     utterance = self.tokenizer.tokenize("Names in English")
     field = KnowledgeGraphField(self.graph, self.utterance,
                                 self.token_indexers, self.tokenizer)
     entity = "string_column:name_in_english"
     lemma_feature = field._contains_lemma_match(
         entity, field._entity_text_map[entity], utterance[0], 0, utterance)
     assert lemma_feature == 1
 def test_span_overlap_fraction(self):
     # pylint: disable=protected-access
     utterance = self.tokenizer.tokenize("what is the name in english of mersin?")
     field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)
     entity = 'string_column:name_in_english'
     entity_text = field._entity_text_map[entity]
     feature_values = [field._span_overlap_fraction(entity, entity_text, token, i, utterance)
                       for i, token in enumerate(utterance)]
     assert feature_values == [0, 0, 0, 1, 1, 1, 0, 0, 0]
Пример #4
0
    def text_to_instance(
            self,  # type: ignore
            question: str,
            logical_forms: List[str] = None,
            additional_metadata: Dict[str, Any] = None,
            world_extractions: Dict[str, Union[str, List[str]]] = None,
            entity_literals: Dict[str, Union[str, List[str]]] = None,
            tokenized_question: List[Token] = None,
            debug_counter: int = None,
            qr_spec_override: List[Dict[str, int]] = None,
            dynamic_entities_override: Dict[str, str] = None) -> Instance:

        # pylint: disable=arguments-differ
        tokenized_question = tokenized_question or self._tokenizer.tokenize(
            question.lower())
        additional_metadata = additional_metadata or dict()
        additional_metadata['question_tokens'] = [
            token.text for token in tokenized_question
        ]
        if world_extractions is not None:
            additional_metadata['world_extractions'] = world_extractions
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)

        if qr_spec_override is not None or dynamic_entities_override is not None:
            # Dynamically specify theory and/or entities
            dynamic_entities = dynamic_entities_override or self._dynamic_entities
            neighbors: Dict[str, List[str]] = {
                key: []
                for key in dynamic_entities.keys()
            }
            knowledge_graph = KnowledgeGraph(entities=set(
                dynamic_entities.keys()),
                                             neighbors=neighbors,
                                             entity_text=dynamic_entities)
            world = QuarelWorld(knowledge_graph,
                                self._lf_syntax,
                                qr_coeff_sets=qr_spec_override)
        else:
            knowledge_graph = self._knowledge_graph
            world = self._world

        table_field = KnowledgeGraphField(knowledge_graph,
                                          tokenized_question,
                                          self._entity_token_indexers,
                                          tokenizer=self._tokenizer)

        if self._tagger_only:
            fields: Dict[str, Field] = {'tokens': question_field}
            if entity_literals is not None:
                entity_tags = self._get_entity_tags(self._all_entities,
                                                    table_field,
                                                    entity_literals,
                                                    tokenized_question)
                if debug_counter > 0:
                    logger.info(f'raw entity tags = {entity_tags}')
                entity_tags_bio = self._convert_tags_bio(entity_tags)
                fields['tags'] = SequenceLabelField(entity_tags_bio,
                                                    question_field)
                additional_metadata['tags_gold'] = entity_tags_bio
            additional_metadata['words'] = [x.text for x in tokenized_question]
            fields['metadata'] = MetadataField(additional_metadata)
            return Instance(fields)

        world_field = MetadataField(world)

        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_actions():
            _, rule_right_side = production_rule.split(' -> ')
            is_global_rule = not world.is_table_entity(rule_right_side)
            field = ProductionRuleField(production_rule, is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            'question': question_field,
            'table': table_field,
            'world': world_field,
            'actions': action_field
        }

        if self._denotation_only:
            denotation_field = LabelField(additional_metadata['answer_index'],
                                          skip_indexing=True)
            fields['denotation_target'] = denotation_field

        if self._entity_bits_mode is not None and world_extractions is not None:
            entity_bits = self._get_entity_tags(['world1', 'world2'],
                                                table_field, world_extractions,
                                                tokenized_question)
            if self._entity_bits_mode == "simple":
                entity_bits_v = [[[0, 0], [1, 0], [0, 1]][tag]
                                 for tag in entity_bits]
            elif self._entity_bits_mode == "simple_collapsed":
                entity_bits_v = [[[0], [1], [1]][tag] for tag in entity_bits]
            elif self._entity_bits_mode == "simple3":
                entity_bits_v = [[[1, 0, 0], [0, 1, 0], [0, 0, 1]][tag]
                                 for tag in entity_bits]

            entity_bits_field = ArrayField(np.array(entity_bits_v))
            fields['entity_bits'] = entity_bits_field

        if logical_forms:
            action_map = {
                action.rule: i
                for i, action in enumerate(action_field.field_list)
            }  # type: ignore
            action_sequence_fields: List[Field] = []
            for logical_form in logical_forms:
                expression = world.parse_logical_form(logical_form)
                action_sequence = world.get_action_sequence(expression)
                try:
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except KeyError as error:
                    logger.info(
                        f'Missing production rule: {error.args}, skipping logical form'
                    )
                    logger.info(f'Question was: {question}')
                    logger.info(f'Logical form was: {logical_form}')
                    continue
            fields['target_action_sequences'] = ListField(
                action_sequence_fields)
        fields['metadata'] = MetadataField(additional_metadata or {})
        return Instance(fields)
Пример #5
0
    def text_to_instance(
        self,  # type: ignore
        question: str,
        table_lines: List[List[str]],
        target_values: List[str] = None,
        offline_search_output: List[str] = None,
    ) -> Instance:
        """
        Reads text inputs and makes an instance. We pass the ``table_lines`` to ``TableQuestionContext``, and that
        method accepts this field either as lines from CoreNLP processed tagged files that come with the dataset,
        or simply in a tsv format where each line corresponds to a row and the cells are tab-separated.

        Parameters
        ----------
        question : ``str``
            Input question
        table_lines : ``List[List[str]]``
            The table content optionally preprocessed by CoreNLP. See ``TableQuestionContext.read_from_lines``
            for the expected format.
        target_values : ``List[str]``, optional
            Target values for the denotations the logical forms should execute to. Not required for testing.
        offline_search_output : ``List[str]``, optional
            List of logical forms, produced by offline search. Not required during test.
        """
        tokenized_question = self._tokenizer.tokenize(question.lower())
        question_field = TextField(tokenized_question,
                                   self._question_token_indexers)
        metadata: Dict[str, Any] = {
            "question_tokens": [x.text for x in tokenized_question]
        }
        table_context = TableQuestionContext.read_from_lines(
            table_lines, tokenized_question)
        world = WikiTablesLanguage(table_context)
        world_field = MetadataField(world)
        # Note: Not passing any featre extractors when instantiating the field below. This will make
        # it use all the available extractors.
        table_field = KnowledgeGraphField(
            table_context.get_table_knowledge_graph(),
            tokenized_question,
            self._table_token_indexers,
            tokenizer=self._tokenizer,
            include_in_vocab=self._use_table_for_vocab,
            max_table_tokens=self._max_table_tokens,
        )
        production_rule_fields: List[Field] = []
        for production_rule in world.all_possible_productions():
            _, rule_right_side = production_rule.split(" -> ")
            is_global_rule = not world.is_instance_specific_entity(
                rule_right_side)
            field = ProductionRuleField(production_rule,
                                        is_global_rule=is_global_rule)
            production_rule_fields.append(field)
        action_field = ListField(production_rule_fields)

        fields = {
            "question": question_field,
            "metadata": MetadataField(metadata),
            "table": table_field,
            "world": world_field,
            "actions": action_field,
        }

        if target_values is not None:
            target_values_field = MetadataField(target_values)
            fields["target_values"] = target_values_field

        # We'll make each target action sequence a List[IndexField], where the index is into
        # the action list we made above.  We need to ignore the type here because mypy doesn't
        # like `action.rule` - it's hard to tell mypy that the ListField is made up of
        # ProductionRuleFields.
        action_map = {
            action.rule: i
            for i, action in enumerate(action_field.field_list)
        }  # type: ignore
        if offline_search_output:
            action_sequence_fields: List[Field] = []
            for logical_form in offline_search_output:
                try:
                    action_sequence = world.logical_form_to_action_sequence(
                        logical_form)
                    index_fields: List[Field] = []
                    for production_rule in action_sequence:
                        index_fields.append(
                            IndexField(action_map[production_rule],
                                       action_field))
                    action_sequence_fields.append(ListField(index_fields))
                except ParsingError as error:
                    logger.debug(
                        f"Parsing error: {error.message}, skipping logical form"
                    )
                    logger.debug(f"Question was: {question}")
                    logger.debug(f"Logical form was: {logical_form}")
                    logger.debug(f"Table info was: {table_lines}")
                    continue
                except KeyError as error:
                    logger.debug(
                        f"Missing production rule: {error.args}, skipping logical form"
                    )
                    logger.debug(f"Question was: {question}")
                    logger.debug(f"Table info was: {table_lines}")
                    logger.debug(f"Logical form was: {logical_form}")
                    continue
                except:  # noqa
                    logger.error(logical_form)
                    raise
                if len(action_sequence_fields
                       ) >= self._max_offline_logical_forms:
                    break

            if not action_sequence_fields:
                # This is not great, but we're only doing it when we're passed logical form
                # supervision, so we're expecting labeled logical forms, but we can't actually
                # produce the logical forms.  We should skip this instance.  Note that this affects
                # _dev_ and _test_ instances, too, so your metrics could be over-estimates on the
                # full test data.
                return None
            fields["target_action_sequences"] = ListField(
                action_sequence_fields)
        if self._output_agendas:
            agenda_index_fields: List[Field] = []
            for agenda_string in world.get_agenda(conservative=True):
                agenda_index_fields.append(
                    IndexField(action_map[agenda_string], action_field))
            if not agenda_index_fields:
                agenda_index_fields = [IndexField(-1, action_field)]
            fields["agenda"] = ListField(agenda_index_fields)
        return Instance(fields)
class TestKnowledgeGraphField(SemparseTestCase):
    def setup_method(self):
        self.tokenizer = SpacyTokenizer(pos_tags=True)
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        table_file = self.FIXTURES_ROOT / "data" / "wikitables" / "tables" / "341.tagged"
        self.graph = TableQuestionContext.read_from_file(
            table_file, self.utterance).get_table_knowledge_graph()
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name",
                                                            namespace="tokens")
        self.in_index = self.vocab.add_token_to_namespace("in",
                                                          namespace="tokens")
        self.english_index = self.vocab.add_token_to_namespace(
            "english", namespace="tokens")
        self.location_index = self.vocab.add_token_to_namespace(
            "location", namespace="tokens")
        self.mersin_index = self.vocab.add_token_to_namespace(
            "mersin", namespace="tokens")

        self.oov_index = self.vocab.get_token_index("random OOV string",
                                                    namespace="tokens")
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance,
                                         self.token_indexers, self.tokenizer)

        super().setup_method()

    def test_count_vocab_items(self):
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        self.field.count_vocab_items(namespace_token_counts)

        assert namespace_token_counts["tokens"] == {
            "name": 1,
            "in": 2,
            "english": 2,
            "location": 1,
            "mersin": 1,
        }

    def test_get_padding_lengths_raises_if_not_indexed(self):
        with pytest.raises(ConfigurationError):
            self.field.get_padding_lengths()

    def test_padding_lengths_are_computed_correctly(self):
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {
            "num_entities": 3,
            "num_utterance_tokens": 4,
            "num_fields": 3,
            "list_tokens___tokens": 3,
        }
        self.field._token_indexers[
            "token_characters"] = TokenCharactersIndexer(min_padding_length=1)
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {
            "num_entities": 3,
            "num_utterance_tokens": 4,
            "num_fields": 3,
            "list_tokens___tokens": 3,
            "list_token_characters___token_characters": 3,
            "list_token_characters___num_token_characters": 8,
        }

    def test_as_tensor_produces_correct_output(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        padding_lengths["num_utterance_tokens"] += 1
        padding_lengths["num_entities"] += 1
        padding_lengths["num_fields"] += 1
        tensor_dict = self.field.as_tensor(padding_lengths)
        assert tensor_dict.keys() == {"text", "linking"}
        expected_text_tensor = [
            [self.mersin_index, 0, 0],
            [self.location_index, self.in_index, self.english_index],
            [self.name_index, self.in_index, self.english_index],
            [0, 0, 0],
        ]
        assert_almost_equal(
            tensor_dict["text"]["tokens"]["tokens"].detach().cpu().numpy(),
            expected_text_tensor)

        linking_tensor = tensor_dict["linking"].detach().cpu().numpy()
        expected_linking_tensor = [
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # string:mersin, "where"
                [0, 0, 0, 0, 0, -1.5, 0, 0, 0, 0],  # string:mersin, "is"
                [0, 1, 1, 1, 1, 1, 0, 0, 1, 1],  # string:mersin, "mersin"
                [0, 0, 0, 0, 0, -5, 0, 0, 0, 0],  # string:mersin, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            ],  # string:mersin, padding
            [
                [0, 0, 0, 0, 0, -2.6, 0, 0, 0,
                 0],  # string_column:name_in_english, "where"
                [0, 0, 0, 0, 0, -7.5, 0, 0, 0,
                 0],  # string_column:name_in_english, "is"
                [0, 0, 0, 0, 0, -1.8333, 1, 1, 0,
                 0],  # string_column:..in_english, "mersin"
                [0, 0, 0, 0, 0, -18, 0, 0, 0,
                 0],  # string_column:name_in_english, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            ],  # string_column:name_in_english, padding
            [
                [0, 0, 0, 0, 0, -1.6, 0, 0, 0,
                 0],  # string_..:location_in_english, "where"
                [0, 0, 0, 0, 0, -5.5, 0, 0, 0,
                 0],  # string_column:location_in_english, "is"
                [0, 0, 0, 0, 0, -1, 0, 0, 0,
                 0],  # string_column:location_in_english, "mersin"
                [0, 0, 0, 0, 0, -14, 0, 0, 0,
                 0],  # string_column:location_in_english, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            ],  # string_column:location_in_english, padding
            [
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "where"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "is"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "mersin"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "?"
                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            ],
        ]  # padding, padding
        for entity_index, entity_features in enumerate(
                expected_linking_tensor):
            for question_index, feature_vector in enumerate(entity_features):
                assert_almost_equal(
                    linking_tensor[entity_index, question_index],
                    feature_vector,
                    decimal=4,
                    err_msg=f"{entity_index} {question_index}",
                )

    def test_lemma_feature_extractor(self):
        utterance = self.tokenizer.tokenize("Names in English")
        field = KnowledgeGraphField(self.graph, self.utterance,
                                    self.token_indexers, self.tokenizer)
        entity = "string_column:name_in_english"
        lemma_feature = field._contains_lemma_match(
            entity, field._entity_text_map[entity], utterance[0], 0, utterance)
        assert lemma_feature == 1

    def test_span_overlap_fraction(self):
        utterance = self.tokenizer.tokenize(
            "what is the name in english of mersin?")
        field = KnowledgeGraphField(self.graph, self.utterance,
                                    self.token_indexers, self.tokenizer)
        entity = "string_column:name_in_english"
        entity_text = field._entity_text_map[entity]
        feature_values = [
            field._span_overlap_fraction(entity, entity_text, token, i,
                                         utterance)
            for i, token in enumerate(utterance)
        ]
        assert feature_values == [0, 0, 0, 1, 1, 1, 0, 0, 0]

    def test_batch_tensors(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        tensor_dict1 = self.field.as_tensor(padding_lengths)
        tensor_dict2 = self.field.as_tensor(padding_lengths)
        batched_tensor_dict = self.field.batch_tensors(
            [tensor_dict1, tensor_dict2])
        assert batched_tensor_dict.keys() == {"text", "linking"}
        expected_single_tensor = [
            [self.mersin_index, 0, 0],
            [self.location_index, self.in_index, self.english_index],
            [self.name_index, self.in_index, self.english_index],
        ]
        expected_batched_tensor = [
            expected_single_tensor, expected_single_tensor
        ]
        assert_almost_equal(
            batched_tensor_dict["text"]["tokens"]
            ["tokens"].detach().cpu().numpy(),
            expected_batched_tensor,
        )
        expected_linking_tensor = torch.stack(
            [tensor_dict1["linking"], tensor_dict2["linking"]])
        assert_almost_equal(
            batched_tensor_dict["linking"].detach().cpu().numpy(),
            expected_linking_tensor.detach().cpu().numpy(),
        )

    def test_field_initialized_with_empty_constructor(self):
        try:
            self.field.empty_field()
        except AssertionError as e:
            pytest.fail(str(e), pytrace=True)
class KnowledgeGraphFieldTest(SemparseTestCase):
    def setUp(self):
        self.tokenizer = WordTokenizer(SpacyWordSplitter(pos_tags=True))
        self.utterance = self.tokenizer.tokenize("where is mersin?")
        self.token_indexers = {"tokens": SingleIdTokenIndexer("tokens")}

        table_file = self.FIXTURES_ROOT / "data" / "wikitables" / "tables" / "341.tagged"
        self.graph = TableQuestionContext.read_from_file(table_file, self.utterance).get_table_knowledge_graph()
        self.vocab = Vocabulary()
        self.name_index = self.vocab.add_token_to_namespace("name", namespace='tokens')
        self.in_index = self.vocab.add_token_to_namespace("in", namespace='tokens')
        self.english_index = self.vocab.add_token_to_namespace("english", namespace='tokens')
        self.location_index = self.vocab.add_token_to_namespace("location", namespace='tokens')
        self.mersin_index = self.vocab.add_token_to_namespace("mersin", namespace='tokens')

        self.oov_index = self.vocab.get_token_index('random OOV string', namespace='tokens')
        self.edirne_index = self.oov_index
        self.field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)

        super(KnowledgeGraphFieldTest, self).setUp()

    def test_count_vocab_items(self):
        namespace_token_counts = defaultdict(lambda: defaultdict(int))
        self.field.count_vocab_items(namespace_token_counts)

        assert namespace_token_counts["tokens"] == {
                'name': 1,
                'in': 2,
                'english': 2,
                'location': 1,
                'mersin': 1,
                }

    def test_index_converts_field_correctly(self):
        # pylint: disable=protected-access
        self.field.index(self.vocab)
        assert self.field._indexed_entity_texts.keys() == {'tokens'}
        # Note that these are sorted by their _identifiers_, not their cell text, so the
        # `fb:row.rows` show up after the `fb:cells`.
        expected_array = [[self.mersin_index],
                          [self.location_index, self.in_index, self.english_index],
                          [self.name_index, self.in_index, self.english_index]]
        assert self.field._indexed_entity_texts['tokens'] == expected_array

    def test_get_padding_lengths_raises_if_not_indexed(self):
        with pytest.raises(AssertionError):
            self.field.get_padding_lengths()

    def test_padding_lengths_are_computed_correctly(self):
        # pylint: disable=protected-access
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {'num_entities': 3, 'num_entity_tokens': 3,
                                                    'num_utterance_tokens': 4}
        self.field._token_indexers['token_characters'] = TokenCharactersIndexer(min_padding_length=1)
        self.field.index(self.vocab)
        assert self.field.get_padding_lengths() == {'num_entities': 3, 'num_entity_tokens': 3,
                                                    'num_utterance_tokens': 4,
                                                    'num_token_characters': 8}

    def test_as_tensor_produces_correct_output(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        padding_lengths['num_utterance_tokens'] += 1
        padding_lengths['num_entities'] += 1
        tensor_dict = self.field.as_tensor(padding_lengths)
        assert tensor_dict.keys() == {'text', 'linking'}
        expected_text_tensor = [[self.mersin_index, 0, 0],
                                [self.location_index, self.in_index, self.english_index],
                                [self.name_index, self.in_index, self.english_index],
                                [0, 0, 0]]
        assert_almost_equal(tensor_dict['text']['tokens'].detach().cpu().numpy(), expected_text_tensor)

        linking_tensor = tensor_dict['linking'].detach().cpu().numpy()
        expected_linking_tensor = [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # string:mersin, "where"
                                    [0, 0, 0, 0, 0, -1.5, 0, 0, 0, 0],  # string:mersin, "is"
                                    [0, 1, 1, 1, 1, 1, 0, 0, 1, 1],  # string:mersin, "mersin"
                                    [0, 0, 0, 0, 0, -5, 0, 0, 0, 0],  # string:mersin, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # string:mersin, padding
                                   [[0, 0, 0, 0, 0, -2.6, 0, 0, 0, 0],  # string_column:name_in_english, "where"
                                    [0, 0, 0, 0, 0, -7.5, 0, 0, 0, 0],  # string_column:name_in_english, "is"
                                    [0, 0, 0, 0, 0, -1.8333, 1, 1, 0, 0],  # string_column:..in_english, "mersin"
                                    [0, 0, 0, 0, 0, -18, 0, 0, 0, 0],  # string_column:name_in_english, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # string_column:name_in_english, padding
                                   [[0, 0, 0, 0, 0, -1.6, 0, 0, 0, 0],  # string_..:location_in_english, "where"
                                    [0, 0, 0, 0, 0, -5.5, 0, 0, 0, 0],  # string_column:location_in_english, "is"
                                    [0, 0, 0, 0, 0, -1, 0, 0, 0, 0],  # string_column:location_in_english, "mersin"
                                    [0, 0, 0, 0, 0, -14, 0, 0, 0, 0],  # string_column:location_in_english, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],  # string_column:location_in_english, padding
                                   [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "where"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "is"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "mersin"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],  # padding, "?"
                                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]  # padding, padding
        for entity_index, entity_features in enumerate(expected_linking_tensor):
            for question_index, feature_vector in enumerate(entity_features):
                assert_almost_equal(linking_tensor[entity_index, question_index],
                                    feature_vector,
                                    decimal=4,
                                    err_msg=f"{entity_index} {question_index}")

    def test_lemma_feature_extractor(self):
        # pylint: disable=protected-access
        utterance = self.tokenizer.tokenize("Names in English")
        field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)
        entity = 'string_column:name_in_english'
        lemma_feature = field._contains_lemma_match(entity,
                                                    field._entity_text_map[entity],
                                                    utterance[0],
                                                    0,
                                                    utterance)
        assert lemma_feature == 1

    def test_span_overlap_fraction(self):
        # pylint: disable=protected-access
        utterance = self.tokenizer.tokenize("what is the name in english of mersin?")
        field = KnowledgeGraphField(self.graph, self.utterance, self.token_indexers, self.tokenizer)
        entity = 'string_column:name_in_english'
        entity_text = field._entity_text_map[entity]
        feature_values = [field._span_overlap_fraction(entity, entity_text, token, i, utterance)
                          for i, token in enumerate(utterance)]
        assert feature_values == [0, 0, 0, 1, 1, 1, 0, 0, 0]

    def test_batch_tensors(self):
        self.field.index(self.vocab)
        padding_lengths = self.field.get_padding_lengths()
        tensor_dict1 = self.field.as_tensor(padding_lengths)
        tensor_dict2 = self.field.as_tensor(padding_lengths)
        batched_tensor_dict = self.field.batch_tensors([tensor_dict1, tensor_dict2])
        assert batched_tensor_dict.keys() == {'text', 'linking'}
        expected_single_tensor = [[self.mersin_index, 0, 0],
                                  [self.location_index, self.in_index, self.english_index],
                                  [self.name_index, self.in_index, self.english_index]]
        expected_batched_tensor = [expected_single_tensor, expected_single_tensor]
        assert_almost_equal(batched_tensor_dict['text']['tokens'].detach().cpu().numpy(),
                            expected_batched_tensor)
        expected_linking_tensor = torch.stack([tensor_dict1['linking'], tensor_dict2['linking']])
        assert_almost_equal(batched_tensor_dict['linking'].detach().cpu().numpy(),
                            expected_linking_tensor.detach().cpu().numpy())

    def test_field_initialized_with_empty_constructor(self):
        try:
            self.field.empty_field()
        except AssertionError as e:
            pytest.fail(str(e), pytrace=True)