def test_param_group_scheduler_asserts():

    t1 = torch.zeros([1], requires_grad=True)
    t2 = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([{"params": t1, "lr": 0.1}, {"params": t2, "lr": 0.1}])

    lr_scheduler1 = LinearCyclicalScheduler(
        optimizer, "lr", param_group_index=0, start_value=1.0, end_value=0.0, cycle_size=10
    )
    lr_scheduler2 = LinearCyclicalScheduler(
        optimizer, "lr", param_group_index=1, start_value=1.0, end_value=0.0, cycle_size=10
    )

    with pytest.raises(TypeError, match=r"Argument schedulers should be a list/tuple"):
        ParamGroupScheduler(schedulers=None, names=["a", "b", "c"])

    with pytest.raises(ValueError, match=r"Argument schedulers should be a list/tuple of parameter schedulers"):
        ParamGroupScheduler(schedulers=[0, 1, 2], names=["a", "b", "c"])

    with pytest.raises(ValueError, match=r"Argument schedulers should be a list/tuple of parameter schedulers"):
        ParamGroupScheduler(schedulers=[lr_scheduler1, "2"], names=["a", "b"])

    with pytest.raises(TypeError, match=r"Argument names should be a list/tuple"):
        ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names="ab")

    with pytest.raises(ValueError, match=r"Argument names should be a list/tuple of parameter scheduler's names"):
        ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names=[1, 2])

    with pytest.raises(ValueError, match=r"\d should be equal \d"):
        ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names=["a"])

    scheduler = ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names=["a", "b"])
    with pytest.raises(TypeError, match=r"Argument state_dict should be a dictionary"):
        scheduler.load_state_dict(None)

    with pytest.raises(ValueError, match=r"Required state attribute 'schedulers' is absent in provided state_dict"):
        scheduler.load_state_dict({"a": 1})

    with pytest.raises(ValueError, match=r"Input state_dict contains 0 state_dicts of param group schedulers"):
        scheduler.load_state_dict({"schedulers": []})

    with pytest.raises(ValueError, match=r"Required state attribute 'schedulers' is absent in provided state_dict"):
        scheduler.load_state_dict({})

    with pytest.raises(
        ValueError, match=r"Name of scheduler from input state dict does not " r"correspond to required one"
    ):
        scheduler.load_state_dict({"schedulers": [("a", lr_scheduler1.state_dict()), ("bad_name", {})]})
def test_linear_scheduler():

    with pytest.raises(TypeError, match=r"Argument optimizer should be torch.optim.Optimizer"):
        LinearCyclicalScheduler({}, "lr", 1, 0, cycle_size=0)

    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0.0)

    with pytest.raises(ValueError, match=r"Argument cycle_size should be positive and larger than 1"):
        LinearCyclicalScheduler(optimizer, "lr", 1, 0, cycle_size=0)

    with pytest.raises(ValueError, match=r"Argument cycle_size should be positive and larger than 1"):
        LinearCyclicalScheduler(optimizer, "lr", 1, 0, cycle_size=1)

    scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10)
    state_dict = scheduler.state_dict()

    def save_lr(engine):
        lrs.append(optimizer.param_groups[0]["lr"])

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

    for _ in range(2):
        lrs = []
        trainer.run([0] * 9, max_epochs=2)

        assert lrs == list(
            map(
                pytest.approx,
                [
                    # Cycle 1
                    1.0,
                    0.8,
                    0.6,
                    0.4,
                    0.2,
                    0.0,
                    0.2,
                    0.4,
                    0.6,
                    0.8,
                    # Cycle 2
                    1.0,
                    0.8,
                    0.6,
                    0.4,
                    0.2,
                    0.0,
                    0.2,
                    0.4,  # 0.6, 0.8,
                ],
            )
        )
        scheduler.load_state_dict(state_dict)

    optimizer = torch.optim.SGD([tensor], lr=0)
    scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10, cycle_mult=2)
    state_dict = scheduler.state_dict()

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

    for _ in range(2):
        lrs = []
        trainer.run([0] * 10, max_epochs=3)

        assert lrs == list(
            map(
                pytest.approx,
                [
                    # Cycle 1
                    1.0,
                    0.8,
                    0.6,
                    0.4,
                    0.2,
                    0.0,
                    0.2,
                    0.4,
                    0.6,
                    0.8,
                    # Cycle 2
                    1.0,
                    0.9,
                    0.8,
                    0.7,
                    0.6,
                    0.5,
                    0.4,
                    0.3,
                    0.2,
                    0.1,
                    0.0,
                    0.1,
                    0.2,
                    0.3,
                    0.4,
                    0.5,
                    0.6,
                    0.7,
                    0.8,
                    0.9,
                ],
            )
        )
        scheduler.load_state_dict(state_dict)

    # With float cycle_size
    optimizer = torch.optim.SGD([tensor], lr=0)
    scheduler = LinearCyclicalScheduler(
        optimizer, "lr", start_value=1.2, end_value=0.2, cycle_size=10.00000012, cycle_mult=1.0
    )
    state_dict = scheduler.state_dict()

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

    for _ in range(2):
        lrs = []
        trainer.run([0] * 9, max_epochs=2)
        assert lrs == list(
            map(
                pytest.approx,
                [
                    # Cycle 1
                    1.2,
                    1.0,
                    0.8,
                    0.6,
                    0.4,
                    0.2,
                    0.4,
                    0.6,
                    0.8,
                    1.0,
                    # Cycle 2
                    1.2,
                    1.0,
                    0.8,
                    0.6,
                    0.4,
                    0.2,
                    0.4,
                    0.6,  # 0.8, 1.0,
                ],
            )
        )
        scheduler.load_state_dict(state_dict)