Ejemplo n.º 1
0
def test_catch_grad_grad():
    with temp_files_ctx(num=1) as temp_files:
        # Check that ShardedDDP exposes the original module's attributes
        dist.init_process_group(init_method="file://" + temp_files[0],
                                backend="gloo",
                                rank=0,
                                world_size=1)

        model = Sequential(Linear(2, 3), Linear(3, 3))
        model.train()
        chained_grad = torch.zeros_like(next(model.parameters()))
        chained_grad.requires_grad = True
        next(model.parameters()).grad = chained_grad

        optimizer = OSS(params=model.parameters(),
                        optim=torch.optim.SGD,
                        lr=1e-3,
                        momentum=0.99)
        ddp_model = ShardedDataParallel(model, optimizer)

        inputs = torch.rand(100, 2)
        with pytest.raises(RuntimeError):
            _ = ddp_model(inputs)

        dist.destroy_process_group()
def _get_cached_results(
    world_size,
    with_model2,
    with_sync_bn,
    with_fsdp,
    with_checkpoint,
    mixed_precision,
    flatten,
    wrap_bn,
    fp32_reduce_scatter,
    bucket_cap_mb,
):
    """Cache the training to save time. For DDP, flatten, wrap_bn etc. doesn't matter, so
    the results can be cached.
    """
    if not with_fsdp:
        flatten = None
        wrap_bn = None
        fp32_reduce_scatter = None

    key = (
        world_size,
        with_model2,
        with_sync_bn,
        with_fsdp,
        with_checkpoint,
        mixed_precision,
        flatten,
        wrap_bn,
        fp32_reduce_scatter,
        bucket_cap_mb,
    )
    global _result_cache
    if key not in _result_cache:
        # Get 4 files: 2 for dist_init and 2 for each rank to save the losses.
        with temp_files_ctx(num=2 + world_size) as temp_files:
            mp.spawn(
                _distributed_worker,
                (
                    world_size,
                    with_model2,
                    with_sync_bn,
                    with_fsdp,
                    with_checkpoint,
                    temp_files,
                    mixed_precision,
                    flatten,
                    wrap_bn,
                    fp32_reduce_scatter,
                    bucket_cap_mb,
                ),
                nprocs=world_size,
            )
            final_losses = {}
            for rank in range(world_size):
                with open(temp_files[2 + rank], "rb") as f:
                    for iter_key, loss in pickle.load(f).items():
                        final_losses[f"rank_{rank}_{iter_key}"] = loss
            _result_cache[key] = final_losses
    return _result_cache[key]
def test_train_and_eval_with_checkpointing(flatten, mixed_precision,
                                           amp_context, half_input,
                                           fsdp_wrap_ckpt):

    flatten = flatten == "flat"
    mixed_precision = mixed_precision == "fp16"
    amp_context = amp_context == "autocast"
    half_input = half_input == "halfin"
    fsdp_wrap_ckpt = fsdp_wrap_ckpt == "F->C"

    # Expecting an known bug in 4 out of 32 cases.
    if fsdp_wrap_ckpt and mixed_precision and not flatten:
        pytest.skip("known bug")

    world_size = 2

    with temp_files_ctx(2) as (temp_file_name, unused):
        mp.spawn(
            _test_func,
            args=(
                world_size,
                temp_file_name,
                unused,
                flatten,
                mixed_precision,
                amp_context,
                half_input,
                fsdp_wrap_ckpt,
            ),
            nprocs=world_size,
            join=True,
        )
Ejemplo n.º 4
0
def test_consolidation(embedding_size: int, flatten_parameters: bool):

    world_size = 2
    with in_temporary_directory():
        with temp_files_ctx(num=1) as temp_files:
            mp.spawn(_worker, (temp_files[0], world_size, embedding_size,
                               flatten_parameters),
                     nprocs=world_size)
def test_gpt2(world_size):
    # Check that having trainable unused params is fine
    backend = "gloo"
    device = "cuda"
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(run_test_gpt2,
                 args=(world_size, backend, device, temp_files[0]),
                 nprocs=world_size,
                 join=True)
Ejemplo n.º 6
0
def test_forward_overlap(world_size, flatten, mixed):
    fsdp_config = {
        "flatten_parameters": flatten == "flatten",
        "mixed_precision": mixed == "mixed",
    }
    with temp_files_ctx(2) as temp_files:
        mp.spawn(
            _distributed_worker, (world_size, fsdp_config, temp_files[0], temp_files[1]), nprocs=world_size,
        )
