Ejemplo n.º 1
0
def test_pwlinear_scheduler_step_constant(max_epochs, milestones_values):
    # Testing step_constant
    engine = Engine(lambda e, b: None)
    linear_state_parameter_scheduler = PiecewiseLinearStateScheduler(
        param_name="pwlinear_scheduled_param", milestones_values=milestones_values, create_new=True
    )
    linear_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
    engine.run([0] * 8, max_epochs=max_epochs)
    torch.testing.assert_allclose(getattr(engine.state, "pwlinear_scheduled_param"), milestones_values[0][1])

    state_dict = linear_state_parameter_scheduler.state_dict()
    linear_state_parameter_scheduler.load_state_dict(state_dict)
Ejemplo n.º 2
0
def test_pwlinear_scheduler_linear_increase(max_epochs, milestones_values, expected_val):
    # Testing linear increase
    engine = Engine(lambda e, b: None)
    linear_state_parameter_scheduler = PiecewiseLinearStateScheduler(
        param_name="pwlinear_scheduled_param", milestones_values=milestones_values, create_new=True
    )
    linear_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
    engine.run([0] * 8, max_epochs=max_epochs)
    torch.testing.assert_allclose(getattr(engine.state, "pwlinear_scheduled_param"), expected_val, atol=0.001, rtol=0.0)

    state_dict = linear_state_parameter_scheduler.state_dict()
    linear_state_parameter_scheduler.load_state_dict(state_dict)
def test_pwlinear_scheduler_linear_increase_history(max_epochs,
                                                    milestones_values,
                                                    save_history,
                                                    expected_param_history):
    # Testing linear increase
    engine = Engine(lambda e, b: None)
    pw_linear_step_parameter_scheduler = PiecewiseLinearStateScheduler(
        param_name="pwlinear_scheduled_param",
        milestones_values=milestones_values,
        save_history=save_history,
    )
    pw_linear_step_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED)
    engine.run([0] * 8, max_epochs=max_epochs)
    expected_param_history = expected_param_history
    assert hasattr(engine.state, "param_history")
    state_param = engine.state.param_history["pwlinear_scheduled_param"]
    assert len(state_param) == len(expected_param_history)
    assert state_param == expected_param_history

    state_dict = pw_linear_step_parameter_scheduler.state_dict()
    pw_linear_step_parameter_scheduler.load_state_dict(state_dict)