Ejemplo n.º 1
0
    def check(optimizer, model):
        # Just run a couple of epochs, check that the model is properly updated
        for _ in range(epochs):
            target = torch.rand((batch, target_width), device=device)
            inputs = torch.rand((batch, input_width), device=device)

            def closure():
                optimizer.zero_grad()
                output = model(inputs)
                loss = loss_fn(output, target)
                loss.backward()
                return loss

            _ = optimizer.step(closure=closure)

            # Check that all the params are the same on all ranks
            check_same_models_across_ranks(
                model, process_group, params_should_be_equal=True, check_broadcast_buffers=True
            )
def run_one_step(
    rank,
    world_size,
    backend,
    device,
    temp_file_name,
    broadcast_buffers,
    grad_accumulation,
    reduce_buffer_size,
):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    # Any model works. Add one different buffer per rank
    model = _get_mlp()
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)

    next(model.parameters()
         ).requires_grad = False  # Test non-trainable parameters

    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    broadcast_buffers=broadcast_buffers,
                                    reduce_buffer_size=reduce_buffer_size)

    # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
    check_same_models_across_ranks(ddp_model,
                                   dist.group.WORLD,
                                   params_should_be_equal=True,
                                   check_broadcast_buffers=broadcast_buffers)

    # Optim loop
    def closure():
        optimizer.zero_grad()

        with ddp_model.no_sync() if grad_accumulation else suppress():
            input_tensor = torch.rand((64, 2)).to(device)
            loss = ddp_model(input_tensor).abs().sum()
            loss.backward()
        return loss

    # The models should stay the same in between the ranks
    for i in range(5):
        _ = optimizer.step(closure=closure)
        # when running on cpu/gloo the "nodes" are not really different
        same_params = device == torch.device("cpu") or grad_accumulation
        check_same_models_across_ranks(
            ddp_model,
            dist.group.WORLD,
            params_should_be_equal=same_params,
            check_broadcast_buffers=broadcast_buffers)

    dist.destroy_process_group()
Ejemplo n.º 3
0
def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
    dist_init(rank, world_size, tempfile_name)
    device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
    torch.cuda.set_device(rank)

    # Run a dummy step so that the optimizer state dict exists
    batch, input_width, hidden, target_width = 3, 3, 3, 5
    target = torch.rand((batch, target_width), device=device)
    inputs = torch.rand((batch, input_width), device=device)

    model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden),
                                torch.nn.Linear(hidden, target_width))
    model.to(device)

    loss_fn = torch.nn.L1Loss()
    loss_fn.to(device)

    # With SGD, Momentum is required to get a state to shard
    optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99)

    def closure():
        optimizer.zero_grad()
        output = model(inputs)
        loss = loss_fn(output, target)
        loss.backward()
        return loss

    _ = optimizer.step(closure=closure)

    # Update the optimizer state on the reference rank
    optimizer.consolidate_state_dict(recipient_rank=reference_rank)

    # Fetch the state on the reference rank
    # - check that it has the correct size
    # - load it again
    if rank == reference_rank:
        optimizer_state_dict = optimizer.state_dict()
        assert len(optimizer_state_dict["state"]) == len(
            list(model.parameters()))
    else:
        optimizer_state_dict = {}

    # distribute to the other ranks
    optimizer_state_dict = sync_object_ranks(optimizer_state_dict,
                                             reference_rank, device)

    # Load the optimizer state dict
    optimizer.load_state_dict(optimizer_state_dict)

    # Check that the states are not None, but {}
    for state in optimizer.state.values():
        for _, _ in state.items():
            pass

    # Test the state dict materialization on all ranks
    _ = optimizer.step(closure=closure)
    optimizer_state_dict = optimizer.state_dict(all_ranks=True)  # one per rank
    optimizer.load_state_dict(optimizer_state_dict)
    _ = optimizer.step(closure=closure)
    check_same_models_across_ranks(model,
                                   dist.group.WORLD,
                                   params_should_be_equal=True,
                                   check_broadcast_buffers=False)

    dist.destroy_process_group()