Ejemplo n.º 7
0
def test_train_eval_change():
    world_size = 4
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_test_train_eval_change,
            args=(world_size, temp_files[0]),
            nprocs=world_size,
            join=True,
        )
Ejemplo n.º 8
0
def test_multiple_groups(reduce_buffer_size, backend):
    world_size = 4
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_test_multiple_groups,
            args=(world_size, temp_files[0], backend, reduce_buffer_size),
            nprocs=world_size,
            join=True,
        )
Ejemplo n.º 9
0
def test_ddp_parity_two_optim(reduce_buffer_size):
    world_size = 2
    backend = dist.Backend.NCCL
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_ddp_parity_two_optim,
            args=(world_size, backend, temp_files[0], reduce_buffer_size),
            nprocs=world_size,
            join=True,
        )
Ejemplo n.º 10
0
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
             reduce_buffer_size, optimizer_type):
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_one_step,
            args=(world_size, backend, device, temp_files[0],
                  broadcast_buffers, grad_accumulation, reduce_buffer_size),
            nprocs=world_size,
            join=True,
        )
Ejemplo n.º 11
0
def test_two_optimizers():
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
    backend = "gloo"
    device = "cpu"
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(run_test_two_optimizers,
                 args=(world_size, backend, device, temp_files[0]),
                 nprocs=world_size,
                 join=True)
Ejemplo n.º 12
0
def test_training_change(reduce_buffer_size):
    world_size = 2
    backend = "nccl"
    device = "cuda"
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_test_training_change,
            args=(world_size, backend, device, temp_files[0],
                  reduce_buffer_size),
            nprocs=world_size,
            join=True,
        )
Ejemplo n.º 13
0
def test_ddp_sync_batch_norm():
    # Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
    world_size = 2
    backend = "gloo"
    device = "cuda"
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_test_ddp_sync_batch_norm,
            args=(world_size, backend, device, temp_files[0]),
            nprocs=world_size,
            join=True,
        )
Ejemplo n.º 14
0
def test_device_change(reduce_buffer_size):
    # Check that ShardedDDP handles a device change properly
    world_size = 2
    backend = "nccl"
    with temp_files_ctx(num=1) as temp_files:
        device = "cuda"
        mp.spawn(
            run_test_device_change,
            args=(world_size, backend, device, temp_files[0],
                  reduce_buffer_size),
            nprocs=world_size,
            join=True,
        )
Ejemplo n.º 15
0
def test_train_and_eval_with_checkpointing():
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    world_size = 2

    with temp_files_ctx(2) as (temp_file_name, unused):
        mp.spawn(
            _test_func,
            args=(world_size, temp_file_name, unused),
            nprocs=world_size,
            join=True,
        )
Ejemplo n.º 16
0
def test_memory_tracking_fsdp():
    """
    Check that we can collect memory traces of a simplistic model
    in the context of FSDP distributed training
    """

    with temp_files_ctx(num=2) as sync_files:
        world_size = 2
        mp.spawn(
            _layer_memory_tracking_fsdp_worker,
            (sync_files, world_size),
            nprocs=world_size,
        )
Ejemplo n.º 17
0
def test_inputs(reduce_buffer_size, backend, device):
    # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
    world_size = 2
    if backend == "nccl" and device == "cpu":
        pytest.skip("Incompatible combination, or cuda not available")
        return
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_test_two_inputs,
            args=(world_size, backend, device, temp_files[0],
                  reduce_buffer_size),
            nprocs=world_size,
            join=True,
        )
Ejemplo n.º 18
0
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn):
    mixed_precision = precision == "mixed"
    flatten = flatten == "flatten"
    wrap_bn = wrap_bn == "auto_wrap_bn"
    fp32_reduce_scatter = True if mixed_precision else None

    if torch_version() < (1, 8, 0) and flatten:
        # 1.6 and 1.7 throws this error:
        #   RuntimeError: Trying to backward through the graph a second time, but the saved
        #   intermediate results have already been freed. Specify retain_graph=True when calling
        #   backward the first time.
        pytest.skip("older pytorch throws error when flatten is used")

    world_size = 2
    expected_losses = None
    # Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt.
    for with_fsdp in [False, True]:
        for with_checkpoint in [False, True]:
            # Get 4 files: 2 for dist_init and 2 for each rank to save the losses.
            with temp_files_ctx(num=2 + world_size) as temp_files:
                mp.spawn(
                    _distributed_worker,
                    (
                        world_size,
                        with_fsdp,
                        with_checkpoint,
                        temp_files,
                        mixed_precision,
                        flatten,
                        wrap_bn,
                        fp32_reduce_scatter,
                    ),
                    nprocs=world_size,
                )
                final_losses = {}
                for rank in range(world_size):
                    with open(temp_files[2 + rank], "rb") as f:
                        final_losses[f"rank_{rank}"] = pickle.load(f)
                if expected_losses is None:
                    expected_losses = final_losses
                else:
                    print(f"fsdp: {with_fsdp} ckpt: {with_checkpoint}")
                    assert objects_are_equal(expected_losses,
                                             final_losses,
                                             raise_exception=True)
