コード例 #1
0
        def check_step():
            input_tensor = torch.rand((batch, in_channels)).to(device)

            def closure_ddp(input_tensor=input_tensor):
                ddp_optimizer.zero_grad()
                ddp_loss = ddp_model(input_tensor).abs().sum()
                ddp_loss.backward()
                return ddp_loss

            def closure_sharded(input_tensor=input_tensor):
                sharded_optimizer.zero_grad()
                sharded_loss = oss_ddp_model(input_tensor).abs().sum()
                sharded_loss.backward()
                return sharded_loss

            loss_ddp = cast(torch.Tensor,
                            ddp_optimizer.step(closure=closure_ddp))
            loss_sharded_optim = cast(
                torch.Tensor, sharded_optimizer.step(closure=closure_sharded))

            assert torch.allclose(
                loss_ddp, loss_sharded_optim
            ), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}"

            check_same_model_params(oss_ddp_model, ddp_model)
コード例 #2
0
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)  # Any model works. Add one different buffer per rank

    model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)
    n_half_params = len(list(model.parameters())) // 2

    sharded_optimizer = OSS(
        params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
    )
    sharded_optimizer_2 = OSS(
        params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
    )

    sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True)

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], lr=1e-3, momentum=0.99)
    ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], lr=1e-3, momentum=0.99)
    ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)

    def check_same_model_params():
        for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
            for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                assert torch.allclose(
                    p, ddp_p, atol=1e-3
                ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
        for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
            assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"

    check_same_model_params()  # The models should stay the same in between the ranks

    for i in range(20):
        input_tensor = torch.rand((64, 2)).to(device)

        # Run DDP
        ddp_optimizer.zero_grad()
        ddp_optimizer_2.zero_grad()
        ddp_loss = ddp_model(input_tensor).abs().sum()
        ddp_loss.backward()
        ddp_optimizer.step()
        ddp_optimizer_2.step()

        # Run Sharded
        sharded_optimizer.zero_grad()
        sharded_optimizer_2.zero_grad()
        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
        sharded_loss.backward()
        sharded_optimizer.step()
        sharded_optimizer_2.step()
        check_same_model_params()

    dist.destroy_process_group()
コード例 #3
0
def run_ddp_parity(
    rank,
    world_size,
    backend,
    temp_file_name,
    reduce_buffer_size,
    grad_accumulation,
    change_train_graph,
    fp16_reduction,
    clip_grad_norm,
    amp,
    manual_reduction,
    multiple_fw,
):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)

    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)
    NUMBER_BATCHS = 5

    # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
    print(
        f"{rank}: Checking configuration: accumulate {grad_accumulation}" +
        f" - change train graph {change_train_graph}" + f" - amp {amp}" +
        f" - manual reduction {manual_reduction}" +
        f" - buffers {reduce_buffer_size}" + f" - multiple FW {multiple_fw}",
        flush=True,
    )

    # The API should be the exact same in between the sharded and non-sharded variants, generic closure
    def closure(model,
                scaler,
                input_tensor,
                should_accumulate,
                _manual_reduction=False):
        accumulate_steps = 3 if should_accumulate else 1

        model.zero_grad()

        def step():
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    loss = model(input_tensor).abs().sum()
                    scaler.scale(loss).backward()
            else:
                loss = model(input_tensor).abs().sum()
                loss.backward()

        with model.no_sync() if should_accumulate else suppress():
            for _ in range(accumulate_steps - 1):
                step()

        if not _manual_reduction:
            step()
        else:
            with model.no_sync():
                step()

            model.reduce()

    # Any model works. Add one different buffer per rank
    model = _get_mlp_emb(multiple_fw)
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)

    # Make sure that the model starts with non-trainable, so that we check for the buckets to be
    # properly reassigned when/if this changes
    next(model.parameters()).requires_grad = False

    sharded_optimizer = OSS(params=model.parameters(),
                            optim=torch.optim.SGD,
                            lr=1e-4,
                            momentum=0.99)
    sharded_ddp_model = ShardedDataParallel(
        module=model,
        sharded_optimizer=sharded_optimizer,
        broadcast_buffers=True,
        reduce_buffer_size=reduce_buffer_size,
        reduce_fp16=fp16_reduction,
    )

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(),
                                    lr=1e-4,
                                    momentum=0.99)
    ddp_model = DDP(ddp_model_single,
                    device_ids=[rank],
                    broadcast_buffers=True,
                    find_unused_parameters=True)

    if fp16_reduction:
        from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook

        ddp_model.register_comm_hook(state=None,
                                     hook=fp16_compress_hook)  # type: ignore

    ddp_scaler = TorchGradScaler() if amp else None
    sharded_scaler = ShardedGradScaler() if amp else None

    # The model should be synchronized in between the ranks at construction time, check that
    check_same_model_params(sharded_ddp_model, ddp_model)

    # Typical training loop, check that we get the exact same results as DDP
    for i in range(NUMBER_BATCHS):
        input_tensor = _get_random_inputs(device)

        def ddp_closure(input_tensor=input_tensor):
            return closure(ddp_model, ddp_scaler, input_tensor,
                           grad_accumulation)

        def sharded_closure(input_tensor=input_tensor):
            return closure(
                sharded_ddp_model,
                sharded_scaler,
                input_tensor,
                grad_accumulation,
                _manual_reduction=manual_reduction,
            )

        # Step/scale both
        for _scaler, _closure, _optimizer in (
            (ddp_scaler, ddp_closure, ddp_optimizer),
            (sharded_scaler, sharded_closure, sharded_optimizer),
        ):
            if _scaler is not None:
                _ = _closure(input_tensor)
                _scaler.step(_optimizer)
                _scaler.update()
            else:
                _optimizer.step(_closure())

        check_same_model_params(sharded_ddp_model, ddp_model,
                                f"Rank: {rank} - Step {i} broke")

        # Check that the two grad norm are equivalent
        # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
        # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
        # be valid for ShardedDDP
        # NOTE: DDP does not handle parameters trainability being changed after the fact, see
        # https://github.com/pytorch/pytorch/blob/5781aec74ef00284e0262817a649278c2e8072bf/torch/nn/parallel/distributed.py#L471
        if clip_grad_norm and not change_train_graph:
            if torch_version() >= (1, 9, 0):
                total_norm = torch.nn.utils.clip_grad_norm_(
                    ddp_model.parameters(),
                    0.3,
                    norm_type=2.0,
                    error_if_nonfinite=False)  # type: ignore
            else:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    ddp_model.parameters(), 0.3, norm_type=2.0)  # type: ignore
            if not torch.isnan(total_norm):
                oss_total_norm = sharded_optimizer.clip_grad_norm(
                    0.3, norm_type=2.0)
                allclose = torch.allclose(oss_total_norm,
                                          total_norm,
                                          atol=1e-2 if amp else 1e-8)

                if not allclose:
                    # Debug helper if this unit test does not pass, compare the gradients in between DDP and ShardedDDP
                    for idx, (p_ddp, p_sdp) in enumerate(
                            zip(ddp_model.parameters(),
                                sharded_ddp_model.parameters())):
                        if p_ddp.grad is not None:
                            if p_sdp.grad is not None:
                                print(rank,
                                      idx,
                                      torch.norm(p_ddp.grad),
                                      torch.norm(p_sdp.grad),
                                      flush=True)
                            else:
                                print(rank,
                                      idx,
                                      torch.norm(p_ddp.grad),
                                      "not owned",
                                      flush=True)

                assert (
                    allclose
                ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}"
            else:
                print(rank, "NaN grad norm in DDP", flush=True)

        # Flip the trainability of the first parameter back and forth
        if i == 0 and change_train_graph:
            next(sharded_ddp_model.parameters()).requires_grad = not next(
                sharded_ddp_model.parameters()).requires_grad
            next(ddp_model.parameters()).requires_grad = not next(
                ddp_model.parameters()).requires_grad
            check_same_model_params(
                sharded_ddp_model, ddp_model,
                f"Rank: {rank} - Trainability refresh {i} broke")

    dist.destroy_process_group()
