Esempio n. 1
0
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = _DoubleInput().to(device)

    parameters = list(model.parameters())
    optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])

    # Optim loop
    def closure():
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        _ = optimizer_1.step(closure=closure)
        _ = optimizer_2.step(closure=closure)

    dist.destroy_process_group()
Esempio n. 2
0
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)  # Any model works. Add one different buffer per rank

    model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)
    n_half_params = len(list(model.parameters())) // 2

    sharded_optimizer = OSS(
        params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
    )
    sharded_optimizer_2 = OSS(
        params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
    )

    sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True)

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], lr=1e-3, momentum=0.99)
    ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], lr=1e-3, momentum=0.99)
    ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)

    def check_same_model_params():
        for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
            for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                assert torch.allclose(
                    p, ddp_p, atol=1e-3
                ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
        for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
            assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"

    check_same_model_params()  # The models should stay the same in between the ranks

    for i in range(20):
        input_tensor = torch.rand((64, 2)).to(device)

        # Run DDP
        ddp_optimizer.zero_grad()
        ddp_optimizer_2.zero_grad()
        ddp_loss = ddp_model(input_tensor).abs().sum()
        ddp_loss.backward()
        ddp_optimizer.step()
        ddp_optimizer_2.step()

        # Run Sharded
        sharded_optimizer.zero_grad()
        sharded_optimizer_2.zero_grad()
        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
        sharded_loss.backward()
        sharded_optimizer.step()
        sharded_optimizer_2.step()
        check_same_model_params()

    dist.destroy_process_group()
Esempio n. 3
0
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    class _DoubleInput(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3),
                                  Linear(3, 3), Linear(3, 3), Linear(3, 3))

        def forward(self, x, y):
            x1 = self.mlp(x)
            x2 = self.mlp(y)
            return torch.cat((x1, x2), dim=1)

    model = _DoubleInput().to(device)

    parameters = list(model.parameters())
    optimizer_1 = OSS(params=parameters[:-10],
                      optim=torch.optim.SGD,
                      lr=0.01,
                      momentum=0.99)
    optimizer_2 = OSS(params=parameters[-10:],
                      optim=torch.optim.SGD,
                      lr=0.01,
                      momentum=0.99)
    ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])

    # Optim loop
    def closure():
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        _ = optimizer_1.step(closure=closure)
        _ = optimizer_2.step(closure=closure)

    dist.destroy_process_group()
Esempio n. 4
0
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name,
                             reduce_buffer_size):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)  # Any model works. Add one different buffer per rank

    BATCHS = 20

    model = _get_mlp_emb()
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)
    n_half_params = len(list(model.parameters())) // 2
    optim_settings = {"lr": 1e-3, "momentum": 0.99}

    sharded_optimizer = OSS(params=list(model.parameters())[:n_half_params],
                            optim=torch.optim.SGD,
                            **optim_settings)
    sharded_optimizer_2 = OSS(params=list(model.parameters())[n_half_params:],
                              optim=torch.optim.SGD,
                              **optim_settings)

    sharded_ddp_model = ShardedDataParallel(
        module=model,
        sharded_optimizer=[sharded_optimizer, sharded_optimizer_2],
        broadcast_buffers=True,
        reduce_buffer_size=reduce_buffer_size,
    )

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(
        list(ddp_model_single.parameters())[:n_half_params], **optim_settings)
    ddp_optimizer_2 = torch.optim.SGD(
        list(ddp_model_single.parameters())[n_half_params:], **optim_settings)
    ddp_model = DDP(ddp_model_single,
                    device_ids=[rank],
                    broadcast_buffers=True)

    check_same_model_params(
        sharded_ddp_model,
        ddp_model,
        f"DDP parity two optim test failing. differing at startup, Buffers {reduce_buffer_size}",
    )

    for i in range(BATCHS):
        input_tensor = _get_random_inputs(device)

        # Run DDP
        ddp_optimizer.zero_grad()
        ddp_optimizer_2.zero_grad()
        ddp_loss = ddp_model(input_tensor).abs().sum()
        ddp_loss.backward()
        ddp_optimizer.step()
        ddp_optimizer_2.step()
        torch.cuda.synchronize(device)

        # Run Sharded
        sharded_optimizer.zero_grad()
        sharded_optimizer_2.zero_grad()
        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
        sharded_loss.backward()
        sharded_optimizer.step()
        sharded_optimizer_2.step()
        torch.cuda.synchronize(device)

        check_same_model_params(
            sharded_ddp_model,
            ddp_model,
            f"DDP parity two optim test failing, step {i}, buffers {reduce_buffer_size}",
        )

    dist.destroy_process_group()