Exemple #1
0
def _test_func(rank, world_size, tempfile_name, unused):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    # Keep initialization deterministic.
    torch.manual_seed(0)

    model = FullyShardedDataParallel(SimpleModuleWithCheckpointing().cuda())
    optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # Collect parameter sizes to ensure these stay consistent through the steps below.
    expected_param_shapes = {
        name: tuple(param.shape)
        for name, param in model.named_parameters()
    }

    # For clarity, this is what `expected_param_shapes` should look like depending on world size:
    assert expected_param_shapes == {
        "_fsdp_wrapped_module.flat_param_0": (12, ),
        "_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0":
        (6, ),
    }, expected_param_shapes

    torch.manual_seed(1 + rank)

    # Train for a step.
    _train_step(model, optim, expected_param_shapes)

    # Now do an eval step.
    _eval_step(model, optim, expected_param_shapes)

    # And finally do another train step.
    _train_step(model, optim, expected_param_shapes)

    teardown()
Exemple #2
0
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class Model(Module):
        def __init__(self):
            super().__init__()
            self.inner = FSDP(Linear(4, 4), **fsdp_config)
            self.outer = Linear(4, 5)

        def forward(self, x):
            # Forward twice.
            i = self.inner(x)
            j = self.inner(x)
            return self.outer(i + j)

    model = FSDP(Model(), **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for _ in range(3):
        in_data = torch.rand(64, 4).cuda()
        in_data.requires_grad = True
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    model.assert_state(TrainingState.IDLE)
    teardown()
Exemple #3
0
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused,
               test_case):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    if test_case["assert_ref_out"]:
        with torch.no_grad():
            weight = model.weight.T.clone().cuda()
            v = torch.Tensor(test_case["inputs"][0][rank]).cuda()
            ref_out = torch.matmul(v, weight)
    model.to("cuda")
    assert isinstance(fsdp_config, dict), str(fsdp_config)
    model = FSDP(model, **fsdp_config)
    optim = SGD(model.parameters(), lr=0.1)
    inputs = test_case["inputs"]
    assert len(inputs) == 1 or not test_case["assert_ref_out"]
    assert len(inputs[0]) >= world_size
    for in_data in inputs:
        in_data = Tensor(in_data[rank]).cuda()
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    if test_case["assert_ref_out"]:
        torch.testing.assert_allclose(ref_out, out)

    model.assert_state(TrainingState.IDLE)
    teardown()
def _distributed_worker(gpu_id, world_size, with_fsdp, freezing_method,
                        tempfile_name, unused, rank_0_output, expected_state):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = _create_model(with_fsdp)
    model = model.cuda()

    # freezing the trunk using requires_grad.
    assert freezing_method in ["requires_grad", "grad_to_none"]
    if freezing_method == "requires_grad":
        for param in model.trunk.parameters():
            param.requires_grad = False

    if with_fsdp:
        model = FSDP(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id])

    if gpu_id == 0:
        print(model)

    target = torch.LongTensor([0, 1]).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
        if freezing_method == "grad_to_none":
            for param in model.trunk.parameters():
                param.grad = None
        optimizer.step()

    if with_fsdp:
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        assert objects_are_equal(expected_state,
                                 fsdp_state,
                                 raise_exception=True)
    elif rank == 0:
        state_after = model.module.cpu().state_dict()
        torch.save(state_after, rank_0_output)

    teardown()