コード例 #4
0
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name,
                             reduce_buffer_size):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)  # Any model works. Add one different buffer per rank

    BATCHS = 20

    model = _get_mlp_emb()
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)
    n_half_params = len(list(model.parameters())) // 2
    optim_settings = {"lr": 1e-3, "momentum": 0.99}

    sharded_optimizer = OSS(params=list(model.parameters())[:n_half_params],
                            optim=torch.optim.SGD,
                            **optim_settings)
    sharded_optimizer_2 = OSS(params=list(model.parameters())[n_half_params:],
                              optim=torch.optim.SGD,
                              **optim_settings)

    sharded_ddp_model = ShardedDataParallel(
        module=model,
        sharded_optimizer=[sharded_optimizer, sharded_optimizer_2],
        broadcast_buffers=True,
        reduce_buffer_size=reduce_buffer_size,
    )

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(
        list(ddp_model_single.parameters())[:n_half_params], **optim_settings)
    ddp_optimizer_2 = torch.optim.SGD(
        list(ddp_model_single.parameters())[n_half_params:], **optim_settings)
    ddp_model = DDP(ddp_model_single,
                    device_ids=[rank],
                    broadcast_buffers=True)

    check_same_model_params(
        sharded_ddp_model,
        ddp_model,
        f"DDP parity two optim test failing. differing at startup, Buffers {reduce_buffer_size}",
    )

    for i in range(BATCHS):
        input_tensor = _get_random_inputs(device)

        # Run DDP
        ddp_optimizer.zero_grad()
        ddp_optimizer_2.zero_grad()
        ddp_loss = ddp_model(input_tensor).abs().sum()
        ddp_loss.backward()
        ddp_optimizer.step()
        ddp_optimizer_2.step()
        torch.cuda.synchronize(device)

        # Run Sharded
        sharded_optimizer.zero_grad()
        sharded_optimizer_2.zero_grad()
        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
        sharded_loss.backward()
        sharded_optimizer.step()
        sharded_optimizer_2.step()
        torch.cuda.synchronize(device)

        check_same_model_params(
            sharded_ddp_model,
            ddp_model,
            f"DDP parity two optim test failing, step {i}, buffers {reduce_buffer_size}",
        )

    dist.destroy_process_group()
