def _test(lr_schedulers, save_lr): num_iterations = 10 max_epochs = 20 scheduler = ParamGroupScheduler(lr_schedulers, names=[f"s_{i}" for i in range(len(lr_schedulers))]) state_dict = scheduler.state_dict() trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) data = [0] * num_iterations for _ in range(2): lrs = [] trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr, lrs) trainer.run(data, max_epochs=max_epochs) trainer.remove_event_handler(save_lr, Events.ITERATION_COMPLETED) assert [lr[0] for lr in lrs] == pytest.approx([lr[1] for lr in lrs]) scheduler.load_state_dict(state_dict) values = ParamGroupScheduler.simulate_values(max_epochs * num_iterations, lr_schedulers) assert [lr[1] for lr in values] == pytest.approx([lr[2] for lr in values]) assert [lr[0] for lr in lrs] == pytest.approx([lr[1] for lr in values])
def test_param_group_scheduler_asserts(): t1 = torch.zeros([1], requires_grad=True) t2 = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([{"params": t1, "lr": 0.1}, {"params": t2, "lr": 0.1}]) lr_scheduler1 = LinearCyclicalScheduler( optimizer, "lr", param_group_index=0, start_value=1.0, end_value=0.0, cycle_size=10 ) lr_scheduler2 = LinearCyclicalScheduler( optimizer, "lr", param_group_index=1, start_value=1.0, end_value=0.0, cycle_size=10 ) with pytest.raises(TypeError, match=r"Argument schedulers should be a list/tuple"): ParamGroupScheduler(schedulers=None, names=["a", "b", "c"]) with pytest.raises(ValueError, match=r"Argument schedulers should be a list/tuple of parameter schedulers"): ParamGroupScheduler(schedulers=[0, 1, 2], names=["a", "b", "c"]) with pytest.raises(ValueError, match=r"Argument schedulers should be a list/tuple of parameter schedulers"): ParamGroupScheduler(schedulers=[lr_scheduler1, "2"], names=["a", "b"]) with pytest.raises(TypeError, match=r"Argument names should be a list/tuple"): ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names="ab") with pytest.raises(ValueError, match=r"Argument names should be a list/tuple of parameter scheduler's names"): ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names=[1, 2]) with pytest.raises(ValueError, match=r"\d should be equal \d"): ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names=["a"]) scheduler = ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names=["a", "b"]) with pytest.raises(TypeError, match=r"Argument state_dict should be a dictionary"): scheduler.load_state_dict(None) with pytest.raises(ValueError, match=r"Required state attribute 'schedulers' is absent in provided state_dict"): scheduler.load_state_dict({"a": 1}) with pytest.raises(ValueError, match=r"Input state_dict contains 0 state_dicts of param group schedulers"): scheduler.load_state_dict({"schedulers": []}) with pytest.raises(ValueError, match=r"Required state attribute 'schedulers' is absent in provided state_dict"): scheduler.load_state_dict({}) with pytest.raises( ValueError, match=r"Name of scheduler from input state dict does not " r"correspond to required one" ): scheduler.load_state_dict({"schedulers": [("a", lr_scheduler1.state_dict()), ("bad_name", {})]})