def load(dir_path): checkpoint_path = os.path.join(dir_path, CHECKPOINT_FNAME) checkpoint = torch.load(checkpoint_path) vocabs = {} token_config = config.token for token_name in token_config.names: token = getattr(token_config, token_name, {}) vocab_config = getattr(token, "vocab", {}) texts = checkpoint["vocab_texts"][token_name] if type(vocab_config) != dict: vocab_config = vars(vocab_config) vocabs[token_name] = Vocab(token_name, **vocab_config).from_texts(texts) for token_name, token_maker in token_makers.items(): token_maker.set_vocab(vocabs[token_name]) return token_makers
def __init__( self, index, token_name, classes, dropout=0, embed_dim=15, trainable=True, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False, ): super(SparseToEmbedding, self).__init__() self.embed_dim = embed_dim vocab = Vocab(token_name) vocab.init() for c in classes[index]: vocab.add(c) embedding_params = { "vocab": vocab, "dropout": dropout, "embed_dim": embed_dim, "trainable": trainable, "padding_idx": padding_idx, "max_norm": max_norm, "norm_type": norm_type, "scale_grad_by_freq": scale_grad_by_freq, "sparse": sparse, } self.embedding = WordEmbedding(**embedding_params)
def test_save_and_load(): texts = "A\nB\nC\nD" vocab = Vocab("token_name") vocab.from_texts(texts) vocab_path = "./test_vocab.txt" vocab.dump(vocab_path) vocab2 = Vocab("token_name") vocab2.load(vocab_path) os.remove(vocab_path) assert vocab.get_all_tokens() == vocab2.get_all_tokens()
def test_init_vocab(): vocab = Vocab("token_name") vocab.init() assert vocab.get_all_tokens() == ["[PAD]", "[UNK]"]