コード例 #5
0
    def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer],
                                    change_train_graph: bool = False):
        # Any model works. Add one different buffer per rank
        trunk = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden),
                                    torch.nn.Linear(hidden, hidden))
        trunk.register_buffer("test_buffer", torch.ones((1)) * rank)
        trunk.to(device)

        head = torch.nn.Linear(hidden, out_channels).to(device)

        # Define a model to be trained by OSS
        oss_module = torch.nn.Sequential(trunk, head)
        oss_trainable_params = [
            {
                "params": trunk.parameters(),
                "lr": 1e-5
            },
            {
                "params": head.parameters(),
                "lr": 1e-4
            },
        ]

        optimizer_settings: Dict[Any, Any] = {}
        if isinstance(optimizer, torch.optim.SGD):
            optimizer_settings["momentum"] = 0.9

        sharded_optimizer = optim.OSS(
            params=oss_trainable_params,
            optim=optimizer,
            group=None,
            broadcast_buffer_size=2**10,
            **optimizer_settings,
        )

        oss_ddp_model = DDP(module=oss_module,
                            device_ids=[rank],
                            broadcast_buffers=True,
                            find_unused_parameters=True)

        # Define a model to be trained by normal pytorch + DDP
        ddp_trunk = copy.deepcopy(trunk)
        ddp_head = copy.deepcopy(head)
        ddp_module = torch.nn.Sequential(ddp_trunk, ddp_head)

        ddp_trainable_params = [
            {
                "params": ddp_trunk.parameters(),
                "lr": 1e-5
            },
            {
                "params": ddp_head.parameters(),
                "lr": 1e-4
            },
        ]
        ddp_optimizer = optimizer(ddp_trainable_params,
                                  **optimizer_settings)  # type: ignore
        ddp_model = DDP(module=ddp_module,
                        device_ids=[rank],
                        broadcast_buffers=True,
                        find_unused_parameters=True)

        def check_step():
            input_tensor = torch.rand((batch, in_channels)).to(device)

            def closure_ddp(input_tensor=input_tensor):
                ddp_optimizer.zero_grad()
                ddp_loss = ddp_model(input_tensor).abs().sum()
                ddp_loss.backward()
                return ddp_loss

            def closure_sharded(input_tensor=input_tensor):
                sharded_optimizer.zero_grad()
                sharded_loss = oss_ddp_model(input_tensor).abs().sum()
                sharded_loss.backward()
                return sharded_loss

            loss_ddp = cast(torch.Tensor,
                            ddp_optimizer.step(closure=closure_ddp))
            loss_sharded_optim = cast(
                torch.Tensor, sharded_optimizer.step(closure=closure_sharded))

            assert torch.allclose(
                loss_ddp, loss_sharded_optim
            ), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}"

            check_same_model_params(oss_ddp_model, ddp_model)

        # The model should be synchronized in between the ranks at construction time, check that
        check_same_model_params(oss_ddp_model, ddp_model)

        # The models should stay the same in between ddp and sharded optimizer
        for i in range(5):
            check_step()

            # Check that altering the trainable parameters does not cause DDP and OSS to diverge
            if change_train_graph:
                # Flip the first parameter from trainable to non-trainable and vice-versa
                next(ddp_module.parameters()).requires_grad = not next(
                    ddp_module.parameters()).requires_grad
                next(oss_module.parameters()).requires_grad = not next(
                    oss_module.parameters()).requires_grad
                # sharded_optimizer.refresh_trainable()

        # Check that the checkpoints are compatible
        # - get states
        ddp_state_dict = ddp_optimizer.state_dict()
        sharded_optimizer.consolidate_state_dict(recipient_rank=RECIPIENT_RANK)
        sharded_optim_state_dict = sharded_optimizer.state_dict(
        ) if rank == RECIPIENT_RANK else {}
        sharded_optim_state_dict = sync_object_ranks(sharded_optim_state_dict,
                                                     RECIPIENT_RANK, device)

        # - cross load the states
        ddp_optimizer.load_state_dict(
            sharded_optim_state_dict)  # mixup on purpose !
        sharded_optimizer.load_state_dict(ddp_state_dict)

        # - run one step and check that the models are still the same
        check_step()
