Ejemplo n.º 1
0
        def load(dir_path):
            checkpoint_path = os.path.join(dir_path, CHECKPOINT_FNAME)
            checkpoint = torch.load(checkpoint_path)

            vocabs = {}
            token_config = config.token
            for token_name in token_config.names:
                token = getattr(token_config, token_name, {})
                vocab_config = getattr(token, "vocab", {})

                texts = checkpoint["vocab_texts"][token_name]
                if type(vocab_config) != dict:
                    vocab_config = vars(vocab_config)
                vocabs[token_name] = Vocab(token_name,
                                           **vocab_config).from_texts(texts)

            for token_name, token_maker in token_makers.items():
                token_maker.set_vocab(vocabs[token_name])
            return token_makers
Ejemplo n.º 2
0
    def __init__(
        self,
        index,
        token_name,
        classes,
        dropout=0,
        embed_dim=15,
        trainable=True,
        padding_idx=None,
        max_norm=None,
        norm_type=2,
        scale_grad_by_freq=False,
        sparse=False,
    ):
        super(SparseToEmbedding, self).__init__()

        self.embed_dim = embed_dim

        vocab = Vocab(token_name)
        vocab.init()
        for c in classes[index]:
            vocab.add(c)

        embedding_params = {
            "vocab": vocab,
            "dropout": dropout,
            "embed_dim": embed_dim,
            "trainable": trainable,
            "padding_idx": padding_idx,
            "max_norm": max_norm,
            "norm_type": norm_type,
            "scale_grad_by_freq": scale_grad_by_freq,
            "sparse": sparse,
        }

        self.embedding = WordEmbedding(**embedding_params)
Ejemplo n.º 3
0
def test_save_and_load():
    texts = "A\nB\nC\nD"

    vocab = Vocab("token_name")
    vocab.from_texts(texts)

    vocab_path = "./test_vocab.txt"
    vocab.dump(vocab_path)

    vocab2 = Vocab("token_name")
    vocab2.load(vocab_path)

    os.remove(vocab_path)
    assert vocab.get_all_tokens() == vocab2.get_all_tokens()
Ejemplo n.º 4
0
def test_init_vocab():
    vocab = Vocab("token_name")
    vocab.init()

    assert vocab.get_all_tokens() == ["[PAD]", "[UNK]"]