def _test_func(rank, world_size, tempfile_name, unused, flatten,
               mixed_precision, amp_context, half_input, fsdp_wrap_ckpt):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    # Keep initialization deterministic.
    torch.manual_seed(0)

    model = FSDP(
        SimpleModuleWithCheckpointing(flatten, mixed_precision,
                                      fsdp_wrap_ckpt).cuda(),
        flatten_parameters=flatten,
        mixed_precision=mixed_precision,
    )
    optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # Collect parameter sizes to ensure these stay consistent through the steps below.
    expected_param_shapes = {
        name: tuple(param.shape)
        for name, param in model.named_parameters()
    }

    # For clarity, this is what `expected_param_shapes` should look like depending on world size:
    if not flatten:
        assert expected_param_shapes == {
            "ffn.0.weight": (5, ),
            "ffn.0.bias": (2, ),
            "ffn.1.weight": (5, ),
            "ffn.1.bias": (2, ),
            "ffn.2.weight": (5, ),
            "ffn.2.bias": (2, ),
        }
    else:
        assert expected_param_shapes == {
            "_fsdp_wrapped_module.flat_param_0": (12, ),
            "_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0":
            (6, ),
        }, expected_param_shapes

    torch.manual_seed(1 + rank)

    # Train for a step.
    _train_step(model, optim, expected_param_shapes, amp_context,
                mixed_precision, half_input)

    # Now do an eval step.
    _eval_step(model, optim, expected_param_shapes, amp_context,
               mixed_precision, half_input)

    # And finally do another train step.
    _train_step(model, optim, expected_param_shapes, amp_context,
                mixed_precision, half_input)

    teardown()
Exemple #6
0
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused,
               test_case):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    my_lr = 0.1

    device = torch.device("cuda")
    if fsdp_config.get("mixed_precision", False):
        dtype = torch.float16
        fsdp_config["fp32_reduce_scatter"] = True
    else:
        dtype = torch.float32

    if test_case["assert_ref_out"]:
        with torch.no_grad():
            # Compute one iteration local output.
            fp32_weight = model.weight.T.clone().to(device)
            weight = fp32_weight.to(dtype)
            v = torch.Tensor(test_case["inputs"][0][rank]).to(device, dtype)
            ref_forward_output_my_rank = torch.matmul(v, weight)
            # Compute one iteration global weight update.
            v = torch.Tensor(test_case["inputs"][0][:world_size]).to(
                device, dtype)
            grad = v.float().sum(0).repeat(weight.shape[0], 1).div(world_size)
            ref_weight_out = fp32_weight - grad.T * my_lr
            assert ref_weight_out.dtype == torch.float32
    model.to(
        device)  # not dtype, since FSDP will manage mixed precision internally
    assert isinstance(fsdp_config, dict), str(fsdp_config)
    model = FSDP(model, **fsdp_config)
    optim = SGD(model.parameters(), lr=my_lr)
    inputs = test_case["inputs"]
    assert len(inputs) == 1 or not test_case["assert_ref_out"]
    assert len(inputs[0]) >= world_size
    for in_data in inputs:
        in_data = Tensor(in_data[rank]).to(device, dtype)
        out = model(in_data)
        out.float().sum().backward()
        optim.step()
        optim.zero_grad()
        if test_case["assert_ref_out"]:
            with model.summon_full_params():
                weight_out = model.module.weight.data.T.clone()
            # make sure we can do more fwd/bwd
            loss = model(in_data)
            loss.sum().backward()

    if test_case["assert_ref_out"]:
        torch.testing.assert_allclose(ref_forward_output_my_rank, out)
        torch.testing.assert_allclose(ref_weight_out, weight_out)

    model.assert_state(TrainingState.IDLE)
    teardown()
