def bench_single_process(args): os.environ.update({"MASTER_ADDR": args.host}) os.environ.update({"MASTER_PORT": "10638"}) rpc.init_rpc( "worker", rank=0, world_size=1, ) num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1 num_devices = min(args.num_devices, num_devices) assert num_devices > 0 init_random_seed(0) device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") blob = make_model_and_data(args, None) model = blob["model"] balance = generate_balance(num_devices, len(model)) model = partition_model(model, balance) p = Pipe(model, chunks=args.chunks, checkpoint=args.checkpoint) del model del blob["model"] train(blob["data"], p, blob["criterion"], blob["optimizer"], blob["vocab_size"], args)
def test_1to3(balance, checkpoint, setup_rpc): if torch.cuda.device_count() < len(balance): pytest.skip("at least %d cuda devices required" % len(balance)) @skippable(stash=["1to3"]) class Layer1(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 1) def forward(self, input): yield stash("1to3", input) output = self.conv(input) return output # noqa class Layer2(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 1) def forward(self, input): output = self.conv(input) return output @skippable(pop=["1to3"]) class Layer3(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 3, 1) def forward(self, input): skip_1to3 = yield pop("1to3") output = self.conv(input) + skip_1to3 return output model = nn.Sequential(Layer1(), Layer2(), Layer3()) model = partition_model(model, balance) model = Pipe(model, chunks=3, checkpoint=checkpoint) in_device = model.devices[0] out_device = model.devices[-1] input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True) output = model(input) loss = output.local_value().mean() loss.backward() assert torch.allclose(output.local_value().norm(), torch.tensor(1039.0, device=out_device), atol=6e-1) assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device))
def __init__( self, args, dictionary, embed_tokens, no_encoder_attn=False, decoder_module_list=None, ): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) import_pipe() self.use_pipeline = decoder_module_list is not None if not self.use_pipeline: self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) self.decoder_layers = nn.Sequential(*[ TransformerDecoderLayer(args, no_encoder_attn) for _ in range(args.decoder_layers) ]) self.decoder_output_layer = TransformerDecoderOutputLayer( args, embed_tokens, dictionary ) else: decoder_balance = utils.eval_str_list( args.pipeline_decoder_balance, type=int ) decoder_devices = utils.eval_str_list( args.pipeline_decoder_devices, type=int ) assert sum(decoder_balance) == len(decoder_module_list), ( f"Sum of decoder_balance={decoder_balance} is not equal " + f"to num_decoder_modules={len(decoder_module_list)}" ) if TORCH_PIPE: self.model = Pipe( module=partition_model(nn.Sequential(*decoder_module_list), decoder_balance, decoder_devices), chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) else: self.model = Pipe( module=nn.Sequential(*decoder_module_list), balance=decoder_balance, devices=decoder_devices, chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, )
def __init__(self, args, dictionary, embed_tokens, encoder_module_list=None): super().__init__(dictionary) self.register_buffer("version", torch.Tensor([3])) import_pipe() self.use_pipeline = encoder_module_list is not None if not self.use_pipeline: self.embedding_layer = TransformerEncoderEmbedding( args, embed_tokens) self.encoder_layers = nn.Sequential(*[ TransformerEncoderLayer(args) for i in range(args.encoder_layers) ]) if isinstance(embed_tokens, nn.ModuleList): emb_dim = sum(e.embedding_dim for e in embed_tokens) else: emb_dim = embed_tokens.embedding_dim self.final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim) else: encoder_balance = utils.eval_str_list( args.pipeline_encoder_balance, type=int) encoder_devices = utils.eval_str_list( args.pipeline_encoder_devices, type=int) assert sum(encoder_balance) == len(encoder_module_list), ( f"Sum of encoder_balance={encoder_balance} is not equal " + f"to num_encoder_modules={len(encoder_module_list)}") if TORCH_PIPE: self.model = Pipe( module=partition_model(nn.Sequential(*encoder_module_list), encoder_balance, encoder_devices), chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) else: self.model = Pipe( module=nn.Sequential(*encoder_module_list), balance=encoder_balance, devices=encoder_devices, chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, )
def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint): import_pipe() super().__init__() assert isinstance(encoder, FairseqEncoder) assert isinstance(decoder, FairseqDecoder) encoder_module_list = ( [encoder.embedding_layer] + list(encoder.encoder_layers) + [encoder.final_layer_norm] ) self.num_encoder_modules = len(encoder_module_list) decoder_module_list = ( [decoder.embedding_layer] + list(decoder.decoder_layers) + [decoder.decoder_output_layer] ) self.num_decoder_modules = len(decoder_module_list) module_list = encoder_module_list + decoder_module_list self.devices = devices if TORCH_PIPE: self.model = Pipe( partition_model(nn.Sequential(*module_list), balance, devices), chunks=chunks, checkpoint=checkpoint, ) else: self.model = Pipe( nn.Sequential(*module_list), balance=balance, devices=devices, chunks=chunks, checkpoint=checkpoint, ) self.encoder_max_positions = self.max_positions_helper( encoder.embedding_layer, "max_source_positions" ) self.decoder_max_positions = self.max_positions_helper( decoder.embedding_layer, "max_target_positions" ) self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None) # Note: To be populated during inference self.encoder = None self.decoder = None