Esempio n. 1
0
    def _test_dtypes(cfg: Dict,
                     autocast,
                     in_dtype,
                     p_dtype,
                     loss_dtype,
                     reduce_dtype,
                     rank,
                     group,
                     expected_buffer_type=None):
        # Patch torch.distributed.reduce_scatter to check the dtype of the reduction
        orig_reduce_scatter = torch.distributed.reduce_scatter

        model: nn.Module = DeviceAndTypeCheckModule(
            expected_input_dtype=in_dtype,
            expected_param_dtype=p_dtype,
            expected_loss_dtype=loss_dtype,
            expected_buffer_dtype=expected_buffer_type,
        )

        def _reduce_scatter(output, input_list, **kwargs):
            for tensor in input_list:
                model._check("reduce_scatter.dtype",
                             tensor.dtype,
                             expected=reduce_dtype)
            return orig_reduce_scatter(output, input_list, **kwargs)

        with mock.patch("torch.distributed.reduce_scatter",
                        new=_reduce_scatter):
            model = FullyShardedDataParallel(model, group, **cfg).cuda()
            device = next(model.parameters()).device
            x = torch.rand(2, 5).to(device)
            with torch.cuda.amp.autocast(enabled=autocast):
                loss = model(x)
            loss.backward()
Esempio n. 2
0
def _test_func(rank, world_size, tempfile_name, unused):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    # Keep initialization deterministic.
    torch.manual_seed(0)

    model = FullyShardedDataParallel(SimpleModuleWithCheckpointing().cuda())
    optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # Collect parameter sizes to ensure these stay consistent through the steps below.
    expected_param_shapes = {
        name: tuple(param.shape)
        for name, param in model.named_parameters()
    }

    # For clarity, this is what `expected_param_shapes` should look like depending on world size:
    assert expected_param_shapes == {
        "_fsdp_wrapped_module.flat_param_0": (12, ),
        "_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0":
        (6, ),
    }, expected_param_shapes

    torch.manual_seed(1 + rank)

    # Train for a step.
    _train_step(model, optim, expected_param_shapes)

    # Now do an eval step.
    _eval_step(model, optim, expected_param_shapes)

    # And finally do another train step.
    _train_step(model, optim, expected_param_shapes)

    teardown()
Esempio n. 3
0
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class Model(Module):
        def __init__(self):
            super().__init__()
            self.inner = FSDP(Linear(4, 4), **fsdp_config)
            self.outer = Linear(4, 5)

        def forward(self, x):
            # Forward twice.
            i = self.inner(x)
            j = self.inner(x)
            return self.outer(i + j)

    model = FSDP(Model(), **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for _ in range(3):
        in_data = torch.rand(64, 4).cuda()
        in_data.requires_grad = True
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    model.assert_state(TrainingState.IDLE)
    teardown()
Esempio n. 4
0
    def _test_identical_outputs_eval(
        cls,
        model_init_fn,
        config,
        rank,
        group,
        num_steps=2,
        use_cuda=True,
        lr=0.01,
        ref_ddp_fn=None,
    ):
        if config.get("mixed_precision", False):
            autocast = True
            # Force the compute dtype to be torch.float32 so that we get
            # identical results as PyTorch DDP when using autocast. Note that
            # this will cause the all-gather to happen in FP32, which is slower
            # than necessary in most cases.
            config["compute_dtype"] = torch.float32
        else:
            autocast = False

        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrapper_config=None).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=[rank],
                                                        output_device=rank,
                                                        process_group=group)
        else:
            model = ref_ddp_fn(model, group)
        ref_loss = cls._eval_with_config(model, autocast)
        ref_state_dict = model.module.state_dict()
        if config.get("cpu_offload", False):
            for k in ref_state_dict.keys():
                ref_state_dict[k] = ref_state_dict[k].cpu()

        # Confirm we get the same behavior using FullyShardedDataParallel.
        if config.get("ssd_offload", False):
            config["offload_config"] = OffloadConfig(
                offload_type="ssd_offload")

        del config["ssd_offload"]
        model = FullyShardedDataParallel(
            model_init_fn(group=group, wrapper_config=config), group, **config)
        if not model.ssd_offload and not model.move_params_to_cpu:
            if use_cuda:
                model = model.cuda()
            else:
                assert next(model.parameters()).device == torch.device("cpu")
        shard_loss = cls._eval_with_config(model, autocast)

        try:
            torch.testing.assert_allclose(ref_loss, shard_loss)
        except (AssertionError, RuntimeError) as e:
            raise Exception(
                f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}"
            )
        if config.get("flatten_parameters", True):
            metadata = model.local_metadata_dict()
            assert isinstance(metadata, dict)