def test_input_type(temp_files, fsdp_config, input_cls):
    """Test FSDP with input being a list or a dict, only single GPU."""

    if torch_version() < (1, 7, 0):
        # This test runs multiple test cases in a single process. On 1.6.0 it
        # throw an error like this:
        #     RuntimeError: Container is already initialized! Cannot initialize it twice!
        pytest.skip(
            "older pytorch doesn't work well with single process dist_init multiple times"
        )

    result = dist_init(rank=0,
                       world_size=1,
                       filename=temp_files[0],
                       filename_rpc=temp_files[1])
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class Model(Module):
        def __init__(self):
            super().__init__()
            self.layer = Linear(4, 4)

        def forward(self, input):
            if isinstance(input, list):
                input = input[0]
            else:
                assert isinstance(input, dict), input
                input = input["in"]
            return self.layer(input)

    model = FSDP(Model(), **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for _ in range(5):
        in_data = torch.rand(64, 4).cuda()
        in_data.requires_grad = True
        if input_cls is list:
            in_data = [in_data]
        else:
            assert input_cls is dict
            in_data = {"in": in_data}

        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    model.assert_state(TrainingState.IDLE)

    teardown()
def _dist_worker(rank, world_size, files, wrap_middle, test_fn):

    # Get data from files.
    file1, file2, sd_before, sd_after, in_data = files
    sd_before = torch.load(
        sd_before, map_location=lambda storage, loc: storage.cuda(rank))
    if test_fn == "train":
        sd_after = torch.load(
            sd_after, map_location=lambda storage, loc: storage.cuda(rank))
    in_data = torch.load(in_data,
                         map_location=lambda storage, loc: storage.cuda(rank))

    result = dist_init(rank=rank,
                       world_size=world_size,
                       filename=file1,
                       filename_rpc=file2)
    assert result, "Dist init failed"

    fsdp_model = FSDP(
        # To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping
        # and make that work.
        Model(with_fsdp=True, wrap_middle=wrap_middle),
        flatten_parameters=test_fn == "optim_state",
        mixed_precision=False,
        compute_dtype=torch.float16,
    )
    fsdp_model.load_state_dict(sd_before)

    if test_fn == "train":
        _train(fsdp_model, in_data)
        objects_are_equal(sd_after,
                          fsdp_model.state_dict(),
                          raise_exception=True)
    elif test_fn == "eval":
        _eval(fsdp_model, in_data)
    elif test_fn == "optim_state":
        optim = SGD(fsdp_model.parameters(), lr=0.1)
        for _ in range(3):
            out = fsdp_model(in_data)
            out.backward()
            optim.step()
        sd = fsdp_model.gather_full_optim_state_dict(optim)
        if rank == 0:
            # There should 8 momentum buffers in the state.
            assert len(sd["state"].keys()) == 8
        else:
            assert sd is None, "only rank 0 should have the optim state"
    else:
        assert 0, f"invalid test_fn {test_fn}"

    teardown()
Exemple #9
0
def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, filename, filename_rpc)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = create_model(with_fsdp, with_checkpoint)
    model = model.cuda()
    if with_fsdp:
        model = to_fsdp(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500)

    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-4)

    results = {}
    for iteration in range(3):
        get_cur_mem(gpu_id, results, f"iter {iteration}: start")

        out = model(batch)
        get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")

        out = sum(o.sum() for o in out[0])
        fake_loss = criterion(out, torch.tensor(0.0).cuda())
        get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")

        fake_loss.backward()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")

        optimizer.step()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after step")

        # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
        if torch_version() >= (1, 7, 0):
            model.zero_grad(set_to_none=True)
        else:
            for p in model.parameters():
                p.grad = None
        get_cur_mem(gpu_id, results, f"iter {iteration}: done")

    assert results == expected, f"{results} but expected {expected}"

    teardown()
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class Model(Module):
        def __init__(self):
            super().__init__()
            # TODO (Min): for now, we just test pytorch sync_bn here.
            #             this will grow into regnet; testing apex sync_bn, etc.
            self.conv = Conv2d(2, 2, (1, 1))
            self.bn = BatchNorm2d(2)

        def forward(self, x):
            x = self.conv(x)
            x = self.bn(x)
            return x

    # TODO (Min): check DDP equivalency.

    model = Model()
    # Note, different rank may wrap in different order due to different random
    # seeds. But results should be the same.
    if random.randint(0, 1) == 0:
        print("auto_wrap_bn, then convert_sync_batchnorm")
        model = auto_wrap_bn(model)
        model = SyncBatchNorm.convert_sync_batchnorm(model)
    else:
        print("convert_sync_batchnorm, then auto_wrap_bn")
        model = SyncBatchNorm.convert_sync_batchnorm(model)
        model = auto_wrap_bn(model)
    model = FSDP(model, **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for _ in range(3):
        in_data = torch.rand(2, 2, 2, 2).cuda()
        in_data.requires_grad = True
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    model.assert_state(TrainingState.IDLE)
    teardown()
Exemple #11
0
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused,
               test_case):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    my_lr = 0.1

    if test_case["assert_ref_out"]:
        with torch.no_grad():
            # Compute one iteration local output.
            weight = model.weight.T.clone().cuda()
            v = torch.Tensor(test_case["inputs"][0][rank]).cuda()
            ref_forward_output_my_rank = torch.matmul(v, weight)
            # Compute one iteration global weight update.
            v = torch.Tensor(test_case["inputs"][0][:world_size]).cuda()
            grad = v.sum(0).repeat(weight.shape[0], 1).div(world_size)
            ref_weight_out = weight - grad.T * my_lr
    model.to("cuda")
    assert isinstance(fsdp_config, dict), str(fsdp_config)
    model = FSDP(model, **fsdp_config)
    optim = SGD(model.parameters(), lr=my_lr)
    inputs = test_case["inputs"]
    assert len(inputs) == 1 or not test_case["assert_ref_out"]
    assert len(inputs[0]) >= world_size
    for in_data in inputs:
        in_data = Tensor(in_data[rank]).cuda()
        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()
        if test_case["assert_ref_out"]:
            with model.summon_full_params():
                weight_out = model.module.weight.data.T.clone()
            # make sure we can do more fwd/bwd
            loss = model(in_data)
            loss.sum().backward()

    if test_case["assert_ref_out"]:
        torch.testing.assert_allclose(ref_forward_output_my_rank, out)
        torch.testing.assert_allclose(ref_weight_out, weight_out)

    model.assert_state(TrainingState.IDLE)
    teardown()
