Пример #1
0
 def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
     super().__init__(dictionary)
     self.register_buffer('version', torch.Tensor([3]))
     self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
     layers = [
         TransformerDecoderLayer(args, no_encoder_attn)
         for _ in range(args.decoder_layers)
     ]
     self.decoder_layers = nn.Sequential(*layers)
     self.decoder_output_layer = TransformerDecoderOutputLayer(
         args, embed_tokens, dictionary)
Пример #2
0
 def __init__(
     self,
     args,
     dictionary,
     embed_tokens,
     no_encoder_attn=False,
     decoder_module_list=None,
 ):
     super().__init__(dictionary)
     self.register_buffer("version", torch.Tensor([3]))
     try:
         from fairscale.nn import Pipe
     except ImportError:
         raise ImportError(
             "Please install fairscale with: pip install fairscale")
     if decoder_module_list is None:
         embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
         layers = [
             TransformerDecoderLayer(args, no_encoder_attn)
             for _ in range(args.decoder_layers)
         ]
         decoder_output_layer = TransformerDecoderOutputLayer(
             args, embed_tokens, dictionary)
         decoder_module_list = [embedding_layer
                                ] + layers + [decoder_output_layer]
     self.use_pipeline = getattr(args, "pipeline_decoder_balance",
                                 None) is not None
     if self.use_pipeline:
         decoder_balance = utils.eval_str_list(
             args.pipeline_decoder_balance, type=int)
         decoder_devices = utils.eval_str_list(
             args.pipeline_decoder_devices, type=int)
         assert sum(decoder_balance) == len(decoder_module_list), (
             f"Sum of decoder_balance={decoder_balance} is not equal " +
             f"to num_decoder_modules={len(decoder_module_list)}")
         self.model = Pipe(
             module=nn.Sequential(*decoder_module_list),
             balance=decoder_balance,
             devices=decoder_devices,
             chunks=args.pipeline_chunks,
             checkpoint=args.pipeline_checkpoint,
         )
     else:
         self.embedding_layer = decoder_module_list[0]
         self.decoder_layers = nn.Sequential(*decoder_module_list[1:-1])
         self.decoder_output_layer = decoder_module_list[-1]
Пример #3
0
 def __init__(
     self,
     args,
     dictionary,
     embed_tokens,
     no_encoder_attn=False,
     decoder_module_list=None,
 ):
     super().__init__(dictionary)
     self.register_buffer("version", torch.Tensor([3]))
     import_pipe()
     self.use_pipeline = decoder_module_list is not None
     if not self.use_pipeline:
         self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
         self.decoder_layers = nn.Sequential(*[
             TransformerDecoderLayer(args, no_encoder_attn)
             for _ in range(args.decoder_layers)
         ])
         self.decoder_output_layer = TransformerDecoderOutputLayer(
             args, embed_tokens, dictionary
         )
     else:
         decoder_balance = utils.eval_str_list(
             args.pipeline_decoder_balance, type=int
         )
         decoder_devices = utils.eval_str_list(
             args.pipeline_decoder_devices, type=int
         )
         assert sum(decoder_balance) == len(decoder_module_list), (
             f"Sum of decoder_balance={decoder_balance} is not equal "
             + f"to num_decoder_modules={len(decoder_module_list)}"
         )
         if TORCH_PIPE:
             self.model = Pipe(
                 module=partition_model(nn.Sequential(*decoder_module_list), decoder_balance, decoder_devices),
                 chunks=args.pipeline_chunks,
                 checkpoint=args.pipeline_checkpoint,
             )
         else:
             self.model = Pipe(
                 module=nn.Sequential(*decoder_module_list),
                 balance=decoder_balance,
                 devices=decoder_devices,
                 chunks=args.pipeline_chunks,
                 checkpoint=args.pipeline_checkpoint,
             )