def test_update_optim_scale(): weight, bias, input = make_half_precision_params() optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16) optimizer._optim_scale_update_freq = 1 optimizer._optim_scale = 2**15 optimizer.zero_grad() loss = (weight.mv(input) + bias).pow(2).sum() loss.backward() optimizer.step() assert optimizer._optim_scale == 2**16
def test_exploding_optimizer_state(): weight = torch.tensor([[float("inf")]]).half().cuda().requires_grad_() input = torch.tensor([1.0]).half().cuda().requires_grad_() optimizer = Adam([weight], lr=1e-3, precision=Precision.PURE_FP16) optimizer._optim_scale = 1.0 optimizer.zero_grad() loss = (weight.mv(input)).pow(2).sum() loss.backward() with pytest.raises(RuntimeError): optimizer.step()
def test_step_with_grad_scaler(): weight, bias, input = make_half_precision_params() optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16) scaler = GradScaler() initial_value = None for _i in range(5): optimizer.zero_grad() loss = (weight.mv(input) + bias).pow(2).sum() if _i == 0: initial_value = loss.item() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() assert loss.item() < initial_value