コード例 #6
0
def run_state_dict_distributed(rank, world_size, tempfile_name):
    dist_init(rank, world_size, tempfile_name, backend="gloo")

    device = torch.device(rank)
    torch.manual_seed(
        rank)  # make sure that the different rank get different data

    # Setup two problems in parallel, we'll make sure that the second track (with save/load) follows the first one(untouched)
    # We split the model in two to test the multiple param groups support
    batch, input_width, hidden, target_width = 3, 20, 10, 5
    target = torch.rand((batch, target_width), device=device)
    inputs = torch.rand((batch, input_width), device=device)

    model_oss1 = torch.nn.Sequential(torch.nn.Linear(input_width, hidden),
                                     torch.nn.Linear(hidden,
                                                     hidden)).to(device)
    head_oss1 = torch.nn.Linear(hidden, target_width).to(device)

    model_oss2 = copy.deepcopy(model_oss1)
    head_oss2 = copy.deepcopy(head_oss1)

    # For this test the gradients are (all) reduced in the same way in between the torch reference and fairscale.
    # Normally OSS would use ShardedDDP and only reduce to the proper rank, but this does not change the
    # gradient norm computation from OSS and adds a dependency.
    # to keep the comparison apples-to-apples DDP is used in both cases
    model_oss1 = DDP(
        module=model_oss1,
        device_ids=[rank],
    )
    sharded_optimizer1 = optim.OSS(model_oss1.parameters(),
                                   lr=0.1,
                                   momentum=0.99)
    sharded_optimizer1.add_param_group({"params": head_oss1.parameters()})

    model_oss2 = DDP(
        module=model_oss2,
        device_ids=[rank],
    )
    sharded_optimizer2 = optim.OSS(model_oss2.parameters(),
                                   lr=0.1,
                                   momentum=0.99)
    sharded_optimizer2.add_param_group({"params": head_oss2.parameters()})

    loss_fn = torch.nn.L1Loss().to(device)

    def run_grad_step(model, head, optimizer):
        model.zero_grad()
        outputs = head(model(inputs))

    # pull the current state, broadcast it to all ranks
    sharded_optimizer2.consolidate_state_dict(
        recipient_rank=RECIPIENT_RANK)  # all ranks
    state_dict2 = sharded_optimizer2.state_dict(
    ) if rank == RECIPIENT_RANK else {}
    state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device)

    # re-create a new optimizer from scratch with absurd values, load the previous state
    sharded_optimizer2 = optim.OSS(model_oss2.parameters(),
                                   lr=1e6,
                                   momentum=0.0001)
    sharded_optimizer2.add_param_group({"params": head_oss2.parameters()})
    sharded_optimizer2.load_state_dict(state_dict2)
    check_same_model_params(
        model_oss1, model_oss2,
        "parameters of the two identical models have diverged (before any steps)"
    )

    # now take a step and check that parameters are equal
    run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
    run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
    check_same_model_params(
        model_oss1, model_oss2,
        "parameters of the two identical models have diverged (after stepping)"
    )

    # save the state dict for one model only, then distribute to the other ranks
    sharded_optimizer2.consolidate_state_dict(
        recipient_rank=RECIPIENT_RANK)  # all ranks
    state_dict2 = sharded_optimizer2.state_dict(
    ) if rank == RECIPIENT_RANK else {}
    state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device)

    # Check that the pulled state and the .param_groups attribute are in sync
    for replica in range(len(state_dict2["param_groups"])):
        for k in state_dict2["param_groups"][replica].keys():
            if k != "params":
                assert state_dict2["param_groups"][replica][
                    k] == sharded_optimizer2.param_groups[0][k]

    # take a step
    run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
    run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
    check_same_model_params(
        model_oss1, model_oss2,
        "parameters of the two identical models have diverged (after consolidating)"
    )

    # save again for one rank, then distribute to the others
    sharded_optimizer2.consolidate_state_dict(
        recipient_rank=RECIPIENT_RANK)  # all ranks
    state_dict2 = sharded_optimizer2.state_dict(
    ) if rank == RECIPIENT_RANK else {}
    state_dict2 = sync_object_ranks(state_dict2, RECIPIENT_RANK, device)

    # reload the state_dict
    sharded_optimizer2 = optim.OSS(model_oss2.parameters(),
                                   lr=0.1,
                                   momentum=0.99)
    sharded_optimizer2.add_param_group({"params": head_oss2.parameters()})
    sharded_optimizer2.load_state_dict(state_dict2)

    # take a step
    run_grad_step(model_oss1, head_oss1, sharded_optimizer1)
    run_grad_step(model_oss2, head_oss2, sharded_optimizer2)
    check_same_model_params(
        model_oss1, model_oss2,
        "parameters of the two identical models have diverged (after reloading)"
    )

    dist.destroy_process_group()
