def _test(lr_schedulers, optimizer):
        num_iterations = 10
        max_epochs = 20

        scheduler = ParamGroupScheduler(
            lr_schedulers,
            names=["s_{}".format(i) for i in range(len(lr_schedulers))])
        state_dict = scheduler.state_dict()

        trainer = Engine(lambda engine, batch: None)

        @trainer.on(Events.ITERATION_COMPLETED)
        def save_lr(engine):
            lrs.append((optimizer.param_groups[0]["lr"],
                        optimizer.param_groups[1]["lr"]))

        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        data = [0] * num_iterations

        for _ in range(2):
            lrs = []
            trainer.run(data, max_epochs=max_epochs)
            assert [lr[0]
                    for lr in lrs] == pytest.approx([lr[1] for lr in lrs])
            scheduler.load_state_dict(state_dict)
Exemple #2
0
def test_param_group_scheduler_asserts():

    t1 = torch.zeros([1], requires_grad=True)
    t2 = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([
        {"params": t1, 'lr': 0.1},
        {"params": t2, 'lr': 0.1},
    ])

    lr_scheduler1 = LinearCyclicalScheduler(optimizer.param_groups[0], "lr",
                                            start_value=1.0, end_value=0.0, cycle_size=10)
    lr_scheduler2 = LinearCyclicalScheduler(optimizer.param_groups[1], "lr",
                                            start_value=1.0, end_value=0.0, cycle_size=10)

    with pytest.raises(ValueError):
        ParamGroupScheduler(schedulers=[0, 1, 2], names=['a', 'b', 'c'])

    with pytest.raises(ValueError):
        ParamGroupScheduler(schedulers=[lr_scheduler1, '2'], names=['a', 'b'])

    with pytest.raises(ValueError):
        ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names='ab')

    with pytest.raises(ValueError):
        ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names=['a', ])
def test_param_group_scheduler_asserts():

    t1 = torch.zeros([1], requires_grad=True)
    t2 = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([
        {
            "params": t1,
            "lr": 0.1
        },
        {
            "params": t2,
            "lr": 0.1
        },
    ])

    lr_scheduler1 = LinearCyclicalScheduler(
        optimizer,
        "lr",
        param_group_index=0,
        start_value=1.0,
        end_value=0.0,
        cycle_size=10,
    )
    lr_scheduler2 = LinearCyclicalScheduler(
        optimizer,
        "lr",
        param_group_index=1,
        start_value=1.0,
        end_value=0.0,
        cycle_size=10,
    )

    with pytest.raises(ValueError):
        ParamGroupScheduler(schedulers=[0, 1, 2], names=["a", "b", "c"])

    with pytest.raises(ValueError):
        ParamGroupScheduler(schedulers=[lr_scheduler1, "2"], names=["a", "b"])

    with pytest.raises(ValueError):
        ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2],
                            names="ab")

    with pytest.raises(ValueError):
        ParamGroupScheduler(
            schedulers=[lr_scheduler1, lr_scheduler2],
            names=[
                "a",
            ],
        )

    scheduler = ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2],
                                    names=["a", "b"])
    with pytest.raises(
            ValueError,
            match=
            r"Required state attribute 'schedulers' is absent in provided state_dict",
    ):
        scheduler.load_state_dict({"a": 1})

    with pytest.raises(
            ValueError,
            match=
            r"Input state_dict contains 0 state_dicts of param group schedulers",
    ):
        scheduler.load_state_dict({"schedulers": []})

    with pytest.raises(
            ValueError,
            match=r"Name of scheduler from input state dict does not "
            r"correspond to required one",
    ):
        scheduler.load_state_dict({
            "schedulers": [("a", lr_scheduler1.state_dict()), ("bad_name", {})]
        })
