def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): super().__init__(dictionary) self.register_buffer('version', torch.Tensor([3])) self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) layers = [ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) ] self.decoder_layers = nn.Sequential(*layers) self.decoder_output_layer = TransformerDecoderOutputLayer( args, embed_tokens, dictionary)
def __init__( self, args, dictionary, embed_tokens, no_encoder_attn=False, decoder_module_list=None, ): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) try: from fairscale.nn import Pipe except ImportError: raise ImportError( "Please install fairscale with: pip install fairscale") if decoder_module_list is None: embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) layers = [ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) ] decoder_output_layer = TransformerDecoderOutputLayer( args, embed_tokens, dictionary) decoder_module_list = [embedding_layer ] + layers + [decoder_output_layer] self.use_pipeline = getattr(args, "pipeline_decoder_balance", None) is not None if self.use_pipeline: decoder_balance = utils.eval_str_list( args.pipeline_decoder_balance, type=int) decoder_devices = utils.eval_str_list( args.pipeline_decoder_devices, type=int) assert sum(decoder_balance) == len(decoder_module_list), ( f"Sum of decoder_balance={decoder_balance} is not equal " + f"to num_decoder_modules={len(decoder_module_list)}") self.model = Pipe( module=nn.Sequential(*decoder_module_list), balance=decoder_balance, devices=decoder_devices, chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) else: self.embedding_layer = decoder_module_list[0] self.decoder_layers = nn.Sequential(*decoder_module_list[1:-1]) self.decoder_output_layer = decoder_module_list[-1]
def __init__( self, args, dictionary, embed_tokens, no_encoder_attn=False, decoder_module_list=None, ): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) import_pipe() self.use_pipeline = decoder_module_list is not None if not self.use_pipeline: self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) self.decoder_layers = nn.Sequential(*[ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) ]) self.decoder_output_layer = TransformerDecoderOutputLayer( args, embed_tokens, dictionary ) else: decoder_balance = utils.eval_str_list( args.pipeline_decoder_balance, type=int ) decoder_devices = utils.eval_str_list( args.pipeline_decoder_devices, type=int ) assert sum(decoder_balance) == len(decoder_module_list), ( f"Sum of decoder_balance={decoder_balance} is not equal " + f"to num_decoder_modules={len(decoder_module_list)}" ) if TORCH_PIPE: self.model = Pipe( module=partition_model(nn.Sequential(*decoder_module_list), decoder_balance, decoder_devices), chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) else: self.model = Pipe( module=nn.Sequential(*decoder_module_list), balance=decoder_balance, devices=decoder_devices, chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, )