Ejemplo n.º 19
0
def test_local_state_dict_calls_state_dict_recursion():
    """Testing the case of infinite recursive when FSDP is subclassed"""
    class TestModule(FSDP):
        def __init__(self):
            super().__init__(module=nn.Linear(100, 100))

        def state_dict(self, *args, **kwargs):
            return self.local_state_dict(*args, **kwargs)

    rank = 0
    world_size = 1
    with temp_files_ctx(2) as temp_files:
        result = dist_init(rank, world_size, temp_files[0], temp_files[1])
        assert result, "Dist init failed"

        m = TestModule()
        d = m.state_dict()

        teardown()
Ejemplo n.º 20
0
def test_mixed_types():
    with temp_files_ctx(num=1) as temp_files:
        # Check that ShardedDDP exposes the original module's attributes
        dist.init_process_group(init_method="file://" + temp_files[0],
                                backend="gloo",
                                rank=0,
                                world_size=1)

        model = _get_mlp(tripwire=True)

        optimizer = OSS(params=model.parameters(),
                        optim=torch.optim.SGD,
                        lr=1e-3,
                        momentum=0.99)
        model = ShardedDataParallel(model, optimizer)
        input_tensor = torch.rand((2, 2))
        _ = model(input_tensor)

        dist.destroy_process_group()
def test_ddp_attributes():
    # Check that ShardedDDP exposes the same attributes as Pytorch's DDP
    # - is multi_device_module
    # - device_type
    with temp_files_ctx(num=1) as temp_files:
        dist.init_process_group(init_method="file://" + temp_files[0],
                                backend="gloo",
                                rank=0,
                                world_size=1)

        model = Sequential(Linear(2, 3), Linear(3, 3))
        optimizer = OSS(params=model.parameters(),
                        optim=torch.optim.SGD,
                        lr=1e-3,
                        momentum=0.99)
        ddp_model = ShardedDataParallel(model, optimizer)

        assert hasattr(ddp_model, "is_multi_device_module")
        assert hasattr(ddp_model, "device_type")
        dist.destroy_process_group()
Ejemplo n.º 22
0
def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size,
              optimizer_type, reduce_fp16, setup):
    world_size = 2
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_one_step,
            args=(
                world_size,
                setup[0],
                setup[1],
                temp_files[0],
                broadcast_buffers,
                grad_accumulation,
                reduce_buffer_size,
                optimizer_type,
                reduce_fp16,
            ),
            nprocs=world_size,
            join=True,
        )
Ejemplo n.º 23
0
def test_random_attributes():
    with temp_files_ctx(num=1) as temp_files:
        # Check that ShardedDDP exposes the original module's attributes
        dist.init_process_group(init_method="file://" + temp_files[0],
                                backend="gloo",
                                rank=0,
                                world_size=1)

        model = Sequential(Linear(2, 3), Linear(3, 3))
        model.banana = "sweet"

        optimizer = OSS(params=model.parameters(),
                        optim=torch.optim.SGD,
                        lr=1e-3,
                        momentum=0.99)
        ddp_model = ShardedDataParallel(model, optimizer)

        assert hasattr(ddp_model, "banana")
        assert not hasattr(ddp_model, "orange")

        dist.destroy_process_group()