def _freeze_distributed_worker(
    gpu_id,
    world_size,
    tempfile_name,
    unused,
):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    # The use case for this test is where the weights in the submodule
    # are not frozen but the leftover weights or those contained by the
    # root module are frozen. Refer to issue #758 for a real world example.
    model = FreezeModel()
    model = model.cuda()

    for param in model.head.parameters():
        param.requires_grad = False

    model = FSDP(model)

    if gpu_id == 0:
        print(model)

    target = torch.tensor([0, 1], dtype=torch.long).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    for iteration in range(3):
        out = model(batch)
        fake_loss = criterion(out, target)
        print("Loss", iteration, ":", fake_loss.item())
        optimizer.zero_grad()
        fake_loss.backward()
        optimizer.step()

    teardown()
Exemple #13
0
def test_local_state_dict_calls_state_dict_recursion():
    """Testing the case of infinite recursive when FSDP is subclassed"""
    class TestModule(FSDP):
        def __init__(self):
            super().__init__(module=nn.Linear(100, 100))

        def state_dict(self, *args, **kwargs):
            return self.local_state_dict(*args, **kwargs)

    rank = 0
    world_size = 1
    with temp_files_ctx(2) as temp_files:
        result = dist_init(rank, world_size, temp_files[0], temp_files[1])
        assert result, "Dist init failed"

        m = TestModule()
        d = m.state_dict()

        teardown()
def _dist_worker(rank, world_size, files, outer_flat, inner_flat, sharing):

    # Get data from files.
    file1, file2, sd_before, sd_after, in_data = files
    sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank))
    sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank))
    in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank))

    result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2)
    assert result, "Dist init failed"

    fsdp_model = FSDP(Model(with_fsdp=True, inner_flat=inner_flat, sharing=sharing), flatten_parameters=outer_flat)
    fsdp_model.load_state_dict(sd_before)

    _train(fsdp_model, in_data)

    objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True)

    teardown()