Esempio n. 5
0
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused,
               test_case):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    if test_case["assert_ref_out"]:
        with torch.no_grad():
            weight = model.weight.T.clone().cuda()
            v = torch.Tensor(test_case["inputs"][0][rank]).cuda()
            ref_out = torch.matmul(v, weight)
    model.to("cuda")
    assert isinstance(fsdp_config, dict), str(fsdp_config)
    model = FSDP(model, **fsdp_config)
    optim = SGD(model.parameters(), lr=0.1)
    inputs = test_case["inputs"]
    assert len(inputs) == 1 or not test_case["assert_ref_out"]
    assert len(inputs[0]) >= world_size
    for in_data in inputs:
        in_data = Tensor(in_data[rank]).cuda()
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    if test_case["assert_ref_out"]:
        torch.testing.assert_allclose(ref_out, out)

    model.assert_state(TrainingState.IDLE)
    teardown()
def _distributed_worker(gpu_id, world_size, with_fsdp, freezing_method,
                        tempfile_name, unused, rank_0_output, expected_state):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = _create_model(with_fsdp)
    model = model.cuda()

    # freezing the trunk using requires_grad.
    assert freezing_method in ["requires_grad", "grad_to_none"]
    if freezing_method == "requires_grad":
        for param in model.trunk.parameters():
            param.requires_grad = False

    if with_fsdp:
        model = FSDP(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id])

    if gpu_id == 0:
        print(model)

    target = torch.LongTensor([0, 1]).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
        if freezing_method == "grad_to_none":
            for param in model.trunk.parameters():
                param.grad = None
        optimizer.step()

    if with_fsdp:
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        assert objects_are_equal(expected_state,
                                 fsdp_state,
                                 raise_exception=True)
    elif rank == 0:
        state_after = model.module.cpu().state_dict()
        torch.save(state_after, rank_0_output)

    teardown()
Esempio n. 7
0
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused,
               test_case):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    my_lr = 0.1

    device = torch.device("cuda")
    if fsdp_config.get("mixed_precision", False):
        dtype = torch.float16
        fsdp_config["fp32_reduce_scatter"] = True
    else:
        dtype = torch.float32

    if test_case["assert_ref_out"]:
        with torch.no_grad():
            # Compute one iteration local output.
            fp32_weight = model.weight.T.clone().to(device)
            weight = fp32_weight.to(dtype)
            v = torch.Tensor(test_case["inputs"][0][rank]).to(device, dtype)
            ref_forward_output_my_rank = torch.matmul(v, weight)
            # Compute one iteration global weight update.
            v = torch.Tensor(test_case["inputs"][0][:world_size]).to(
                device, dtype)
            grad = v.float().sum(0).repeat(weight.shape[0], 1).div(world_size)
            ref_weight_out = fp32_weight - grad.T * my_lr
            assert ref_weight_out.dtype == torch.float32
    model.to(
        device)  # not dtype, since FSDP will manage mixed precision internally
    assert isinstance(fsdp_config, dict), str(fsdp_config)
    model = FSDP(model, **fsdp_config)
    optim = SGD(model.parameters(), lr=my_lr)
    inputs = test_case["inputs"]
    assert len(inputs) == 1 or not test_case["assert_ref_out"]
    assert len(inputs[0]) >= world_size
    for in_data in inputs:
        in_data = Tensor(in_data[rank]).to(device, dtype)
        out = model(in_data)
        out.float().sum().backward()
        optim.step()
        optim.zero_grad()
        if test_case["assert_ref_out"]:
            with model.summon_full_params():
                weight_out = model.module.weight.data.T.clone()
            # make sure we can do more fwd/bwd
            loss = model(in_data)
            loss.sum().backward()

    if test_case["assert_ref_out"]:
        torch.testing.assert_allclose(ref_forward_output_my_rank, out)
        torch.testing.assert_allclose(ref_weight_out, weight_out)

    model.assert_state(TrainingState.IDLE)
    teardown()
