示例#1
0
    def test_embedding_vocab_extension_without_stored_namespace(self):
        vocab = Vocabulary()
        vocab.add_token_to_namespace("word1", "tokens_a")
        vocab.add_token_to_namespace("word2", "tokens_a")
        embedding_params = Params({
            "vocab_namespace": "tokens_a",
            "embedding_dim": 10
        })
        embedder = Embedding.from_vocab_or_file(
            vocab, **embedding_params.as_dict(quiet=True))

        # Previous models won't have _vocab_namespace attribute. Force it to be None
        embedder._vocab_namespace = None
        original_weight = embedder.weight

        assert original_weight.shape[0] == 4

        extension_counter = {"tokens_a": {"word3": 1}}
        vocab._extend(extension_counter)

        embedder.extend_vocab(vocab, "tokens_a")  # specified namespace

        extended_weight = embedder.weight
        assert extended_weight.shape[0] == 5
        assert torch.all(extended_weight[:4, :] == original_weight[:4, :])
示例#2
0
    def test_embedding_vocab_extension_with_default_namespace(self):
        vocab = Vocabulary()
        vocab.add_token_to_namespace('word1')
        vocab.add_token_to_namespace('word2')
        embedding_params = Params({"vocab_namespace": "tokens",
                                   "embedding_dim": 10})
        embedder = Embedding.from_params(vocab, embedding_params)
        original_weight = embedder.weight

        assert original_weight.shape[0] == 4

        extension_counter = {"tokens": {"word3": 1}}
        vocab._extend(extension_counter)

        embedder.extend_vocab(vocab) # default namespace

        extended_weight = embedder.weight
        assert extended_weight.shape[0] == 5
        assert torch.all(extended_weight[:4, :] == original_weight[:4, :])
示例#3
0
    def test_embedding_vocab_extension_with_specified_namespace(self):
        vocab = Vocabulary()
        vocab.add_token_to_namespace("word1", "tokens_a")
        vocab.add_token_to_namespace("word2", "tokens_a")
        embedding_params = Params({
            "vocab_namespace": "tokens_a",
            "embedding_dim": 10
        })
        embedder = Embedding.from_params(embedding_params, vocab=vocab)
        original_weight = embedder.weight

        assert original_weight.shape[0] == 4

        extension_counter = {"tokens_a": {"word3": 1}}
        vocab._extend(extension_counter)

        embedder.extend_vocab(vocab, "tokens_a")  # specified namespace

        extended_weight = embedder.weight
        assert extended_weight.shape[0] == 5
        assert torch.all(extended_weight[:4, :] == original_weight[:4, :])
示例#4
0
def add_env_tokens_to_vocab(vocab: Vocabulary, actions: List[Union[str, int]] = None, stack_states: Enum = None,
                            exec_states: Enum = None):
    if actions:
        actions = map(lambda x: str(x), actions)
    else:
        actions = []
    if stack_states:
        stack_states = map(lambda x: x.name, stack_states)
    else:
        stack_states = []
    if exec_states:
        exec_states = map(lambda x: x.name, exec_states)
    else:
        exec_states = []
    extra_vob_counter = {
        'stack': OrderedDict({state: 1 for state in stack_states}),
        'exec': OrderedDict({state: 1 for state in exec_states}),
        'action': OrderedDict({action: 1 for action in actions})
    }
    vocab._extend(extra_vob_counter, non_padded_namespaces=['stack', 'exec', 'action'])
    return vocab