def test_lr_scheduling_on_non_torch_optimizers():
    # tests https://github.com/pytorch/ignite/issues/1162
    optimizer = MagicMock()
    optimizer.param_groups = [{"params": 0}]
    FakeParamScheduler(optimizer, "lr")

    tensor = torch.zeros([1], requires_grad=True)
    base_optimizer = torch.optim.SGD([tensor], lr=0)
    optimizer = MockFP16DeepSpeedZeroOptimizer(base_optimizer)

    milestones_values = [(5, 0.5), (15, 1.0)]

    scheduler = PiecewiseLinear(optimizer, "lr", milestones_values=milestones_values)

    def save_lr(engine):
        lrs.append(optimizer.param_groups[0]["lr"])

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

    lrs = []
    trainer.run([0] * 15, max_epochs=1)

    assert lrs == list(
        map(pytest.approx, [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95])
    )
Beispiel #2
0
def test_opt_params_handler_on_non_torch_optimizers():
    tensor = torch.zeros([1], requires_grad=True)
    base_optimizer = torch.optim.SGD([tensor], lr=0.1234)
    optimizer = MockFP16DeepSpeedZeroOptimizer(base_optimizer)
    handler = DummyOptParamsHandler(optimizer=optimizer, param_name="lr")
    res = handler(engine=None, logger=None, event_name=None)
    assert isinstance(res, dict)
    assert "lr/group_0" in res and res["lr/group_0"] == 0.1234