def test_from_dataset_respects_max_vocab_size_single_int(self):
        max_vocab_size = 1
        vocab = ExtendedVocabulary.from_instances(
            self.dataset, max_vocab_size=max_vocab_size)
        words = vocab.get_index_to_token_vocabulary().values()
        # Additional 2 tokens are '@@PADDING@@' and '@@UNKNOWN@@' by default
        assert len(words) == max_vocab_size + 2

        vocab = ExtendedVocabulary.from_instances(self.dataset, min_count=None)
        words = vocab.get_index_to_token_vocabulary().values()
        assert len(words) == 5
    def test_from_dataset_respects_min_count(self):
        vocab = ExtendedVocabulary.from_instances(
            self.dataset, min_count={'tokens': 4})
        words = vocab.get_index_to_token_vocabulary().values()
        assert 'a' in words
        assert 'b' not in words
        assert 'c' not in words

        vocab = ExtendedVocabulary.from_instances(self.dataset, min_count=None)
        words = vocab.get_index_to_token_vocabulary().values()
        assert 'a' in words
        assert 'b' in words
        assert 'c' in words
Beispiel #3
0
    def __init__(self,
                 vocabulary: ExtendedVocabulary,
                 namespace: str = 'tokens',
                 oov_token: str = DEFAULT_OOV_TOKEN) -> None:
        # Compute the penalty weight applied to p(<unk>).
        unk_vocab_size = vocabulary.get_vocab_size(namespace + '_unk')
        if unk_vocab_size > 0:
            self._unk_penalty = math.log(unk_vocab_size)  # pylint: disable=no-member
        else:
            self._unk_penalty = 0.0

        # Identify the index of the <unk> token.
        self._unk_idx = vocabulary.get_token_index(oov_token, namespace=namespace)

        # Initialize the metric variables.
        self._sum_log_p = 0.0
        self._total_count = 0.0
    def test_saving_and_loading(self):
        # pylint: disable=protected-access
        vocab_dir = self.TEST_DIR / 'vocab_save'

        vocab = ExtendedVocabulary(non_padded_namespaces=["a", "c"])
        # non-padded, should start at 0
        vocab.add_token_to_namespace("a0", namespace="a")
        vocab.add_token_to_namespace("a1", namespace="a")
        vocab.add_token_to_namespace("a2", namespace="a")
        # padded, should start at 2
        vocab.add_token_to_namespace("b2", namespace="b")
        vocab.add_token_to_namespace("b3", namespace="b")

        vocab.save_to_files(vocab_dir)
        vocab2 = ExtendedVocabulary.from_files(vocab_dir)

        assert vocab2._non_padded_namespaces == {"a", "c"}

        # Check namespace a.
        assert vocab2.get_vocab_size(namespace='a') == 3
        assert vocab2.get_token_from_index(0, namespace='a') == 'a0'
        assert vocab2.get_token_from_index(1, namespace='a') == 'a1'
        assert vocab2.get_token_from_index(2, namespace='a') == 'a2'
        assert vocab2.get_token_index('a0', namespace='a') == 0
        assert vocab2.get_token_index('a1', namespace='a') == 1
        assert vocab2.get_token_index('a2', namespace='a') == 2

        # Check namespace b.
        # (unk + padding + two tokens)
        assert vocab2.get_vocab_size(namespace='b') == 4
        assert vocab2.get_token_from_index(
            0, namespace='b') == vocab._padding_token
        assert vocab2.get_token_from_index(
            1, namespace='b') == vocab._oov_token
        assert vocab2.get_token_from_index(2, namespace='b') == 'b2'
        assert vocab2.get_token_from_index(3, namespace='b') == 'b3'
        assert vocab2.get_token_index(vocab._padding_token, namespace='b') == 0
        assert vocab2.get_token_index(vocab._oov_token, namespace='b') == 1
        assert vocab2.get_token_index('b2', namespace='b') == 2
        assert vocab2.get_token_index('b3', namespace='b') == 3

        # Check the dictionaries containing the reverse mapping are identical.
        assert vocab.get_index_to_token_vocabulary(
            "a") == vocab2.get_index_to_token_vocabulary("a")
        assert vocab.get_index_to_token_vocabulary(
            "b") == vocab2.get_index_to_token_vocabulary("b")
 def test_unk_namespace_is_empty_if_vocab_unconstrained(self):
     vocab = ExtendedVocabulary.from_instances(self.dataset)
     words = vocab.get_index_to_token_vocabulary('tokens_unk')
     assert not words  # This checks that there's nothing in ``words`` w/out pylint complaining