Exemplo n.º 1
0
def test_lr_scheduler():
    def _test(torch_lr_scheduler_cls, **kwargs):

        tensor = torch.zeros([1], requires_grad=True)
        optimizer1 = torch.optim.SGD([tensor], lr=0.1)
        optimizer2 = torch.optim.SGD([tensor], lr=0.1)

        torch_lr_scheduler1 = torch_lr_scheduler_cls(optimizer=optimizer1,
                                                     **kwargs)
        torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2,
                                                     **kwargs)
        scheduler = LRScheduler(torch_lr_scheduler1)

        lrs = []
        lrs_true = []

        trainer = Engine(lambda engine, batch: None)

        @trainer.on(Events.ITERATION_STARTED)
        def torch_lr_scheduler_step(engine):
            torch_lr_scheduler2.step()

        @trainer.on(Events.ITERATION_COMPLETED)
        def save_lr(engine):
            lrs.append(optimizer1.param_groups[0]['lr'])

        @trainer.on(Events.ITERATION_COMPLETED)
        def save_true_lr(engine):
            lrs_true.append(optimizer2.param_groups[0]['lr'])

        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        data = [0] * 10
        max_epochs = 2
        trainer.run(data, max_epochs=max_epochs)

        assert lrs_true == pytest.approx(lrs)

        optimizer3 = torch.optim.SGD([tensor], lr=0.1)
        torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3,
                                                     **kwargs)

        simulated_values = LRScheduler.simulate_values(
            num_events=len(data) * max_epochs,
            lr_scheduler=torch_lr_scheduler3)
        assert lrs == pytest.approx([v for i, v in simulated_values])

    _test(torch.optim.lr_scheduler.StepLR, step_size=5, gamma=0.5)
    _test(torch.optim.lr_scheduler.ExponentialLR, gamma=0.78)

    # test _replicate_lr_scheduler
    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0.1)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,
                                                          gamma=0.78)
    init_lr_scheduler_state = dict(lr_scheduler.state_dict())
    copy_lr_scheduler = LRScheduler._replicate_lr_scheduler(lr_scheduler)
    for _ in range(10):
        lr_scheduler.step()

    assert copy_lr_scheduler.state_dict() == init_lr_scheduler_state
Exemplo n.º 2
0
def test_lr_scheduler():
    def _test(torch_lr_scheduler_cls, **kwargs):

        tensor = torch.zeros([1], requires_grad=True)
        optimizer1 = torch.optim.SGD([tensor], lr=0.01)
        optimizer2 = torch.optim.SGD([tensor], lr=0.01)
        opt_state_dict1 = optimizer1.state_dict()
        opt_state_dict2 = optimizer2.state_dict()

        torch_lr_scheduler1 = torch_lr_scheduler_cls(optimizer=optimizer1,
                                                     **kwargs)
        scheduler = LRScheduler(torch_lr_scheduler1)
        state_dict1 = scheduler.state_dict()

        torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2,
                                                     **kwargs)
        state_dict2 = torch_lr_scheduler2.state_dict()

        def dummy_update(engine, batch):
            optimizer1.step()
            optimizer2.step()

        trainer = Engine(dummy_update)

        @trainer.on(Events.ITERATION_STARTED)
        def save_lr(engine):
            lrs.append(optimizer1.param_groups[0]["lr"])

        @trainer.on(Events.ITERATION_STARTED)
        def save_true_lr(engine):
            lrs_true.append(optimizer2.param_groups[0]["lr"])

        @trainer.on(Events.ITERATION_COMPLETED)
        def torch_lr_scheduler_step(engine):
            torch_lr_scheduler2.step()

        trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)

        for _ in range(2):
            lrs = []
            lrs_true = []
            data = [0] * 10
            max_epochs = 2
            trainer.run(data, max_epochs=max_epochs)
            assert lrs_true == pytest.approx(
                lrs), "{}: {} ({}) vs {} ({})".format(_, lrs_true,
                                                      len(lrs_true), lrs,
                                                      len(lrs))
            optimizer1.load_state_dict(opt_state_dict1)
            scheduler.load_state_dict(state_dict1)
            optimizer2.load_state_dict(opt_state_dict2)
            torch_lr_scheduler2.load_state_dict(state_dict2)

        optimizer3 = torch.optim.SGD([tensor], lr=0.01)
        torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3,
                                                     **kwargs)

        simulated_values = LRScheduler.simulate_values(
            num_events=len(data) * max_epochs,
            lr_scheduler=torch_lr_scheduler3)
        assert lrs == pytest.approx([v for i, v in simulated_values])

    _test(torch.optim.lr_scheduler.StepLR, step_size=5, gamma=0.5)
    _test(torch.optim.lr_scheduler.ExponentialLR, gamma=0.78)

    # test _replicate_lr_scheduler
    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0.01)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,
                                                          gamma=0.78)
    init_lr_scheduler_state = dict(lr_scheduler.state_dict())
    copy_lr_scheduler = LRScheduler._replicate_lr_scheduler(lr_scheduler)
    for _ in range(10):
        optimizer.step()
        lr_scheduler.step()

    assert copy_lr_scheduler.state_dict() == init_lr_scheduler_state

    with pytest.raises(TypeError):
        LRScheduler._replicate_lr_scheduler(12)