Ejemplo n.º 24
0
def test_ddp_parity(
    reduce_buffer_size,
    grad_accumulation,
    change_train_graph,
    fp16_reduction,
    clip_grad_norm,
    amp,
    manual_reduction,
    multiple_fw,
):
    if torch_version() < (1, 8, 0):
        pytest.skip("pytorch version >= 1.8.0 required")
    if manual_reduction and change_train_graph:
        pytest.skip(
            "Skipping changing model and grad accumulation combination, makes little sense"
        )

    world_size = torch.cuda.device_count()
    backend = dist.Backend.NCCL
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_ddp_parity,
            args=(
                world_size,
                backend,
                temp_files[0],
                reduce_buffer_size,
                grad_accumulation,
                change_train_graph,
                fp16_reduction,
                clip_grad_norm,
                amp,
                manual_reduction,
                multiple_fw,
            ),
            nprocs=world_size,
            join=True,
        )
Ejemplo n.º 25
0
def temp_files():
    # dist_init needs 2 files
    with temp_files_ctx(2) as files:
        yield files
Ejemplo n.º 26
0
def test_fsdp_memory(fsdp, ckpt):
    expected = {
        ("ddp", "no_ckpt"): {
            "iter 0: start": 9,
            "iter 0: after fwd": 346,
            "iter 0: after loss": 346,
            "iter 0: after bwd": 14,
            "iter 0: after step": 14,
            "iter 0: done": 9,
            "iter 1: start": 9,
            "iter 1: after fwd": 346,
            "iter 1: after loss": 346,
            "iter 1: after bwd": 14,
            "iter 1: after step": 14,
            "iter 1: done": 9,
            "iter 2: start": 9,
            "iter 2: after fwd": 346,
            "iter 2: after loss": 346,
            "iter 2: after bwd": 14,
            "iter 2: after step": 14,
            "iter 2: done": 9,
        },
        ("fsdp", "no_ckpt"): {
            "iter 0: start": 3,
            "iter 0: after fwd": 340,
            "iter 0: after loss": 340,
            "iter 0: after bwd": 66,
            "iter 0: after step": 66,
            "iter 0: done": 3,
            "iter 1: start": 3,
            "iter 1: after fwd": 340,
            "iter 1: after loss": 340,
            "iter 1: after bwd": 66,
            "iter 1: after step": 66,
            "iter 1: done": 3,
            "iter 2: start": 3,
            "iter 2: after fwd": 340,
            "iter 2: after loss": 340,
            "iter 2: after bwd": 66,
            "iter 2: after step": 66,
            "iter 2: done": 3,
        },
        ("ddp", "ckpt"): {
            "iter 0: start": 9,
            "iter 0: after fwd": 57,
            "iter 0: after loss": 57,
            "iter 0: after bwd": 14,
            "iter 0: after step": 14,
            "iter 0: done": 9,
            "iter 1: start": 9,
            "iter 1: after fwd": 57,
            "iter 1: after loss": 57,
            "iter 1: after bwd": 14,
            "iter 1: after step": 14,
            "iter 1: done": 9,
            "iter 2: start": 9,
            "iter 2: after fwd": 57,
            "iter 2: after loss": 57,
            "iter 2: after bwd": 14,
            "iter 2: after step": 14,
            "iter 2: done": 9,
        },
        ("fsdp", "ckpt"): {
            "iter 0: start": 3,
            "iter 0: after fwd": 51,
            "iter 0: after loss": 51,
            "iter 0: after bwd": 66,
            "iter 0: after step": 66,
            "iter 0: done": 3,
            "iter 1: start": 3,
            "iter 1: after fwd": 51,
            "iter 1: after loss": 51,
            "iter 1: after bwd": 66,
            "iter 1: after step": 66,
            "iter 1: done": 3,
            "iter 2: start": 3,
            "iter 2: after fwd": 51,
            "iter 2: after loss": 51,
            "iter 2: after bwd": 66,
            "iter 2: after step": 66,
            "iter 2: done": 3,
        },
    }[(fsdp, ckpt)]
    fsdp = fsdp == "fsdp"
    ckpt = ckpt == "ckpt"
    world_size = 2
    with temp_files_ctx(num=2) as temp_files:
        mp.spawn(
            _distributed_worker, (world_size, fsdp, ckpt, temp_files[0], temp_files[1], expected), nprocs=world_size
        )
