Exemplo n.º 1
0
 def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None):
     super().__init__(dictionary)
     self.register_buffer("version", torch.Tensor([3]))
     try:
         from fairscale.nn import Pipe
     except ImportError:
         raise ImportError("Please install fairscale with: pip install fairscale")
     self.use_pipeline = encoder_module_list is not None
     if not self.use_pipeline:
         self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens)
         self.encoder_layers = nn.Sequential(*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)])
         if isinstance(embed_tokens, nn.ModuleList):
             emb_dim = sum(e.embedding_dim for e in embed_tokens)
         else:
             emb_dim = embed_tokens.embedding_dim
         self.final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim)
     else:
         encoder_balance = utils.eval_str_list(
             args.pipeline_encoder_balance, type=int
         )
         encoder_devices = utils.eval_str_list(
             args.pipeline_encoder_devices, type=int
         )
         assert sum(encoder_balance) == len(encoder_module_list), (
             f"Sum of encoder_balance={encoder_balance} is not equal "
             + f"to num_encoder_modules={len(encoder_module_list)}"
         )
         self.model = Pipe(
             module=nn.Sequential(*encoder_module_list),
             balance=encoder_balance,
             devices=encoder_devices,
             chunks=args.pipeline_chunks,
             checkpoint=args.pipeline_checkpoint,
         )
Exemplo n.º 2
0
 def __init__(self, args, dictionary, embed_tokens):
     super().__init__(dictionary)
     self.register_buffer('version', torch.Tensor([3]))
     self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens)
     layers = [
         TransformerEncoderLayer(args) for i in range(args.encoder_layers)
     ]
     # Note: layer drop not supported yet
     # Note: layer wise attention not supported yet
     self.encoder_layers = nn.Sequential(*layers)
     if isinstance(embed_tokens, nn.ModuleList):
         emb_dim = sum(e.embedding_dim for e in embed_tokens)
     else:
         emb_dim = embed_tokens.embedding_dim
     self.final_layer_norm = \
         TransformerEncoderLayerNorm(args, emb_dim)