def _test_func(rank, world_size, tempfile_name, unused, flatten,
               mixed_precision, amp_context, half_input, fsdp_wrap_ckpt):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    # Keep initialization deterministic.
    torch.manual_seed(0)

    model = FSDP(
        SimpleModuleWithCheckpointing(flatten, mixed_precision,
                                      fsdp_wrap_ckpt).cuda(),
        flatten_parameters=flatten,
        mixed_precision=mixed_precision,
    )
    optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # Collect parameter sizes to ensure these stay consistent through the steps below.
    expected_param_shapes = {
        name: tuple(param.shape)
        for name, param in model.named_parameters()
    }

    # For clarity, this is what `expected_param_shapes` should look like depending on world size:
    if not flatten:
        assert expected_param_shapes == {
            "ffn.0.weight": (5, ),
            "ffn.0.bias": (2, ),
            "ffn.1.weight": (5, ),
            "ffn.1.bias": (2, ),
            "ffn.2.weight": (5, ),
            "ffn.2.bias": (2, ),
        }
    else:
        assert expected_param_shapes == {
            "_fsdp_wrapped_module.flat_param_0": (12, ),
            "_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0":
            (6, ),
        }, expected_param_shapes

    torch.manual_seed(1 + rank)

    # Train for a step.
    _train_step(model, optim, expected_param_shapes, amp_context,
                mixed_precision, half_input)

    # Now do an eval step.
    _eval_step(model, optim, expected_param_shapes, amp_context,
               mixed_precision, half_input)

    # And finally do another train step.
    _train_step(model, optim, expected_param_shapes, amp_context,
                mixed_precision, half_input)

    teardown()
