def test_pwlinear_scheduler_step_constant(max_epochs, milestones_values):
    # Testing step_constant
    engine = Engine(lambda e, b: None)
    linear_state_parameter_scheduler = PiecewiseLinearStateScheduler(
        param_name="pwlinear_scheduled_param", milestones_values=milestones_values, create_new=True
    )
    linear_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
    engine.run([0] * 8, max_epochs=max_epochs)
    torch.testing.assert_allclose(getattr(engine.state, "pwlinear_scheduled_param"), milestones_values[0][1])

    state_dict = linear_state_parameter_scheduler.state_dict()
    linear_state_parameter_scheduler.load_state_dict(state_dict)
def test_pwlinear_scheduler_linear_increase(max_epochs, milestones_values, expected_val):
    # Testing linear increase
    engine = Engine(lambda e, b: None)
    linear_state_parameter_scheduler = PiecewiseLinearStateScheduler(
        param_name="pwlinear_scheduled_param", milestones_values=milestones_values, create_new=True
    )
    linear_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
    engine.run([0] * 8, max_epochs=max_epochs)
    torch.testing.assert_allclose(getattr(engine.state, "pwlinear_scheduled_param"), expected_val, atol=0.001, rtol=0.0)

    state_dict = linear_state_parameter_scheduler.state_dict()
    linear_state_parameter_scheduler.load_state_dict(state_dict)
def test_param_scheduler_with_ema_handler():

    from ignite.handlers import EMAHandler

    model = nn.Linear(2, 1)
    trainer = Engine(lambda e, b: model(b))
    data = torch.rand(100, 2)

    param_name = "ema_decay"

    ema_handler = EMAHandler(model)
    ema_handler.attach(trainer, name=param_name, event=Events.ITERATION_COMPLETED)

    ema_decay_scheduler = PiecewiseLinearStateScheduler(
        param_name=param_name, milestones_values=[(0, 0.0), (10, 0.999)], save_history=True
    )
    ema_decay_scheduler.attach(trainer, Events.ITERATION_COMPLETED)
    trainer.run(data, max_epochs=20)
def test_param_scheduler_attach_warning():
    trainer = Engine(lambda e, b: None)
    param_name = "state_param"
    save_history = True
    create_new = False

    param_scheduler = PiecewiseLinearStateScheduler(
        param_name=param_name,
        milestones_values=[(0, 0.0), (10, 0.999)],
        save_history=save_history,
        create_new=create_new,
    )

    with pytest.warns(
        UserWarning,
        match=r"Attribute '" + re.escape(param_name) + "' is not defined in the engine.state. "
        r"PiecewiseLinearStateScheduler will create it. Remove this warning by setting create_new=True.",
    ):
        param_scheduler.attach(trainer, Events.ITERATION_COMPLETED)
def test_pwlinear_scheduler_linear_increase_history(max_epochs,
                                                    milestones_values,
                                                    save_history,
                                                    expected_param_history):
    # Testing linear increase
    engine = Engine(lambda e, b: None)
    pw_linear_step_parameter_scheduler = PiecewiseLinearStateScheduler(
        param_name="pwlinear_scheduled_param",
        milestones_values=milestones_values,
        save_history=save_history,
    )
    pw_linear_step_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
    engine.run([0] * 8, max_epochs=max_epochs)
    expected_param_history = expected_param_history
    assert hasattr(engine.state, "param_history")
    state_param = engine.state.param_history["pwlinear_scheduled_param"]
    assert len(state_param) == len(expected_param_history)
    assert state_param == expected_param_history

    state_dict = pw_linear_step_parameter_scheduler.state_dict()
    pw_linear_step_parameter_scheduler.load_state_dict(state_dict)
def test_param_scheduler_attach_exception():
    trainer = Engine(lambda e, b: None)
    param_name = "state_param"

    setattr(trainer.state, param_name, None)

    save_history = True
    create_new = True

    param_scheduler = PiecewiseLinearStateScheduler(
        param_name=param_name,
        milestones_values=[(0, 0.0), (10, 0.999)],
        save_history=save_history,
        create_new=create_new,
    )

    with pytest.raises(
        ValueError,
        match=r"Attribute '" + re.escape(param_name) + "' already exists in the engine.state. "
        r"This may be a conflict between multiple handlers. "
        r"Please choose another name.",
    ):
        param_scheduler.attach(trainer, Events.ITERATION_COMPLETED)