コード例 #7
0
    def check_parity(manual_reduction: bool):

        # The API should be the exact same in between the sharded and non-sharded variants, generic closure
        def closure(model,
                    scaler,
                    input_tensor,
                    should_accumulate,
                    _manual_reduction=False):
            accumulate_steps = 3 if should_accumulate else 1

            model.zero_grad()

            def step():
                if scaler is not None:
                    with torch.cuda.amp.autocast():
                        loss = model(input_tensor).abs().sum()
                        scaler.scale(loss).backward()
                else:
                    loss = model(input_tensor).abs().sum()
                    loss.backward()

            with model.no_sync() if should_accumulate else suppress():
                for _ in range(accumulate_steps - 1):
                    step()

            if not _manual_reduction:
                step()
            else:
                with model.no_sync():
                    step()

                model.reduce()

        # Any model works. Add one different buffer per rank
        model = _get_mlp()
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)

        # Make sure that the model starts with non-trainable, so that we check for the buckets to be
        # properly reassigned when/if this changes
        next(model.parameters()).requires_grad = False

        sharded_optimizer = OSS(params=model.parameters(),
                                optim=torch.optim.SGD,
                                lr=1e-4,
                                momentum=0.99)
        sharded_ddp_model = ShardedDataParallel(
            module=model,
            sharded_optimizer=sharded_optimizer,
            broadcast_buffers=True,
            reduce_buffer_size=reduce_buffer_size,
            reduce_fp16=fp16_reduction,
        )

        ddp_model_single = copy.deepcopy(model)
        ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(),
                                        lr=1e-4,
                                        momentum=0.99)
        ddp_model = DDP(ddp_model_single,
                        device_ids=[rank],
                        broadcast_buffers=True,
                        find_unused_parameters=True)

        if fp16_reduction:
            from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook

            ddp_model.register_comm_hook(
                state=None, hook=fp16_compress_hook)  # type: ignore

        ddp_scaler = TorchGradScaler() if amp else None
        sharded_scaler = ShardedGradScaler() if amp else None

        # The model should be synchronized in between the ranks at construction time, check that
        check_same_model_params(sharded_ddp_model, ddp_model)

        # Typical training loop, check that we get the exact same results as DDP
        for i in range(NUMBER_BATCHS):
            input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)

            def ddp_closure(input_tensor=input_tensor):
                return closure(ddp_model, ddp_scaler, input_tensor,
                               grad_accumulation)

            def sharded_closure(input_tensor=input_tensor):
                return closure(
                    sharded_ddp_model,
                    sharded_scaler,
                    input_tensor,
                    grad_accumulation,
                    _manual_reduction=manual_reduction,
                )

            # Step/scale both
            for _scaler, _closure, _optimizer in (
                (ddp_scaler, ddp_closure, ddp_optimizer),
                (sharded_scaler, sharded_closure, sharded_optimizer),
            ):
                if _scaler is not None:
                    _ = _closure(input_tensor)
                    _scaler.step(_optimizer)
                    _scaler.update()

            check_same_model_params(sharded_ddp_model, ddp_model,
                                    f"Rank: {rank} - Step {i} broke")

            # Check that the two grad norm are equivalent
            # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
            # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
            # be valid for ShardedDDP
            if clip_grad_norm:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    ddp_model.parameters(), 0.3, norm_type=2.0)  # type: ignore
                if not torch.isnan(total_norm):
                    oss_total_norm = sharded_optimizer.clip_grad_norm(
                        0.3, norm_type=2.0)
                    assert torch.allclose(
                        oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8
                    ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}"
                else:
                    print(rank, "NaN grad norm in DDP", flush=True)

            # Flip the trainability of the first parameter back and forth
            if i == 0 and change_train_graph:
                next(sharded_ddp_model.parameters()).requires_grad = not next(
                    sharded_ddp_model.parameters()).requires_grad
                next(ddp_model.parameters()).requires_grad = not next(
                    ddp_model.parameters()).requires_grad
                check_same_model_params(
                    sharded_ddp_model, ddp_model,
                    f"Rank: {rank} - Trainability refresh {i} broke")
