class TestGradientAccumulation:
    def test_gradient_accumulation(self, mocker):
        desired_bs, accum_steps = 32, 4
        real_bs = desired_bs // accum_steps
        num_iter = 10
        task = mod_task.XORTask(batch_size=real_bs)

        lr_finder = prepare_lr_finder(task)
        spy = mocker.spy(lr_finder, "criterion")

        lr_finder.range_test(task.train_loader,
                             num_iter=num_iter,
                             accumulation_steps=accum_steps)
        # NOTE: We are using smaller batch size to simulate a large batch.
        # So that the actual times of model/criterion called should be
        # `(desired_bs/real_bs) * num_iter` == `accum_steps * num_iter`
        assert spy.call_count == accum_steps * num_iter

    @pytest.mark.skipif(
        not (IS_AMP_AVAILABLE and mod_task.use_cuda()),
        reason="`apex` module and gpu is required to run this test.")
    def test_gradient_accumulation_with_apex_amp(self, mocker):
        desired_bs, accum_steps = 32, 4
        real_bs = desired_bs // accum_steps
        num_iter = 10
        task = mod_task.XORTask(batch_size=real_bs)

        # Wrap model and optimizer by `amp.initialize`. Beside, `amp` requires
        # CUDA GPU. So we have to move model to GPU first.
        model, optimizer, device = task.model, task.optimizer, task.device
        model = model.to(device)
        task.model, task.optimizer = amp.initialize(model, optimizer)

        lr_finder = prepare_lr_finder(task)
        spy = mocker.spy(amp, "scale_loss")

        lr_finder.range_test(task.train_loader,
                             num_iter=num_iter,
                             accumulation_steps=accum_steps)
        assert spy.call_count == accum_steps * num_iter
        # CUDA GPU. So we have to move model to GPU first.
        model, optimizer, device = task.model, task.optimizer, task.device
        model = model.to(device)
        task.model, task.optimizer = amp.initialize(model, optimizer)

        lr_finder = prepare_lr_finder(task)
        spy = mocker.spy(amp, "scale_loss")

        lr_finder.range_test(task.train_loader,
                             num_iter=num_iter,
                             accumulation_steps=accum_steps)
        assert spy.call_count == accum_steps * num_iter


@pytest.mark.skipif(
    not (IS_AMP_AVAILABLE and mod_task.use_cuda()),
    reason="`apex` module and gpu is required to run these tests.")
class TestMixedPrecision:
    def test_mixed_precision(self, mocker):
        batch_size = 32
        num_iter = 10
        task = mod_task.XORTask(batch_size=batch_size)

        # Wrap model and optimizer by `amp.initialize`. Beside, `amp` requires
        # CUDA GPU. So we have to move model to GPU first.
        model, optimizer, device = task.model, task.optimizer, task.device
        model = model.to(device)
        task.model, task.optimizer = amp.initialize(model, optimizer)
        assert hasattr(task.optimizer, "_amp_stash")

        lr_finder = prepare_lr_finder(task)