示例#1
0
def bench_single_process(args):
    os.environ.update({"MASTER_ADDR": args.host})
    os.environ.update({"MASTER_PORT": "10638"})

    rpc.init_rpc(
        "worker",
        rank=0,
        world_size=1,
    )

    num_devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
    num_devices = min(args.num_devices, num_devices)
    assert num_devices > 0
    init_random_seed(0)
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    blob = make_model_and_data(args, None)
    model = blob["model"]

    balance = generate_balance(num_devices, len(model))
    model = partition_model(model, balance)
    p = Pipe(model, chunks=args.chunks, checkpoint=args.checkpoint)
    del model
    del blob["model"]

    train(blob["data"], p, blob["criterion"], blob["optimizer"],
          blob["vocab_size"], args)
示例#2
0
def test_1to3(balance, checkpoint, setup_rpc):
    if torch.cuda.device_count() < len(balance):
        pytest.skip("at least %d cuda devices required" % len(balance))

    @skippable(stash=["1to3"])
    class Layer1(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(3, 3, 1)

        def forward(self, input):
            yield stash("1to3", input)
            output = self.conv(input)
            return output  # noqa

    class Layer2(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(3, 3, 1)

        def forward(self, input):
            output = self.conv(input)
            return output

    @skippable(pop=["1to3"])
    class Layer3(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(3, 3, 1)

        def forward(self, input):
            skip_1to3 = yield pop("1to3")
            output = self.conv(input) + skip_1to3
            return output

    model = nn.Sequential(Layer1(), Layer2(), Layer3())
    model = partition_model(model, balance)
    model = Pipe(model, chunks=3, checkpoint=checkpoint)

    in_device = model.devices[0]
    out_device = model.devices[-1]

    input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True)
    output = model(input)
    loss = output.local_value().mean()
    loss.backward()

    assert torch.allclose(output.local_value().norm(),
                          torch.tensor(1039.0, device=out_device),
                          atol=6e-1)
    assert torch.allclose(input.grad.norm(),
                          torch.tensor(0.0004533053, device=in_device))
示例#3
0
 def __init__(
     self,
     args,
     dictionary,
     embed_tokens,
     no_encoder_attn=False,
     decoder_module_list=None,
 ):
     super().__init__(dictionary)
     self.register_buffer("version", torch.Tensor([3]))
     import_pipe()
     self.use_pipeline = decoder_module_list is not None
     if not self.use_pipeline:
         self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
         self.decoder_layers = nn.Sequential(*[
             TransformerDecoderLayer(args, no_encoder_attn)
             for _ in range(args.decoder_layers)
         ])
         self.decoder_output_layer = TransformerDecoderOutputLayer(
             args, embed_tokens, dictionary
         )
     else:
         decoder_balance = utils.eval_str_list(
             args.pipeline_decoder_balance, type=int
         )
         decoder_devices = utils.eval_str_list(
             args.pipeline_decoder_devices, type=int
         )
         assert sum(decoder_balance) == len(decoder_module_list), (
             f"Sum of decoder_balance={decoder_balance} is not equal "
             + f"to num_decoder_modules={len(decoder_module_list)}"
         )
         if TORCH_PIPE:
             self.model = Pipe(
                 module=partition_model(nn.Sequential(*decoder_module_list), decoder_balance, decoder_devices),
                 chunks=args.pipeline_chunks,
                 checkpoint=args.pipeline_checkpoint,
             )
         else:
             self.model = Pipe(
                 module=nn.Sequential(*decoder_module_list),
                 balance=decoder_balance,
                 devices=decoder_devices,
                 chunks=args.pipeline_chunks,
                 checkpoint=args.pipeline_checkpoint,
             )
示例#4
0
 def __init__(self,
              args,
              dictionary,
              embed_tokens,
              encoder_module_list=None):
     super().__init__(dictionary)
     self.register_buffer("version", torch.Tensor([3]))
     import_pipe()
     self.use_pipeline = encoder_module_list is not None
     if not self.use_pipeline:
         self.embedding_layer = TransformerEncoderEmbedding(
             args, embed_tokens)
         self.encoder_layers = nn.Sequential(*[
             TransformerEncoderLayer(args)
             for i in range(args.encoder_layers)
         ])
         if isinstance(embed_tokens, nn.ModuleList):
             emb_dim = sum(e.embedding_dim for e in embed_tokens)
         else:
             emb_dim = embed_tokens.embedding_dim
         self.final_layer_norm = TransformerEncoderLayerNorm(args, emb_dim)
     else:
         encoder_balance = utils.eval_str_list(
             args.pipeline_encoder_balance, type=int)
         encoder_devices = utils.eval_str_list(
             args.pipeline_encoder_devices, type=int)
         assert sum(encoder_balance) == len(encoder_module_list), (
             f"Sum of encoder_balance={encoder_balance} is not equal " +
             f"to num_encoder_modules={len(encoder_module_list)}")
         if TORCH_PIPE:
             self.model = Pipe(
                 module=partition_model(nn.Sequential(*encoder_module_list),
                                        encoder_balance, encoder_devices),
                 chunks=args.pipeline_chunks,
                 checkpoint=args.pipeline_checkpoint,
             )
         else:
             self.model = Pipe(
                 module=nn.Sequential(*encoder_module_list),
                 balance=encoder_balance,
                 devices=encoder_devices,
                 chunks=args.pipeline_chunks,
                 checkpoint=args.pipeline_checkpoint,
             )
示例#5
0
 def __init__(self, encoder, decoder, balance, devices, chunks, checkpoint):
     import_pipe()
     super().__init__()
     assert isinstance(encoder, FairseqEncoder)
     assert isinstance(decoder, FairseqDecoder)
     encoder_module_list = (
         [encoder.embedding_layer]
         + list(encoder.encoder_layers)
         + [encoder.final_layer_norm]
     )
     self.num_encoder_modules = len(encoder_module_list)
     decoder_module_list = (
         [decoder.embedding_layer]
         + list(decoder.decoder_layers)
         + [decoder.decoder_output_layer]
     )
     self.num_decoder_modules = len(decoder_module_list)
     module_list = encoder_module_list + decoder_module_list
     self.devices = devices
     if TORCH_PIPE:
         self.model = Pipe(
             partition_model(nn.Sequential(*module_list), balance, devices),
             chunks=chunks,
             checkpoint=checkpoint,
         )
     else:
         self.model = Pipe(
             nn.Sequential(*module_list),
             balance=balance,
             devices=devices,
             chunks=chunks,
             checkpoint=checkpoint,
         )
     self.encoder_max_positions = self.max_positions_helper(
         encoder.embedding_layer, "max_source_positions"
     )
     self.decoder_max_positions = self.max_positions_helper(
         decoder.embedding_layer, "max_target_positions"
     )
     self.adaptive_softmax = getattr(decoder, "adaptive_softmax", None)
     # Note: To be populated during inference
     self.encoder = None
     self.decoder = None