示例#1
0
def get_lm_model(args, device, config):
    """Get language model(based on GPT-2) used for sequence prediction."""

    ninp = config["ninp"]
    nhead = config["nhead"]
    initrange = config["initrange"]
    dropout = config["dropout"]
    vocab_size = config["vocab_size"]
    nhid = config["nhid"]
    ndecoder = config["num_decoder_layers"]

    if args.lazy_construction:
        layers = [
            LazyModule(lambda: transformer_lm.EmbeddingLayer(vocab_size, ninp, initrange)),
            LazyModule(lambda: transformer_lm.PositionalEncodingLayer(ninp, dropout)),
        ]
        for _ in range(ndecoder):
            layers.append(LazyModule(lambda: transformer_lm.TransformerDecoderLayer(ninp, nhead, nhid, dropout)))

        layers.append(LazyModule(lambda: transformer_lm.LinearLayer(ninp, vocab_size, initrange)))
        model = layers
    else:
        model = transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)

    return model
示例#2
0
def lazy_skippable_error(pipe_class):
    """Using skippable layers in combination with lazy construction is currently
    not supported, check that it raises an Exception"""
    @skippable(stash=["1to3"])
    class Layer1(nn.Linear):
        pass

    @skippable(pop=["1to3"])
    class Layer3(nn.Linear):
        pass

    model = [
        LazyModule(lambda: Layer1(10, 10)),
        LazyModule(lambda: nn.Linear(10, 10)),
        LazyModule(lambda: Layer3(10, 10)),
    ]

    with pytest.raises(
            ValueError,
            match=
            "Can't use Skippable layers with multi-process pipe and lazy construction"
    ):
        pipe_class(
            model,
            [2, 1],
            worker_map=get_worker_map(),
        )
示例#3
0
def lazy_construction(pipeline_style):
    init_count = 0

    class Custom(nn.Module):
        def __init__(self):
            super(Custom, self).__init__()
            nonlocal init_count
            init_count += 1

        def forward(self, x):
            return x

    model = [
        LazyModule(lambda: Custom()),
        LazyModule(lambda: Custom()),
        LazyModule(lambda: Custom()),
        LazyModule(lambda: Custom()),
    ]

    pipe = MultiProcessPipe(model,
                            balance=[2, 2],
                            style=pipeline_style,
                            worker_map=get_worker_map())

    assert isinstance(pipe[0], Custom)
    assert isinstance(pipe[1], Custom)
    assert len(pipe) == 2
    assert init_count == 2
示例#4
0
def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
    bn = nn.BatchNorm2d(3)
    pipe_bn = deepcopy(bn)
    pipe_fn = lambda: pipe_bn  # noqa: E731
    if lazy:
        model = [LazyModule(pipe_fn)]
    else:
        model = nn.Sequential(pipe_bn)
    pipe = MultiProcessPipe(
        model,
        balance=[1],
        style=pipeline_style,
        worker_map=get_worker_map(),
        chunks=1,
        checkpoint=checkpoint,
        deferred_batch_norm=True,
    )

    x = torch.rand(4, 3, 10, 10)
    pipe(x).mean().backward()
    bn(x).mean().backward()

    assert pipe[0].weight.grad is not None
    assert pipe[0].bias.grad is not None

    assert torch.allclose(pipe[0].weight.grad, bn.weight.grad, atol=1e-4)
    assert torch.allclose(pipe[0].bias.grad, bn.bias.grad, atol=1e-4)