Exemple #15
0
def test_pre_backward_hook(temp_files):
    """Test FSDP with a model that triggers a pre_backward hook bug."""

    result = dist_init(rank=0,
                       world_size=1,
                       filename=temp_files[0],
                       filename_rpc=temp_files[1])
    assert result, "Dist init failed"

    class Model(Module):
        def __init__(self):
            super().__init__()
            self.l1 = Linear(4, 4).cuda()
            self.l2 = FSDP(Linear(4, 4).cuda())
            self.l3 = Linear(4, 4).cuda()

        def forward(self, x):
            x = self.l1(x)
            x = self.l2(x)
            inner_result = x
            x = self.l3(x)
            return x, inner_result

        def assert_and_clear_grad(self):
            for p in self.parameters():
                assert p.shape in [(4, 4), (4, ), (4 * 4 + 4, )], p.shape
                assert p.grad is not None
                p.grad = None

    model = FSDP(Model(), flatten_parameters=False).cuda()
    in_data = torch.rand(1, 4).cuda()
    for _ in range(3):
        out, _ = model(in_data)
        out.sum().backward()
        model.assert_and_clear_grad()

    teardown()
Exemple #16
0
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class InnerModel(Module):
        def __init__(self):
            super().__init__()
            self.layers = Sequential(FSDP(Linear(5, 5), **fsdp_config), )

        def forward(self, x):
            return self.layers(x)

    inner_model = InnerModel()
    model = FSDP(inner_model, **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for i in range(3):
        input = torch.rand((1, 5), dtype=torch.float).cuda()
        input.requires_grad = True
        output = model(input)
        output.sum().backward()
        optim.step()
        optim.zero_grad()
    input = torch.rand((1, 5), dtype=torch.float).cuda()
    output = model(input)

    model.assert_state(TrainingState.IDLE)

    # second time to rewrap the inner model
    rewrapped_model = FSDP(inner_model, **fsdp_config).cuda()
    rewrapped_output = rewrapped_model(input)

    assert torch.allclose(output, rewrapped_output)
    teardown()
def _distributed_worker(
    gpu_id,
    world_size,
    with_model2,
    with_sync_bn,
    with_fsdp,
    with_checkpoint,
    files,
    mixed_precision,
    flatten,
    wrap_bn,
    fp32_reduce_scatter,
    bucket_cap_mb,
):
    filename, filename_rpc = files[:2]
    filename_loss = files[2:]

    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, filename, filename_rpc)
    assert result, "Dist init failed"

    # use False below to debug since error msg is not as good with cudnn.
    torch.backends.cudnn.enabled = True

    # these make things deterministic.
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Ensure we have multiple forward passes.
    batch = [
        torch.randn(size=(2, 3, 16, 16)).cuda(),
        torch.randn(size=(2, 3, 9, 9)).cuda(),
        torch.randn(size=(2, 3, 9, 9)).cuda(),
    ]

    if mixed_precision and not with_fsdp:
        batch = [x.half() for x in batch]

    model = _create_model(
        with_model2,
        with_sync_bn,
        with_fsdp,
        with_checkpoint,
        mixed_precision,
        flatten,
        wrap_bn,
        fp32_reduce_scatter,
        bucket_cap_mb,
    )
    model = model.cuda()

    if with_fsdp:
        model = FSDP(
            model,
            flatten_parameters=flatten,
            mixed_precision=mixed_precision,
            compute_dtype=torch.float32,
            fp32_reduce_scatter=fp32_reduce_scatter,
            bucket_cap_mb=bucket_cap_mb,
        )
        model.set_gradient_divide_factors(1.0, 2.0, True)
        no_sync_context = contextlib.suppress()
    else:
        # With DDP, we need no_sync and manual gradient reduction below because
        # it can't handle multiple forward pass + checkpointing otherwise.
        model = DistributedDataParallel(model, device_ids=[gpu_id])
        no_sync_context = model.no_sync()

    mp_context = contextlib.suppress()
    if mixed_precision:
        mp_context = torch.cuda.amp.autocast(enabled=True)

    if gpu_id == 0:
        print(model)

    target = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.long).cuda()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    losses = {}
    i = 0
    with no_sync_context:
        for iteration in range(3):
            with mp_context:
                out = model(batch)
                loss = criterion(out, target)
                print("Loss", iteration, ":", loss.item())
                losses[f"iter_{i}"] = loss
                i += 1
                optimizer.zero_grad()
                loss.backward()
            # Manual grad reduction, no autocast.
            if not with_fsdp:
                for p in model.parameters():
                    dist.all_reduce(p.grad.data)
                    p.grad.data.div_(2.0)
            # Stepping, no autocast
            optimizer.step()

    # Due to dist.all_reduce code block above with ddp.no_sync, we seem to hit a bug
    # in DDP where tensor.cpu() and torch.save() calls both hang. FSDP is not affected.
    # Therefore, we have to compare losses here instead of states.
    with open(filename_loss[rank], "wb") as f:
        pickle.dump(losses, f)

    teardown()