Esempio n. 9
0
def test_input_type(temp_files, fsdp_config, input_cls):
    """Test FSDP with input being a list or a dict, only single GPU."""

    if torch_version() < (1, 7, 0):
        # This test runs multiple test cases in a single process. On 1.6.0 it
        # throw an error like this:
        #     RuntimeError: Container is already initialized! Cannot initialize it twice!
        pytest.skip(
            "older pytorch doesn't work well with single process dist_init multiple times"
        )

    result = dist_init(rank=0,
                       world_size=1,
                       filename=temp_files[0],
                       filename_rpc=temp_files[1])
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class Model(Module):
        def __init__(self):
            super().__init__()
            self.layer = Linear(4, 4)

        def forward(self, input):
            if isinstance(input, list):
                input = input[0]
            else:
                assert isinstance(input, dict), input
                input = input["in"]
            return self.layer(input)

    model = FSDP(Model(), **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for _ in range(5):
        in_data = torch.rand(64, 4).cuda()
        in_data.requires_grad = True
        if input_cls is list:
            in_data = [in_data]
        else:
            assert input_cls is dict
            in_data = {"in": in_data}

        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    model.assert_state(TrainingState.IDLE)

    teardown()
def _dist_worker(rank, world_size, files, wrap_middle, test_fn):

    # Get data from files.
    file1, file2, sd_before, sd_after, in_data = files
    sd_before = torch.load(
        sd_before, map_location=lambda storage, loc: storage.cuda(rank))
    if test_fn == "train":
        sd_after = torch.load(
            sd_after, map_location=lambda storage, loc: storage.cuda(rank))
    in_data = torch.load(in_data,
                         map_location=lambda storage, loc: storage.cuda(rank))

    result = dist_init(rank=rank,
                       world_size=world_size,
                       filename=file1,
                       filename_rpc=file2)
    assert result, "Dist init failed"

    fsdp_model = FSDP(
        # To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping
        # and make that work.
        Model(with_fsdp=True, wrap_middle=wrap_middle),
        flatten_parameters=test_fn == "optim_state",
        mixed_precision=False,
        compute_dtype=torch.float16,
    )
    fsdp_model.load_state_dict(sd_before)

    if test_fn == "train":
        _train(fsdp_model, in_data)
        objects_are_equal(sd_after,
                          fsdp_model.state_dict(),
                          raise_exception=True)
    elif test_fn == "eval":
        _eval(fsdp_model, in_data)
    elif test_fn == "optim_state":
        optim = SGD(fsdp_model.parameters(), lr=0.1)
        for _ in range(3):
            out = fsdp_model(in_data)
            out.backward()
            optim.step()
        sd = fsdp_model.gather_full_optim_state_dict(optim)
        if rank == 0:
            # There should 8 momentum buffers in the state.
            assert len(sd["state"].keys()) == 8
        else:
            assert sd is None, "only rank 0 should have the optim state"
    else:
        assert 0, f"invalid test_fn {test_fn}"

    teardown()
def test_it(fsdp_config, input_cls):
    """Test FSDP with input being a list or a dict, only single GPU."""
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    # Random port in case the next test run quickly, same port would cause conflict.
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
    torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

    try:
        assert isinstance(fsdp_config, dict), str(fsdp_config)

        class Model(Module):
            def __init__(self):
                super().__init__()
                self.layer = Linear(4, 4)

            def forward(self, input):
                if isinstance(input, list):
                    input = input[0]
                else:
                    assert isinstance(input, dict), input
                    input = input["in"]
                return self.layer(input)

        model = FSDP(Model(), **fsdp_config).cuda()
        optim = SGD(model.parameters(), lr=0.1)

        for _ in range(5):
            in_data = torch.rand(64, 4).cuda()
            in_data.requires_grad = True
            if input_cls is list:
                in_data = [in_data]
            else:
                assert input_cls is dict
                in_data = {"in": in_data}

            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        model.assert_state(TrainingState.IDLE)

    finally:
        # Clean-up is important or the next test in this file may fail to init the PG.
        torch.distributed.destroy_process_group()
        del os.environ["MASTER_ADDR"]
        del os.environ["MASTER_PORT"]
Esempio n. 12
0
    def _distributed_worker(
        gpu_id: int, with_fsdp: bool, sync_file: str, result_file: str
    ):
        torch.cuda.set_device(gpu_id)
        dist.init_process_group(
            backend="nccl", init_method="file://" + sync_file, world_size=2, rank=gpu_id
        )

        # Create the inputs
        torch.manual_seed(0)
        torch.backends.cudnn.deterministic = True
        batch = torch.randn(size=(8, 3, 224, 224)).cuda()

        # Create a fake model based on SWAV blocks
        config = TestRegnetFSDP._create_config(with_fsdp)
        model = build_model(config["MODEL"], config["OPTIMIZER"])
        model = model.cuda()
        if with_fsdp:
            model = FSDP(model)
        else:
            model = DistributedDataParallel(model, device_ids=[gpu_id])
        criterion = SwAVLoss(loss_config=config["LOSS"]["swav_loss"])
        optimizer = optim.SGD(model.parameters(), lr=1e-2)

        # Run a few iterations and collect the losses
        losses = []
        for iteration in range(5):
            out = model(batch)
            loss = criterion(out[0], torch.tensor(0.0).cuda())
            if gpu_id == 0:
                losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            if iteration <= 2:
                for name, param in model.named_parameters():
                    if "prototypes" in name:
                        param.grad = None
            optimizer.step()

        # Store the losses in a file to compare several methods
        if gpu_id == 0:
            with open(result_file, "wb") as f:
                pickle.dump(losses, f)
Esempio n. 13
0
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused,
               test_case):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    my_lr = 0.1

    if test_case["assert_ref_out"]:
        with torch.no_grad():
            # Compute one iteration local output.
            weight = model.weight.T.clone().cuda()
            v = torch.Tensor(test_case["inputs"][0][rank]).cuda()
            ref_forward_output_my_rank = torch.matmul(v, weight)
            # Compute one iteration global weight update.
            v = torch.Tensor(test_case["inputs"][0][:world_size]).cuda()
            grad = v.sum(0).repeat(weight.shape[0], 1).div(world_size)
            ref_weight_out = weight - grad.T * my_lr
    model.to("cuda")
    assert isinstance(fsdp_config, dict), str(fsdp_config)
    model = FSDP(model, **fsdp_config)
    optim = SGD(model.parameters(), lr=my_lr)
    inputs = test_case["inputs"]
    assert len(inputs) == 1 or not test_case["assert_ref_out"]
    assert len(inputs[0]) >= world_size
    for in_data in inputs:
        in_data = Tensor(in_data[rank]).cuda()
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()
        if test_case["assert_ref_out"]:
            with model.summon_full_params():
                weight_out = model.module.weight.data.T.clone()
            # make sure we can do more fwd/bwd
            loss = model(in_data)
            loss.sum().backward()

    if test_case["assert_ref_out"]:
        torch.testing.assert_allclose(ref_forward_output_my_rank, out)
        torch.testing.assert_allclose(ref_weight_out, weight_out)

    model.assert_state(TrainingState.IDLE)
    teardown()