コード例 #8
0
    def check(broadcast_buffers: bool, grad_accumulation: bool = False) -> None:
        # Any model works. Add one different buffer per rank
        model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)

        next(model.parameters()).requires_grad = False  # Test non-trainable parameters

        optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
        ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers)

        def check_same_model_params(same_params: bool):
            # Check that all the params are the same on all ranks
            # This should be true with and without broadcast_buffers, we don't have any real buffer here
            receptacle: List[torch.Tensor] = []

            if dist.get_backend() != "nccl":
                for pg in optimizer.param_groups:
                    for p in pg["params"]:
                        # Check the params
                        receptacle = [p.clone() for _ in range(world_size)] if rank == 0 else []
                        dist.gather(p, receptacle, dst=0)
                        if rank == 0:
                            for sync_p in receptacle[1:]:
                                if same_params:
                                    assert torch.all(torch.eq(receptacle[0], sync_p)), "Models differ in between ranks"
                                else:
                                    assert not torch.all(
                                        torch.eq(receptacle[0], sync_p)
                                    ), "Gradients should not have been synced"

                # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
                if broadcast_buffers:
                    for b in ddp_model.buffers():
                        receptacle = [b.clone() for _ in range(world_size)] if rank == 0 else []
                        dist.gather(b, receptacle, dst=0)
                        if rank == 0:
                            for sync_b in receptacle[1:]:
                                if same_params:
                                    assert torch.all(torch.eq(receptacle[0], sync_b)), "Models differ in between ranks"
                                else:
                                    assert not torch.all(
                                        torch.eq(receptacle[0], sync_b)
                                    ), "Gradients should not have been synced"

                        assert b.cpu().item() == 0.0

        # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
        check_same_model_params(same_params=True)

        # Optim loop
        def closure():
            optimizer.zero_grad()

            with ddp_model.no_sync() if grad_accumulation else suppress():
                input_tensor = torch.rand((64, 2)).to(device)
                loss = ddp_model(input_tensor).abs().sum()
                loss.backward()
            return loss

        # The models should stay the same in between the ranks
        for i in range(5):
            _ = optimizer.step(closure=closure)
            # when running on cpu/gloo the "nodes" are not really different
            same_params = device == torch.device("cpu") or grad_accumulation
            check_same_model_params(same_params=same_params)