Exemple #18
0
def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint,
                        filename, filename_rpc, expected, model_hidden_dim,
                        fsdp_config):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, filename, filename_rpc)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True

    # Note that FSDP auto-cast the input in AMP mode. So we don't need to call half() here.
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = create_model(with_fsdp, with_checkpoint, model_hidden_dim,
                         fsdp_config)
    model = model.cuda()
    if with_fsdp:
        model = to_fsdp(model, fsdp_config)
    else:
        model = DistributedDataParallel(model,
                                        device_ids=[gpu_id],
                                        bucket_cap_mb=500)

    # We enable momentum so that after the first iteration, the optimizer state is added
    # to the total memory used.
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

    # Set AMP context if needed.
    context = contextlib.suppress()
    if "mixed_precision" in fsdp_config and fsdp_config["mixed_precision"]:
        context = torch.cuda.amp.autocast(enabled=True)

    # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this
    # test but on much bigger scale tests). We run 4 iterations here just in case it happens.
    iterations = 4

    results = {}  # results of memory stats
    for iteration in range(iterations):
        get_cur_mem(gpu_id, results, f"iter {iteration}: start")

        with context:
            out = model(batch)
            get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")

            out = sum(o.sum() for o in out[0])
            fake_loss = criterion(out, torch.tensor(0.0).cuda())
            get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")

        fake_loss.backward()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")

        optimizer.step()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after step")

        # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
        if torch_version() >= (1, 7, 0):
            model.zero_grad(set_to_none=True)
        else:
            for p in model.parameters():
                p.grad = None
        get_cur_mem(gpu_id, results, f"iter {iteration}: done")

    dump_all_tensors(gpu_id)
    print(results)

    def cmp(results, expected):
        ret = ""
        assert results.keys() == expected.keys(
        ), f"{list(results.keys())} vs. {list(expected.keys())}"
        for k, v in results.items():
            exp = expected[k]
            if abs(exp - v) > 1:  # allow 1MB rounding differences
                ret += f"{k}: got {v}, expected {exp}\n"
        return ret

    output = cmp(results, expected)
    assert not output, output

    teardown()