Exemple #4
0
def test_param_group_scheduler_asserts():

    t1 = torch.zeros([1], requires_grad=True)
    t2 = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([{"params": t1, "lr": 0.1}, {"params": t2, "lr": 0.1}])

    lr_scheduler1 = LinearCyclicalScheduler(
        optimizer, "lr", param_group_index=0, start_value=1.0, end_value=0.0, cycle_size=10
    )
    lr_scheduler2 = LinearCyclicalScheduler(
        optimizer, "lr", param_group_index=1, start_value=1.0, end_value=0.0, cycle_size=10
    )

    with pytest.raises(TypeError, match=r"Argument schedulers should be a list/tuple"):
        ParamGroupScheduler(schedulers=None, names=["a", "b", "c"])

    with pytest.raises(ValueError, match=r"Argument schedulers should be a list/tuple of parameter schedulers"):
        ParamGroupScheduler(schedulers=[0, 1, 2], names=["a", "b", "c"])

    with pytest.raises(ValueError, match=r"Argument schedulers should be a list/tuple of parameter schedulers"):
        ParamGroupScheduler(schedulers=[lr_scheduler1, "2"], names=["a", "b"])

    with pytest.raises(TypeError, match=r"Argument names should be a list/tuple"):
        ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names="ab")

    with pytest.raises(ValueError, match=r"Argument names should be a list/tuple of parameter scheduler's names"):
        ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names=[1, 2])

    with pytest.raises(ValueError, match=r"\d should be equal \d"):
        ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names=["a"])

    scheduler = ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2], names=["a", "b"])
    with pytest.raises(ValueError, match=r"Required state attribute 'schedulers' is absent in provided state_dict"):
        scheduler.load_state_dict({"a": 1})

    with pytest.raises(ValueError, match=r"Input state_dict contains 0 state_dicts of param group schedulers"):
        scheduler.load_state_dict({"schedulers": []})

    with pytest.raises(ValueError, match=r"Required state attribute 'schedulers' is absent in provided state_dict"):
        scheduler.load_state_dict({})

    with pytest.raises(
        ValueError, match=r"Name of scheduler from input state dict does not " r"correspond to required one"
    ):
        scheduler.load_state_dict({"schedulers": [("a", lr_scheduler1.state_dict()), ("bad_name", {})]})

    optimizer2 = torch.optim.SGD([{"params": t1, "lr": 0.1}, {"params": t2, "lr": 0.1}])
    lr_scheduler3 = LinearCyclicalScheduler(
        optimizer2, "lr", param_group_index=0, start_value=1.0, end_value=0.0, cycle_size=10
    )
    with pytest.raises(ValueError, match=r"schedulers should be related to same optimizer"):
        ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler3])
def test_param_group_scheduler_asserts():

    t1 = torch.zeros([1], requires_grad=True)
    t2 = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([
        {
            "params": t1,
            'lr': 0.1
        },
        {
            "params": t2,
            'lr': 0.1
        },
    ])

    lr_scheduler1 = LinearCyclicalScheduler(optimizer,
                                            "lr",
                                            param_group_index=0,
                                            start_value=1.0,
                                            end_value=0.0,
                                            cycle_size=10)
    lr_scheduler2 = LinearCyclicalScheduler(optimizer,
                                            "lr",
                                            param_group_index=1,
                                            start_value=1.0,
                                            end_value=0.0,
                                            cycle_size=10)

    with pytest.raises(ValueError):
        ParamGroupScheduler(schedulers=[0, 1, 2], names=['a', 'b', 'c'])

    with pytest.raises(ValueError):
        ParamGroupScheduler(schedulers=[lr_scheduler1, '2'], names=['a', 'b'])

    with pytest.raises(ValueError):
        ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2],
                            names='ab')

    with pytest.raises(ValueError):
        ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2],
                            names=[
                                'a',
                            ])

    scheduler = ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler2],
                                    names=['a', 'b'])
    with pytest.raises(
            ValueError,
            match=
            r"Required state attribute 'schedulers' is absent in provided state_dict"
    ):
        scheduler.load_state_dict({'a': 1})

    with pytest.raises(
            ValueError,
            match=
            r"Input state_dict contains 0 state_dicts of param group schedulers"
    ):
        scheduler.load_state_dict({'schedulers': []})

    with pytest.raises(
            ValueError,
            match=r"Name of scheduler from input state dict does not "
            r"correspond to required one"):
        scheduler.load_state_dict({
            'schedulers': [('a', lr_scheduler1.state_dict()), ('bad_name', {})]
        })