def test_pwlinear_scheduler_step_constant(max_epochs, milestones_values): # Testing step_constant engine = Engine(lambda e, b: None) linear_state_parameter_scheduler = PiecewiseLinearStateScheduler( param_name="pwlinear_scheduled_param", milestones_values=milestones_values, create_new=True ) linear_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED) engine.run([0] * 8, max_epochs=max_epochs) torch.testing.assert_allclose(getattr(engine.state, "pwlinear_scheduled_param"), milestones_values[0][1]) state_dict = linear_state_parameter_scheduler.state_dict() linear_state_parameter_scheduler.load_state_dict(state_dict)
def test_pwlinear_scheduler_linear_increase(max_epochs, milestones_values, expected_val): # Testing linear increase engine = Engine(lambda e, b: None) linear_state_parameter_scheduler = PiecewiseLinearStateScheduler( param_name="pwlinear_scheduled_param", milestones_values=milestones_values, create_new=True ) linear_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED) engine.run([0] * 8, max_epochs=max_epochs) torch.testing.assert_allclose(getattr(engine.state, "pwlinear_scheduled_param"), expected_val, atol=0.001, rtol=0.0) state_dict = linear_state_parameter_scheduler.state_dict() linear_state_parameter_scheduler.load_state_dict(state_dict)
def test_pwlinear_scheduler_linear_increase_history(max_epochs, milestones_values, save_history, expected_param_history): # Testing linear increase engine = Engine(lambda e, b: None) pw_linear_step_parameter_scheduler = PiecewiseLinearStateScheduler( param_name="pwlinear_scheduled_param", milestones_values=milestones_values, save_history=save_history, ) pw_linear_step_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED) engine.run([0] * 8, max_epochs=max_epochs) expected_param_history = expected_param_history assert hasattr(engine.state, "param_history") state_param = engine.state.param_history["pwlinear_scheduled_param"] assert len(state_param) == len(expected_param_history) assert state_param == expected_param_history state_dict = pw_linear_step_parameter_scheduler.state_dict() pw_linear_step_parameter_scheduler.load_state_dict(state_dict)