Exemple #19
0
def _distributed_worker(
    gpu_id, world_size, fsdp_config, tempfile, tempfile_rpc,
):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, tempfile, tempfile_rpc)
    assert result, "Dist init failed"

    # Save the original torch.distributed.all_gather function since we will
    # patch it to include an artificial delay.
    orig_all_gather = torch.distributed.all_gather

    def run(compute_cycles, all_gather_cycles):
        has_params = all_gather_cycles > 0
        model = _create_model(fsdp_config, compute_cycles, has_params)

        # Get the input and sets the input's requires_grad to True because
        # we have a fake compute in the forward pass.
        batch = torch.rand(1).cuda()
        batch.requires_grad = True

        # We run 20 iterations but only collect timing data from the minimal 10
        # data points because nondeterministic system events can disturb the timing.
        cpu_iter = Min10()
        cpu_wait = Min10()
        gpu_compute = Min10()
        gpu_total = Min10()
        for _ in range(20):
            # Get two events for measuring the overall time.
            e1 = Event(enable_timing=True)
            e2 = Event(enable_timing=True)

            cpu_start = time.process_time()

            all_gather_called = False

            def _delayed_all_gather(*args, **kwargs):
                nonlocal all_gather_called
                all_gather_called = True
                torch.cuda._sleep(all_gather_cycles)
                return orig_all_gather(*args, **kwargs)

            # forward pass
            #
            # Even though both e1 & e2 are on the compute stream, since
            # compute depends on all_gather, e2-e1 includes all_gather time.
            e1.record()
            with patch("torch.distributed.all_gather", _delayed_all_gather):
                out = model(batch)
                if has_params and world_size > 1:
                    assert all_gather_called
                else:
                    assert not all_gather_called
            e2.record()

            # backward pass
            out.backward()
            if torch_version() >= (1, 7, 0):
                model.zero_grad(set_to_none=True)
            else:
                for p in model.parameters():
                    p.grad = None

            cpu_iter_time = time.process_time() - cpu_start

            # wait for gpu
            out.item()
            cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time

            # get sum of the compute time
            times = []
            for mod in model.modules():
                if not isinstance(mod, Layer):
                    continue
                times.append(mod.get_time())

            # get gpu compute + all_gather time
            overall_gpu_time = e1.elapsed_time(e2)

            cpu_iter.add(cpu_iter_time)
            cpu_wait.add(cpu_wait_for_gpu_time)
            gpu_compute.add(sum(times))
            gpu_total.add(overall_gpu_time)

        del model
        return {
            "cpu_iter": cpu_iter.avg(),
            "cpu_wait": cpu_wait.avg(),
            "gpu_compute": gpu_compute.avg(),
            "gpu_total": gpu_total.avg(),
        }

    sleep_cycles = int(100 * get_cycles_per_ms())

    e1 = run(0, 0)  # no compute, no all-gather
    e2 = run(0, sleep_cycles)  # no compute, only all-gather
    e3 = run(sleep_cycles, 0)  # only compute, no all-gather
    e4 = run(sleep_cycles, sleep_cycles)  # both compute and all-gather
    debug_string = f"\nrank{rank}:\n  e1: {e1}\n  e2: {e2}\n  e3: {e3}\n  e4: {e4}"
    print(debug_string)

    # Check the cpu/gpu timing. CPU should run ahead of GPU. Therefore, cpu-gpu
    # wait should be long, except when there is no real work on GPU.
    #
    # If the assertions fail below, we likely have a cpu-gpu wait in the forward/backward pass.
    short = [e1["cpu_iter"], e2["cpu_iter"], e3["cpu_iter"], e4["cpu_iter"], e1["cpu_wait"]]
    long = [e3["cpu_wait"], e4["cpu_wait"]]
    if world_size == 1:
        short.append(e2["cpu_wait"])  # all gather should not be happening.
    else:
        long.append(e2["cpu_wait"])  # all gather should happen and prolong the cpu-gpu wait.
    for s in short:
        for l in long:
            # 10X longer is a safe margin, since the GPU work timing is around 100X more
            # of that of the CPU.
            assert s * 10 < l, f"{s} * 10 < {l} in " + debug_string

    # Check the GPU timing.
    short = [e1["gpu_compute"], e1["gpu_total"], e2["gpu_compute"]]
    long = [e3["gpu_compute"], e3["gpu_total"], e4["gpu_compute"], e4["gpu_total"]]
    if world_size == 1:
        short.append(e2["gpu_total"])  # all gather should not be happening.
    else:
        long.append(e2["gpu_total"])  # all gather should happen and prolong the cpu-gpu wait.
    for s in short:
        for l in long:
            # 10X longer is a safe margin, since the time is around 100X longer
            # when there is work on GPU vs. no work.
            assert s * 10 < l, f"{s} * 10 < {l} in " + debug_string

    # Check the GPU overlapping when there is all-gather.
    if world_size > 1:
        compute_only = e3["gpu_compute"]
        all_gather_only = e2["gpu_total"]
        both = e4["gpu_total"]
        assert compute_only + all_gather_only > 1.1 * both, (
            f"{compute_only} + {all_gather_only} > 1.1 * {both} in " + debug_string
        )

    teardown()
