Exemplo n.º 1
0
def test_missing_amp_autocast(tmpdir, half_op):
    hidden_dim = 4
    if half_op:
        input = torch.randn(hidden_dim).cuda().half()
        ds_linear = LinearModuleForZeroStage3(hidden_dim,
                                              hidden_dim).cuda().half()
    else:
        input = torch.randn(hidden_dim).cuda()
        ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda()

    output = ds_linear(input)
    assert output.dtype == ds_linear.weight.dtype
Exemplo n.º 2
0
def test_disable_autocast_linear(tmpdir, half_op):
    amp = pytest.importorskip("torch.cuda.amp")

    hidden_dim = 4
    if half_op:
        input = torch.randn(hidden_dim).cuda().half()
        ds_linear = LinearModuleForZeroStage3(hidden_dim,
                                              hidden_dim).cuda().half()
    else:
        input = torch.randn(hidden_dim).cuda()
        ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda()

    with amp.autocast(False):
        output = ds_linear(input)
        assert output.dtype == ds_linear.weight.dtype
Exemplo n.º 3
0
def test_disable_autocast_linear(tmpdir, half_op):
    if _skip_autocast_test():
        pytest.skip("amp autocast is not available")

    hidden_dim = 4
    if half_op:
        input = torch.randn(hidden_dim).cuda().half()
        ds_linear = LinearModuleForZeroStage3(hidden_dim,
                                              hidden_dim).cuda().half()
    else:
        input = torch.randn(hidden_dim).cuda()
        ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda()

    with torch.cuda.amp.autocast(False):
        output = ds_linear(input)
        assert output.dtype == ds_linear.weight.dtype
Exemplo n.º 4
0
def test_autocast_linear(tmpdir, half_input, half_weight):
    amp = pytest.importorskip("torch.cuda.amp")

    hidden_dim = 4
    input = torch.randn(hidden_dim).cuda()
    ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda()

    if half_input:
        input = input.half()

    if half_weight:
        ds_linear = ds_linear.half()

    with amp.autocast():
        output = ds_linear(input)
        assert output.dtype == torch.half
Exemplo n.º 5
0
def test_autocast_linear(tmpdir, half_input, half_weight):
    if _skip_autocast_test():
        pytest.skip("amp autocast is not available")

    hidden_dim = 4
    input = torch.randn(hidden_dim).cuda()
    ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda()

    if half_input:
        input = input.half()

    if half_weight:
        ds_linear = ds_linear.half()

    with torch.cuda.amp.autocast():
        output = ds_linear(input)
        assert output.dtype == torch.half