コード例 #9
0
    def check_parity(amp: bool, accumulate: bool, change_train_graph: bool):

        # The API should be the exact same in between the sharded and non-sharded variants, generic closure
        def closure(model, scaler, input_tensor, should_accumulate):
            accumulate_steps = 3 if should_accumulate else 1

            model.zero_grad()

            def step():
                if scaler is not None:
                    with torch.cuda.amp.autocast():
                        loss = model(input_tensor).abs().sum()
                        scaler.scale(loss).backward()
                else:
                    loss = model(input_tensor).abs().sum()
                    loss.backward()

            with model.no_sync() if should_accumulate else suppress():
                for _ in range(accumulate_steps - 1):
                    step()

            step()

        # Any model works. Add one different buffer per rank
        model = Sequential(Linear(INPUTS, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)

        # Make sure that the model starts with non-trainable, so that we check for the buckets to be
        # properly reassigned when/if this changes
        next(model.parameters()).requires_grad = False

        sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-5, momentum=0.99)
        sharded_ddp_model = ShardedDataParallel(
            module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True
        )

        ddp_model_single = copy.deepcopy(model)
        ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-5, momentum=0.99)
        ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)

        ddp_scaler = TorchGradScaler() if amp else None
        sharded_ddp_scaler = ShardedGradScaler() if amp else None

        # The model should be synchronized in between the ranks at construction time, check that
        check_same_model_params(sharded_ddp_model, ddp_model)

        # Typical training loop, check that we get the exact same results as DDP
        for i in range(NUMBER_BATCHS):
            input_tensor = torch.rand((BATCH_SIZE, INPUTS)).to(device)

            def closure_ddp(input_tensor=input_tensor):
                return closure(ddp_model, ddp_scaler, input_tensor, accumulate)

            def closure_sharded(input_tensor=input_tensor):
                return closure(sharded_ddp_model, sharded_ddp_scaler, input_tensor, accumulate)

            # Step/scale both
            if ddp_scaler is not None:
                _ = closure_ddp(input_tensor)
                ddp_scaler.step(ddp_optimizer)
                ddp_scaler.update()
            else:
                ddp_optimizer.step(closure=closure_ddp)

            if sharded_ddp_scaler is not None:
                _ = closure_sharded(input_tensor)
                sharded_ddp_scaler.step(sharded_optimizer)
                sharded_ddp_scaler.update()
            else:
                sharded_optimizer.step(closure=closure_sharded)

            check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")

            # Flip the trainability of the first parameter back and forth
            if i == 0 and change_train_graph:
                next(sharded_ddp_model.parameters()).requires_grad = not next(
                    sharded_ddp_model.parameters()
                ).requires_grad
                next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
                check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")