Exemple #20
0
def _distributed_worker(
    rank,
    world_size,
    fsdp_config,
    fsdp_wrap_bn,
    ddp_mixed_precision,
    tempfile_name,
    unused,
    state_before,
    inputs,
    rank_0_output,
    state_after,
    sync_bn,
    conv_bias,
    linear_bias,
):
    torch.backends.cudnn.deterministic = True

    result = dist_init(rank, world_size, tempfile_name, unused)
    assert result, "Dist init failed"

    ddp = True
    if fsdp_config:
        ddp = False
        assert isinstance(fsdp_config, dict), str(fsdp_config)
        if fsdp_config["mixed_precision"]:
            # To match DDP in AMP -O1, we need fp32 reduce scatter.
            fsdp_config["fp32_reduce_scatter"] = True

    model = Model(conv_bias, linear_bias)
    model.load_state_dict(state_before)
    model = model.cuda()

    class DummyScaler:
        def scale(self, loss):
            return loss

        def step(self, optim):
            optim.step()

        def update(self):
            pass

    scaler = DummyScaler()
    if ddp:
        if sync_bn == "pytorch":
            model = pytorch_bn_converter(model)
        model = DDP(model, device_ids=[rank], broadcast_buffers=True)
        if ddp_mixed_precision:
            scaler = GradScaler()
    else:
        # Note, different rank may wrap in different order due to different random
        # seeds. But results should be the same.
        if random.randint(0, 1) == 0:
            print(f"auto_wrap_bn {fsdp_wrap_bn}, then sync_bn {sync_bn}")
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
            if sync_bn == "pytorch":
                model = pytorch_bn_converter(model)
        else:
            print(f"sync_bn {sync_bn}, then auto_wrap_bn {fsdp_wrap_bn}")
            if sync_bn == "pytorch":
                model = pytorch_bn_converter(model)
            if fsdp_wrap_bn:
                model = auto_wrap_bn(model, _single_rank_pg)
        model = FSDP(model, **fsdp_config).cuda()
        if fsdp_config["mixed_precision"]:
            scaler = ShardedGradScaler()
        # Print the model for verification.
        if rank == 0:
            print(model)
    optim = SGD(model.parameters(), lr=0.1)
    loss_func = CrossEntropyLoss()

    for in_data in inputs[rank]:
        in_data = in_data.cuda()
        context = contextlib.suppress()
        if ddp and ddp_mixed_precision:
            in_data = in_data.half()
            context = torch.cuda.amp.autocast(enabled=True)
        if not ddp and fsdp_config["mixed_precision"]:
            context = torch.cuda.amp.autocast(enabled=True)
        with context:
            out = model(in_data)
            fake_label = torch.zeros(1, dtype=torch.long).cuda()
            loss = loss_func(out.unsqueeze(0), fake_label)
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()
        optim.zero_grad()

    if ddp:
        # Save the rank 0 state_dict to the output file.
        if rank == 0:
            state_after = model.module.cpu().state_dict()
            torch.save(state_after, rank_0_output)
    else:
        model.assert_state(TrainingState.IDLE)
        # Ensure final state equals to the state_after.
        fsdp_state = model.state_dict()
        # Move tensors to CPU to compare numerics.
        for k, v in fsdp_state.items():
            fsdp_state[k] = v.cpu()
        # Change False to True to enable this when you want to debug the mismatch.
        if False and rank == 0:

            def dump(d):
                for k, v in d.items():
                    print(k, v)

            dump(state_after)
            dump(fsdp_state)
        # If sync_bn is used, all ranks should have the same state, so we can compare with
        # rank 0 state on every rank. Otherwise, only compare rank 0 with rank 0.
        if sync_bn != "none" or rank == 0:
            assert objects_are_equal(state_after,
                                     fsdp_state,
                                     raise_exception=True)

    teardown()