コード例 #1
0
    def build_embedding(cls, args, dictionary, embed_dim, path=None):
        def _vocab_init(tensor, **kwargs):
            nn.init.normal_(tensor, mean=0, std=embed_dim ** -0.5)
            nn.init.constant_(tensor[1], 0)

        embed_tokens = VocabParallelEmbedding(
            len(dictionary), embed_dim, dictionary.pad(), init_method=_vocab_init
        )
        return embed_tokens
コード例 #2
0
ファイル: transformer.py プロジェクト: NJUNLP/TMM-for-MAMS
    def build_embedding(cls, args, dictionary, embed_dim, path=None):
        if not has_megatron_submodule:
            raise ImportError(
                '\n\nPlease install the megatron submodule:'
                '\n\n  git submodule update --init '
                'fairseq/model_parallel/megatron'
            )
        num_embeddings = len(dictionary)
        padding_idx = dictionary.pad()

        def _vocab_init(tensor, **kwargs):
            nn.init.normal_(tensor, mean=0, std=num_embeddings ** -0.5)
            nn.init.constant_(tensor[1], 0)
        emb = VocabParallelEmbedding(num_embeddings, embed_dim, padding_idx, init_method=_vocab_init)
        # if provided, load from preloaded dictionaries
        if path:
            raise NotImplementedError("Loading of embedding from path is not supported for model parallel")
        return emb
コード例 #3
0
 def build_embedding(self, vocab_size, embedding_dim, padding_idx):
     return VocabParallelEmbedding(vocab_size, embedding_dim, padding_idx)