Beispiel #1
0
def test_regnet(temp_files, ddp_ref, precision, flatten, sync_bn):
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    state_before, inputs, conv_bias, linear_bias, state_after = ddp_ref

    state_after = state_after[(precision, sync_bn)]

    fsdp_config = {}
    fsdp_config["mixed_precision"] = precision == "mixed"
    fsdp_config["flatten_parameters"] = flatten == "flatten"

    # When linear bias is True, DDP's AMP O1 and FSDP's default AMP O1.5 is different,
    # we force FSDP to use AMP O1 here by setting compute_dtype to float32.
    if linear_bias:
        fsdp_config["compute_dtype"] = torch.float32

    if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0):
        pytest.skip("Only CUDA 11 is supported with AMP equivalency")

    # Wrap BN half of the time.
    wrap_bn = True
    if random.randint(0, 1) == 0:
        wrap_bn = False
    # Except, always wrap BN in mixed precision + sync_bn mode, due to error of sync_bn wrapping,
    # regardless of compute_dtype.
    if fsdp_config["mixed_precision"] and sync_bn != "none":
        wrap_bn = True

    # When BN is not wrapped (i.e. not in full precision), FSDP's compute_dtype needs to
    # be fp32 to match DDP (otherwise, numerical errors happen on BN's running_mean/running_var
    # buffers).
    if fsdp_config["mixed_precision"] and not wrap_bn:
        fsdp_config["compute_dtype"] = torch.float32

    world_size = _world_size
    mp.spawn(
        _distributed_worker,
        args=(
            world_size,
            fsdp_config,
            wrap_bn,
            None,
            temp_files[0],
            temp_files[1],
            state_before,
            inputs,
            None,
            state_after,
            sync_bn,
            conv_bias,
            linear_bias,
        ),
        nprocs=world_size,
        join=True,
    )
def test1(temp_files, ddp_ref, precision, flatten):
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    state_before, inputs, state_after_fp, state_after_mp = ddp_ref

    if precision == "full":
        state_after = state_after_fp
    else:
        state_after = state_after_mp

    fsdp_config = {}
    fsdp_config["mixed_precision"] = precision == "mixed"
    fsdp_config["flatten_parameters"] = flatten == "flatten"

    if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0):
        pytest.skip("Only CUDA 11 is supported with AMP equivalency")

    # Wrap BN half of the time in full precision mode.
    wrap_bn = True
    if random.randint(0, 1) == 0:
        wrap_bn = False
    # Always wrap BN in mixed precision mode.
    if fsdp_config["mixed_precision"]:
        wrap_bn = True

    world_size = _world_size
    mp.spawn(
        _test_func,
        args=(
            world_size,
            fsdp_config,
            wrap_bn,
            None,
            temp_files[0],
            temp_files[1],
            state_before,
            inputs,
            None,
            state_after,
        ),
        nprocs=world_size,
        join=True,
    )