def test_missing_amp_autocast(tmpdir, half_op): hidden_dim = 4 if half_op: input = torch.randn(hidden_dim).cuda().half() ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda().half() else: input = torch.randn(hidden_dim).cuda() ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda() output = ds_linear(input) assert output.dtype == ds_linear.weight.dtype
def test_disable_autocast_linear(tmpdir, half_op): amp = pytest.importorskip("torch.cuda.amp") hidden_dim = 4 if half_op: input = torch.randn(hidden_dim).cuda().half() ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda().half() else: input = torch.randn(hidden_dim).cuda() ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda() with amp.autocast(False): output = ds_linear(input) assert output.dtype == ds_linear.weight.dtype
def test_disable_autocast_linear(tmpdir, half_op): if _skip_autocast_test(): pytest.skip("amp autocast is not available") hidden_dim = 4 if half_op: input = torch.randn(hidden_dim).cuda().half() ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda().half() else: input = torch.randn(hidden_dim).cuda() ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda() with torch.cuda.amp.autocast(False): output = ds_linear(input) assert output.dtype == ds_linear.weight.dtype
def test_autocast_linear(tmpdir, half_input, half_weight): amp = pytest.importorskip("torch.cuda.amp") hidden_dim = 4 input = torch.randn(hidden_dim).cuda() ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda() if half_input: input = input.half() if half_weight: ds_linear = ds_linear.half() with amp.autocast(): output = ds_linear(input) assert output.dtype == torch.half
def test_autocast_linear(tmpdir, half_input, half_weight): if _skip_autocast_test(): pytest.skip("amp autocast is not available") hidden_dim = 4 input = torch.randn(hidden_dim).cuda() ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).cuda() if half_input: input = input.half() if half_weight: ds_linear = ds_linear.half() with torch.cuda.amp.autocast(): output = ds_linear(input) assert output.dtype == torch.half