Esempio n. 14
0
    def _test_identical_outputs(
        cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
    ):
        if config["mixed_precision"]:
            autocast = True
            # Force the compute dtype to be torch.float32 so that we get
            # identical results as PyTorch DDP when using autocast. Note that
            # this will cause the all-gather to happen in FP32, which is slower
            # than necessary in most cases.
            config["compute_dtype"] = torch.float32
        else:
            autocast = False

        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrapper_config=None).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(
                model, device_ids=[rank], output_device=rank, process_group=group
            )
        else:
            model = ref_ddp_fn(model, group)
        ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
        ref_state_dict = model.module.state_dict()

        # Confirm we get the same behavior using FullyShardedDataParallel.
        model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
        if use_cuda:
            model = model.cuda()
        else:
            assert next(model.parameters()).device == torch.device("cpu")
        shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
        shard_state_dict = model.state_dict()

        try:
            torch.testing.assert_allclose(ref_loss, shard_loss)
            assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
        except (AssertionError, RuntimeError) as e:
            raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
Esempio n. 15
0
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class InnerModel(Module):
        def __init__(self):
            super().__init__()
            self.layers = Sequential(FSDP(Linear(5, 5), **fsdp_config), )

        def forward(self, x):
            return self.layers(x)

    inner_model = InnerModel()
    model = FSDP(inner_model, **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for i in range(3):
        input = torch.rand((1, 5), dtype=torch.float).cuda()
        input.requires_grad = True
        output = model(input)
        output.sum().backward()
        optim.step()
        optim.zero_grad()
    input = torch.rand((1, 5), dtype=torch.float).cuda()
    output = model(input)

    model.assert_state(TrainingState.IDLE)

    # second time to rewrap the inner model
    rewrapped_model = FSDP(inner_model, **fsdp_config).cuda()
    rewrapped_output = rewrapped_model(input)

    assert torch.allclose(output, rewrapped_output)
    teardown()
Esempio n. 16
0
def main(local_rank, *args):
    torch.backends.cudnn.benchmark = True
    init_method = "tcp://%s:%s" % ("0.0.0.0", "9999")
    torch.distributed.init_process_group(backend="nccl",
                                         rank=local_rank,
                                         world_size=8,
                                         init_method=init_method)
    print("[Train]: Time = %s, Initialized Dist Process for Rank = %s" %
          (get_time_string(), local_rank))
    device = torch.device(
        f'cuda:{local_rank}')  # Unique only on individual node.
    torch.cuda.set_device(device)
    torch.cuda.set_device(device)
    fsdp_params = dict(mixed_precision=True,
                       flatten_parameters=True,
                       bucket_cap_mb=25,
                       reshard_after_forward=False,
                       fp32_reduce_scatter=False,
                       cpu_offload=False,
                       move_grads_to_cpu=False,
                       process_group=torch.distributed.group.WORLD)
    with enable_wrap(wrapper_cls=FullyShardedDDP, **fsdp_params):
        nn_model = nn.Sequential(
            nn.Linear(200, 200),
            wrap(
                checkpoint_wrapper(nn.Sequential(
                    nn.Linear(200, 200), nn.Linear(200, 200),
                    wrap(
                        checkpoint_wrapper(nn.Linear(200, 200),
                                           offload_to_cpu=True)),
                    checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
                    nn.Linear(200, 200)),
                                   offload_to_cpu=True)),
            checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
            nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda()

        model = FullyShardedDDP(nn_model, **fsdp_params)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=1e-4,
                                  eps=1e-7,
                                  weight_decay=1e-2,
                                  betas=(0.9, 0.99))
    optimizer.zero_grad(set_to_none=True)

    for i in range(1000):
        optimizer.zero_grad(set_to_none=True)
        fake_inputs = torch.randn(32, 200, device=device)
        fake_labels = torch.randn(32, 64, device=device)
        outputs = model(fake_inputs)
        loss = ((outputs - fake_labels)**2).mean()
        loss.backward()
        model.clip_grad_norm_(1.0)
        optimizer.step()
        if i % 100 == 0:
            print("Loss = %s, rank = %s" % (loss.item(), local_rank))

    state_dict = model.state_dict()
    nn_model = nn.Sequential(
        nn.Linear(200, 200),
        nn.Sequential(nn.Linear(200, 200), nn.Linear(200, 200),
                      nn.Linear(200, 200),
                      checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
                      nn.Linear(200, 200)),
        checkpoint_wrapper(nn.GELU(), offload_to_cpu=True),
        nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda()
    nn_model.load_state_dict(state_dict)
    print("[Train]: Time = %s, Trainable Params = %s" %
          (get_time_string(), numel(nn_model) / 1_000_000))
def _distributed_worker(
    gpu_id,
    world_size,
    with_model2,
    with_sync_bn,
    with_fsdp,
    with_checkpoint,
    files,
    mixed_precision,
    flatten,
    wrap_bn,
    fp32_reduce_scatter,
    bucket_cap_mb,
):
    filename, filename_rpc = files[:2]
    filename_loss = files[2:]

    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, filename, filename_rpc)
    assert result, "Dist init failed"

    # use False below to debug since error msg is not as good with cudnn.
    torch.backends.cudnn.enabled = True

    # these make things deterministic.
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Ensure we have multiple forward passes.
    batch = [
        torch.randn(size=(2, 3, 16, 16)).cuda(),
        torch.randn(size=(2, 3, 9, 9)).cuda(),
        torch.randn(size=(2, 3, 9, 9)).cuda(),
    ]

    if mixed_precision and not with_fsdp:
        batch = [x.half() for x in batch]

    model = _create_model(
        with_model2,
        with_sync_bn,
        with_fsdp,
        with_checkpoint,
        mixed_precision,
        flatten,
        wrap_bn,
        fp32_reduce_scatter,
        bucket_cap_mb,
    )
    model = model.cuda()

    if with_fsdp:
        model = FSDP(
            model,
            flatten_parameters=flatten,
            mixed_precision=mixed_precision,
            compute_dtype=torch.float32,
            fp32_reduce_scatter=fp32_reduce_scatter,
            bucket_cap_mb=bucket_cap_mb,
        )
        model.set_gradient_divide_factors(1.0, 2.0, True)
        no_sync_context = contextlib.suppress()
    else:
        # With DDP, we need no_sync and manual gradient reduction below because
        # it can't handle multiple forward pass + checkpointing otherwise.
        model = DistributedDataParallel(model, device_ids=[gpu_id])
        no_sync_context = model.no_sync()

    mp_context = contextlib.suppress()
    if mixed_precision:
        mp_context = torch.cuda.amp.autocast(enabled=True)

    if gpu_id == 0:
        print(model)

    target = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.long).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    losses = {}
    i = 0
    with no_sync_context:
        for iteration in range(3):
            with mp_context:
                out = model(batch)
                loss = criterion(out, target)
                print("Loss", iteration, ":", loss.item())
                losses[f"iter_{i}"] = loss
                i += 1
                optimizer.zero_grad()
                loss.backward()
            # Manual grad reduction, no autocast.
            if not with_fsdp:
                for p in model.parameters():
                    dist.all_reduce(p.grad.data)
                    p.grad.data.div_(2.0)
            # Stepping, no autocast
            optimizer.step()

    # Due to dist.all_reduce code block above with ddp.no_sync, we seem to hit a bug
    # in DDP where tensor.cpu() and torch.save() calls both hang. FSDP is not affected.
    # Therefore, we have to compare losses here instead of states.
    with open(filename_loss[rank], "wb") as f:
        pickle.dump(losses, f)

    teardown()