def build_embedding(dictionary, embed_dim, path=None, num_embed_chunks=1): assert embed_dim % num_embed_chunks == 0, ( f"Number of embedding chunks = {num_embed_chunks} should be " + f"divisible by the embedding dimension = {embed_dim}") assert path is None or num_embed_chunks == 1, ( "Loading embedding from a path with number of embedding chunks > 1" + " is not yet supported") num_embeddings = len(dictionary) padding_idx = dictionary.pad() # if provided, load from preloaded dictionaries if path: emb = Embedding(num_embeddings, embed_dim, padding_idx) embed_dict = utils.parse_embedding(path) utils.load_embedding(embed_dict, dictionary, emb) else: embed_chunk_dim = embed_dim // num_embed_chunks emb = nn.ModuleList() for i in range(num_embed_chunks): emb.append( Embedding(num_embeddings, embed_chunk_dim, padding_idx)) return emb
def build_embedding(dictionary, embed_dim, path=None): num_embeddings = len(dictionary) padding_idx = dictionary.pad() emb = Embedding(num_embeddings, embed_dim, padding_idx) # if provided, load from preloaded dictionaries if path: embed_dict = utils.parse_embedding(path) utils.load_embedding(embed_dict, dictionary, emb) return emb