def _dist_worker(rank, world_size, files, wrap_middle, test_fn):

    # Get data from files.
    file1, file2, sd_before, sd_after, in_data = files
    sd_before = torch.load(
        sd_before, map_location=lambda storage, loc: storage.cuda(rank))
    if test_fn == "train":
        sd_after = torch.load(
            sd_after, map_location=lambda storage, loc: storage.cuda(rank))
    in_data = torch.load(in_data,
                         map_location=lambda storage, loc: storage.cuda(rank))

    result = dist_init(rank=rank,
                       world_size=world_size,
                       filename=file1,
                       filename_rpc=file2)
    assert result, "Dist init failed"

    fsdp_model = FSDP(
        # To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping
        # and make that work.
        Model(with_fsdp=True, wrap_middle=wrap_middle),
        flatten_parameters=test_fn == "optim_state",
        mixed_precision=False,
        compute_dtype=torch.float16,
    )
    fsdp_model.load_state_dict(sd_before)

    if test_fn == "train":
        _train(fsdp_model, in_data)
        objects_are_equal(sd_after,
                          fsdp_model.state_dict(),
                          raise_exception=True)
    elif test_fn == "eval":
        _eval(fsdp_model, in_data)
    elif test_fn == "optim_state":
        optim = SGD(fsdp_model.parameters(), lr=0.1)
        for _ in range(3):
            out = fsdp_model(in_data)
            out.backward()
            optim.step()
        sd = fsdp_model.gather_full_optim_state_dict(optim)
        if rank == 0:
            # There should 8 momentum buffers in the state.
            assert len(sd["state"].keys()) == 8
        else:
            assert sd is None, "only rank 0 should have the optim state"
    else:
        assert 0, f"invalid test_fn {test_fn}"

    teardown()
def _dist_worker(rank, world_size, files, outer_flat, inner_flat, sharing):

    # Get data from files.
    file1, file2, sd_before, sd_after, in_data = files
    sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank))
    sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank))
    in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank))

    result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2)
    assert result, "Dist init failed"

    fsdp_model = FSDP(Model(with_fsdp=True, inner_flat=inner_flat, sharing=sharing), flatten_parameters=outer_flat)
    fsdp_model.load_state_dict(sd_before)

    _train(fsdp_model, in_data)

    objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True)

    teardown()