示例#5
0
def make_model(args, device, ntokens):
    ninp = 2048  # embedding dimension
    nhid = 2048  # the dimension of the feedforward network model in nn.TransformerEncoder
    nhead = 32  # the number of heads in the multiheadattention models
    dropout = 0
    initrange = 0.1
    ndecoder = args.num_decoder_layers

    if args.lazy_construction:
        layers = [
            LazyModule(lambda: EmbeddingLayer(ntokens, ninp, initrange)),
            LazyModule(lambda: PositionalEncodingLayer(ninp, dropout)),
        ]
        for _ in range(ndecoder):
            layers.append(
                LazyModule(lambda: TransformerDecoderLayer(
                    ninp, nhead, nhid, dropout)))

        layers.append(
            LazyModule(lambda: LinearLayer(ninp, ntokens, initrange)))
        model = layers
    else:
        model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout,
                                       initrange, ndecoder).to(device)

    criterion = nn.CrossEntropyLoss()
    lr = 0.01  # learning rate

    def make_adam(model):
        return Adam(model.parameters(), lr=lr)

    def make_custom_optimizer(model, args):
        if args.xpipe:
            return XpipeAdam(model.parameters(), lr=lr)
        elif args.spectrain:
            return SpectrainSGDMomentum(model.parameters(), lr=lr)
        else:
            return MySGD(model.parameters(), lr=lr)

    optimizer = make_custom_optimizer
    scaler = GradScaler()

    return model, criterion, optimizer, scaler
示例#6
0
def deferred_batch_norm(checkpoint, lazy, pipe_class):
    bn = nn.BatchNorm2d(3)
    pipe_bn = deepcopy(bn)
    pipe_fn = lambda: pipe_bn  # noqa: E731
    if lazy:
        model = [LazyModule(pipe_fn)]
    else:
        model = nn.Sequential(pipe_bn)
    pipe = pipe_class(
        model, balance=[1], worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, deferred_batch_norm=True,
    )

    x = torch.rand(4, 3, 10, 10)
    pipe(x).mean().backward()
    bn(x).mean().backward()

    assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4)
    assert torch.allclose(pipe[0].running_var, bn.running_var, atol=1e-4)
示例#7
0
def reuse_lazy():
    if False:  # speed
        reused = LazyModule(lambda: nn.Linear(10, 10))
        model = [
            reused,
            nn.Linear(10, 10),
            nn.ReLU(), reused,
            nn.ReLU(), reused,
            nn.ReLU()
        ]
        # model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()]
        pipe = MultiProcessPipe(model, [3, 1, 1],
                                style=MultiProcessPipe.AsyncSchedule,
                                worker_map=get_worker_map())
        pipe.eval()
        output = pipe(torch.rand(10))

        print(f"output on {pipe.group.rank()}, {output}")
        torch.distributed.barrier()

    set_random_seed(1234)
    # test both foward
    reused = nn.Linear(10, 10)
    layers = [
        reused,
        nn.Linear(10, 10),
        nn.ReLU(), reused,
        nn.ReLU(), reused,
        nn.ReLU()
    ]
    model = nn.Sequential(*layers)
    model.eval()

    set_random_seed(1234)
    # ensure identical weights but no sharing between model and pipe
    reused = nn.Linear(10, 10)
    layers = [
        reused,
        nn.Linear(10, 10),
        nn.ReLU(), reused,
        nn.ReLU(), reused,
        nn.ReLU()
    ]
    pipe = MultiProcessPipe(layers, [3, 1, 1],
                            style=MultiProcessPipe.AsyncSchedule,
                            worker_map=get_worker_map())
    pipe.eval()
    model_optimizer = torch.optim.SGD(model.parameters(),
                                      lr=0.01,
                                      momentum=0.9)
    pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01,
                                     momentum=0.9) if len(
                                         list(pipe.parameters())) else None
    inputs = torch.rand(10)
    if False:  # speed
        model_out = model(inputs)
        pipe_out = pipe(inputs)

        torch.distributed.barrier()

        if pipe.final_stage:
            assert torch.equal(model_out, pipe_out)

    model.train()
    pipe.train()
    model_out = model(inputs)
    pipe_out = pipe(inputs)
    if pipe.final_stage:
        pipe_loss = pipe_out.mean()
        pipe_loss.backward()

    model_loss = model_out.mean()
    model_loss.backward()

    model_optimizer.step()
    if pipe_optimizer:
        pipe_optimizer.step()

    model.eval()
    pipe.eval()
    model_out = model(inputs)
    pipe_out = pipe(inputs)

    print(f"before barrier on {torch.distributed.get_rank()}")
    torch.distributed.barrier()
    print(f"after barrier on {torch.distributed.get_rank()}")

    if pipe.final_stage:
        assert torch.equal(model_out, pipe_out)