Ejemplo n.º 27
0
def test_fsdp_memory(fsdp, ckpt):
    expected = {
        ("ddp", "no_ckpt"): {
            "iter 0: start": 9,
            "iter 0: after fwd": 346,
            "iter 0: after loss": 346,
            "iter 0: after bwd": 14,
            "iter 0: after step": 17,
            "iter 0: done": 13,
            "iter 1: start": 13,
            "iter 1: after fwd": 350,
            "iter 1: after loss": 350,
            "iter 1: after bwd": 17,
            "iter 1: after step": 17,
            "iter 1: done": 13,
            "iter 2: start": 13,
            "iter 2: after fwd": 350,
            "iter 2: after loss": 350,
            "iter 2: after bwd": 17,
            "iter 2: after step": 17,
            "iter 2: done": 13,
            "iter 3: start": 13,
            "iter 3: after fwd": 350,
            "iter 3: after loss": 350,
            "iter 3: after bwd": 17,
            "iter 3: after step": 17,
            "iter 3: done": 13,
        },
        ("fsdp", "no_ckpt"): {
            "iter 0: start": 3,
            "iter 0: after fwd": 340,
            "iter 0: after loss": 340,
            "iter 0: after bwd": 16,
            "iter 0: after step": 18,
            "iter 0: done": 5,
            "iter 1: start": 5,
            "iter 1: after fwd": 342,
            "iter 1: after loss": 342,
            "iter 1: after bwd": 18,
            "iter 1: after step": 18,
            "iter 1: done": 5,
            "iter 2: start": 5,
            "iter 2: after fwd": 342,
            "iter 2: after loss": 342,
            "iter 2: after bwd": 18,
            "iter 2: after step": 18,
            "iter 2: done": 5,
            "iter 3: start": 5,
            "iter 3: after fwd": 342,
            "iter 3: after loss": 342,
            "iter 3: after bwd": 18,
            "iter 3: after step": 18,
            "iter 3: done": 5,
        },
        ("fsdp_amp_default", "no_ckpt"): {
            "iter 0: start": 28,
            "iter 0: after fwd": 630,
            "iter 0: after loss": 630,
            "iter 0: after bwd": 67,
            "iter 0: after step": 93,
            "iter 0: done": 54,
            "iter 1: start": 54,
            "iter 1: after fwd": 657,
            "iter 1: after loss": 657,
            "iter 1: after bwd": 93,
            "iter 1: after step": 93,
            "iter 1: done": 54,
            "iter 2: start": 54,
            "iter 2: after fwd": 657,
            "iter 2: after loss": 657,
            "iter 2: after bwd": 93,
            "iter 2: after step": 93,
            "iter 2: done": 54,
            "iter 3: start": 54,
            "iter 3: after fwd": 657,
            "iter 3: after loss": 657,
            "iter 3: after bwd": 93,
            "iter 3: after step": 93,
            "iter 3: done": 54,
        },
        ("fsdp_amp_compute_dtype32", "no_ckpt"): {
            "iter 0: start": 28,
            "iter 0: after fwd": 657,
            "iter 0: after loss": 657,
            "iter 0: after bwd": 67,
            "iter 0: after step": 93,
            "iter 0: done": 54,
            "iter 1: start": 54,
            "iter 1: after fwd": 684,
            "iter 1: after loss": 684,
            "iter 1: after bwd": 93,
            "iter 1: after step": 93,
            "iter 1: done": 54,
            "iter 2: start": 54,
            "iter 2: after fwd": 684,
            "iter 2: after loss": 684,
            "iter 2: after bwd": 93,
            "iter 2: after step": 93,
            "iter 2: done": 54,
            "iter 3: start": 54,
            "iter 3: after fwd": 684,
            "iter 3: after loss": 684,
            "iter 3: after bwd": 93,
            "iter 3: after step": 93,
            "iter 3: done": 54,
        },
        ("ddp", "ckpt"): {
            "iter 0: start": 9,
            "iter 0: after fwd": 57,
            "iter 0: after loss": 57,
            "iter 0: after bwd": 14,
            "iter 0: after step": 17,
            "iter 0: done": 13,
            "iter 1: start": 13,
            "iter 1: after fwd": 61,
            "iter 1: after loss": 61,
            "iter 1: after bwd": 17,
            "iter 1: after step": 17,
            "iter 1: done": 13,
            "iter 2: start": 13,
            "iter 2: after fwd": 61,
            "iter 2: after loss": 61,
            "iter 2: after bwd": 17,
            "iter 2: after step": 17,
            "iter 2: done": 13,
            "iter 3: start": 13,
            "iter 3: after fwd": 61,
            "iter 3: after loss": 61,
            "iter 3: after bwd": 17,
            "iter 3: after step": 17,
            "iter 3: done": 13,
        },
        ("fsdp", "ckpt"): {
            "iter 0: start": 3,
            "iter 0: after fwd": 51,
            "iter 0: after loss": 51,
            "iter 0: after bwd": 16,
            "iter 0: after step": 18,
            "iter 0: done": 5,
            "iter 1: start": 5,
            "iter 1: after fwd": 53,
            "iter 1: after loss": 53,
            "iter 1: after bwd": 18,
            "iter 1: after step": 18,
            "iter 1: done": 5,
            "iter 2: start": 5,
            "iter 2: after fwd": 53,
            "iter 2: after loss": 53,
            "iter 2: after bwd": 18,
            "iter 2: after step": 18,
            "iter 2: done": 5,
            "iter 3: start": 5,
            "iter 3: after fwd": 53,
            "iter 3: after loss": 53,
            "iter 3: after bwd": 18,
            "iter 3: after step": 18,
            "iter 3: done": 5,
        },
        ("fsdp_amp_default", "ckpt"): {
            "iter 0: start": 28,
            "iter 0: after fwd": 52,
            "iter 0: after loss": 52,
            "iter 0: after bwd": 67,
            "iter 0: after step": 93,
            "iter 0: done": 54,
            "iter 1: start": 54,
            "iter 1: after fwd": 79,
            "iter 1: after loss": 79,
            "iter 1: after bwd": 93,
            "iter 1: after step": 93,
            "iter 1: done": 54,
            "iter 2: start": 54,
            "iter 2: after fwd": 79,
            "iter 2: after loss": 79,
            "iter 2: after bwd": 93,
            "iter 2: after step": 93,
            "iter 2: done": 54,
            "iter 3: start": 54,
            "iter 3: after fwd": 79,
            "iter 3: after loss": 79,
            "iter 3: after bwd": 93,
            "iter 3: after step": 93,
            "iter 3: done": 54,
        },
        ("fsdp_amp_compute_dtype32", "ckpt"): {
            "iter 0: start": 28,
            "iter 0: after fwd": 52,
            "iter 0: after loss": 52,
            "iter 0: after bwd": 67,
            "iter 0: after step": 93,
            "iter 0: done": 54,
            "iter 1: start": 54,
            "iter 1: after fwd": 79,
            "iter 1: after loss": 79,
            "iter 1: after bwd": 93,
            "iter 1: after step": 93,
            "iter 1: done": 54,
            "iter 2: start": 54,
            "iter 2: after fwd": 79,
            "iter 2: after loss": 79,
            "iter 2: after bwd": 93,
            "iter 2: after step": 93,
            "iter 2: done": 54,
            "iter 3: start": 54,
            "iter 3: after fwd": 79,
            "iter 3: after loss": 79,
            "iter 3: after bwd": 93,
            "iter 3: after step": 93,
            "iter 3: done": 54,
        },
    }[(fsdp, ckpt)]

    # Compute the FSDP config.
    fsdp_config = {}

    # Set mixed precision.
    if "amp" in fsdp:
        fsdp_config["mixed_precision"] = True

    # When compute_dtype is FP32, make sure we use clear_autocast_cache.
    # Setting fp32_reduce_scatter and verbose for more code coverage.
    if "compute_dtype32" in fsdp:
        fsdp_config["compute_dtype"] = torch.float32
        fsdp_config["fp32_reduce_scatter"] = True
        fsdp_config["clear_autocast_cache"] = True
        fsdp_config["verbose"] = True

    # Using bigger hidden dimension for AMP to increase the model size
    # so that bug in handling params will show up but we don't do that
    # in the base case to keep the test fast.
    #   - hidden_dim 128: model size ~4MB
    #   - hidden_dim 512: model size ~55MB
    #   - hidden_dim 1024: model size ~200MB (seems to be too big for CI tests though)
    model_hidden_dim = 128
    if "amp" in fsdp:
        model_hidden_dim = 512

    # Get the fsdp and checkpoint flags.
    with_fsdp = "fsdp" in fsdp
    with_ckpt = ckpt == "ckpt"

    world_size = 2
    with temp_files_ctx(num=2) as temp_files:
        mp.spawn(
            _distributed_worker,
            (world_size, with_fsdp, with_ckpt, temp_files[0], temp_files[1],
             expected, model_hidden_dim, fsdp_config),
            nprocs=world_size,
        )
Ejemplo n.º 28
0
def temp_files():
    # dist_init needs 2 files + 3 files for before state, after state, in_data.
    with temp_files_ctx(5) as files:
        yield files