Ejemplo n.º 1
0
    def test_from_params(self):
        # Save a vocab to check we can load it from_params.
        vocab_dir = os.path.join(self.TEST_DIR, 'vocab_save')
        vocab = Vocabulary(non_padded_namespaces=["a", "c"])
        vocab.add_token_to_namespace("a0", namespace="a")  # non-padded, should start at 0
        vocab.add_token_to_namespace("a1", namespace="a")
        vocab.add_token_to_namespace("a2", namespace="a")
        vocab.add_token_to_namespace("b2", namespace="b")  # padded, should start at 2
        vocab.add_token_to_namespace("b3", namespace="b")
        vocab.save_to_files(vocab_dir)

        params = Params({"directory_path": vocab_dir})
        vocab2 = Vocabulary.from_params(params)
        assert vocab.get_index_to_token_vocabulary("a") == vocab2.get_index_to_token_vocabulary("a")
        assert vocab.get_index_to_token_vocabulary("b") == vocab2.get_index_to_token_vocabulary("b")

        # Test case where we build a vocab from a dataset.
        vocab2 = Vocabulary.from_params(Params({}), self.dataset)
        assert vocab2.get_index_to_token_vocabulary("tokens") == {0: '@@PADDING@@',
                                                                  1: '@@UNKNOWN@@',
                                                                  2: 'a', 3: 'c', 4: 'b'}
        # Test from_params raises when we have neither a dataset and a vocab_directory.
        with pytest.raises(ConfigurationError):
            _ = Vocabulary.from_params(Params({}))

        # Test from_params raises when there are any other dict keys
        # present apart from 'vocabulary_directory' and we aren't calling from_dataset.
        with pytest.raises(ConfigurationError):
            _ = Vocabulary.from_params(Params({"directory_path": vocab_dir, "min_count": {'tokens': 2}}))
Ejemplo n.º 2
0
    def test_saving_and_loading(self):
        # pylint: disable=protected-access
        vocab_dir = os.path.join(self.TEST_DIR, 'vocab_save')

        vocab = Vocabulary(non_padded_namespaces=["a", "c"])
        vocab.add_token_to_namespace("a0", namespace="a")  # non-padded, should start at 0
        vocab.add_token_to_namespace("a1", namespace="a")
        vocab.add_token_to_namespace("a2", namespace="a")
        vocab.add_token_to_namespace("b2", namespace="b")  # padded, should start at 2
        vocab.add_token_to_namespace("b3", namespace="b")

        vocab.save_to_files(vocab_dir)
        vocab2 = Vocabulary.from_files(vocab_dir)

        assert vocab2._non_padded_namespaces == ["a", "c"]

        # Check namespace a.
        assert vocab2.get_vocab_size(namespace='a') == 3
        assert vocab2.get_token_from_index(0, namespace='a') == 'a0'
        assert vocab2.get_token_from_index(1, namespace='a') == 'a1'
        assert vocab2.get_token_from_index(2, namespace='a') == 'a2'
        assert vocab2.get_token_index('a0', namespace='a') == 0
        assert vocab2.get_token_index('a1', namespace='a') == 1
        assert vocab2.get_token_index('a2', namespace='a') == 2

        # Check namespace b.
        assert vocab2.get_vocab_size(namespace='b') == 4  # (unk + padding + two tokens)
        assert vocab2.get_token_from_index(0, namespace='b') == vocab._padding_token
        assert vocab2.get_token_from_index(1, namespace='b') == vocab._oov_token
        assert vocab2.get_token_from_index(2, namespace='b') == 'b2'
        assert vocab2.get_token_from_index(3, namespace='b') == 'b3'
        assert vocab2.get_token_index(vocab._padding_token, namespace='b') == 0
        assert vocab2.get_token_index(vocab._oov_token, namespace='b') == 1
        assert vocab2.get_token_index('b2', namespace='b') == 2
        assert vocab2.get_token_index('b3', namespace='b') == 3

        # Check the dictionaries containing the reverse mapping are identical.
        assert vocab.get_index_to_token_vocabulary("a") == vocab2.get_index_to_token_vocabulary("a")
        assert vocab.get_index_to_token_vocabulary("b") == vocab2.get_index_to_token_vocabulary("b")