コード例 #1
0
def test_optimizer_return_options():

    trainer = Trainer()
    model = EvalModelTemplate()

    # single optimizer
    opt_a = torch.optim.Adam(model.parameters(), lr=0.002)
    opt_b = torch.optim.SGD(model.parameters(), lr=0.002)
    scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10)
    scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10)

    # single optimizer
    model.configure_optimizers = lambda: opt_a
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 1 and len(lr_sched) == 0 and len(freq) == 0

    # opt tuple
    model.configure_optimizers = lambda: (opt_a, opt_b)
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b
    assert len(lr_sched) == 0 and len(freq) == 0

    # opt list
    model.configure_optimizers = lambda: [opt_a, opt_b]
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b
    assert len(lr_sched) == 0 and len(freq) == 0

    # opt tuple of 2 lists
    model.configure_optimizers = lambda: ([opt_a], [scheduler_a])
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
    assert optim[0] == opt_a
    assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
                               frequency=1, reduce_on_plateau=False, monitor='val_loss')

    # opt single dictionary
    model.configure_optimizers = lambda: {"optimizer": opt_a, "lr_scheduler": scheduler_a}
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
    assert optim[0] == opt_a
    assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
                               frequency=1, reduce_on_plateau=False, monitor='val_loss')

    # opt multiple dictionaries with frequencies
    model.configure_optimizers = lambda: (
        {"optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1},
        {"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5},
    )
    optim, lr_sched, freq = trainer.init_optimizers(model)
    assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2
    assert optim[0] == opt_a
    assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
                               frequency=1, reduce_on_plateau=False, monitor='val_loss')
    assert freq == [1, 5]
コード例 #2
0
def test_lr_scheduler_strict(tmpdir):
    """
    Test "strict" support in lr_scheduler dict
    """
    model = EvalModelTemplate()
    optimizer = torch.optim.Adam(model.parameters())
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

    model.configure_optimizers = lambda: {
        'optimizer': optimizer,
        'lr_scheduler': {'scheduler': scheduler, 'monitor': 'giraffe', 'strict': True},
    }
    with pytest.raises(
        MisconfigurationException,
        match=r'ReduceLROnPlateau conditioned on metric .* which is not available\. Available metrics are:',
    ):
        trainer.fit(model)

    model.configure_optimizers = lambda: {
        'optimizer': optimizer,
        'lr_scheduler': {
            'scheduler': scheduler,
            'monitor': 'giraffe',
            'strict': False,
        },
    }
    with pytest.warns(
        RuntimeWarning, match=r'ReduceLROnPlateau conditioned on metric .* which is not available but strict'
    ):
        assert trainer.fit(model)
コード例 #3
0
def test_lr_scheduler_with_no_actual_scheduler_raises(tmpdir):
    """
    Test exception when lr_scheduler dict has no scheduler
    """
    model = EvalModelTemplate()
    model.configure_optimizers = lambda: {
        'optimizer': torch.optim.Adam(model.parameters()),
        'lr_scheduler': {},
    }
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    with pytest.raises(MisconfigurationException, match='The lr scheduler dict must have the key "scheduler"'):
        trainer.fit(model)
コード例 #4
0
def test_reducelronplateau_with_no_monitor_raises(tmpdir):
    """
    Test exception when a ReduceLROnPlateau is used with no monitor
    """
    model = EvalModelTemplate()
    optimizer = torch.optim.Adam(model.parameters())
    model.configure_optimizers = lambda: ([optimizer], [torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)])
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    with pytest.raises(
        MisconfigurationException, match='`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`'
    ):
        trainer.fit(model)
コード例 #5
0
def test_reducelronplateau_with_no_monitor_in_lr_scheduler_dict_raises(tmpdir):
    """
    Test exception when lr_scheduler dict has a ReduceLROnPlateau with no monitor
    """
    model = EvalModelTemplate()
    optimizer = torch.optim.Adam(model.parameters())
    model.configure_optimizers = lambda: {
        'optimizer': optimizer,
        'lr_scheduler': {
            'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
        },
    }
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    with pytest.raises(MisconfigurationException, match='must include a monitor when a `ReduceLROnPlateau`'):
        trainer.fit(model)
コード例 #6
0
def test_lr_scheduler_with_extra_keys_warns(tmpdir):
    """
    Test warning when lr_scheduler dict has extra keys
    """
    model = EvalModelTemplate()
    optimizer = torch.optim.Adam(model.parameters())
    model.configure_optimizers = lambda: {
        'optimizer': optimizer,
        'lr_scheduler': {
            'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, 1),
            'foo': 1,
            'bar': 2,
        },
    }
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    with pytest.warns(RuntimeWarning, match=r'Found unsupported keys in the lr scheduler dict: \[.+\]'):
        trainer.fit(model)
コード例 #7
0
def test_reducelronplateau_scheduling(tmpdir):
    model = EvalModelTemplate()
    optimizer = torch.optim.Adam(model.parameters())
    model.configure_optimizers = lambda: {
        'optimizer': optimizer,
        'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
        'monitor': 'early_stop_on',
    }
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    results = trainer.fit(model)
    assert results == 1
    lr_scheduler = trainer.lr_schedulers[0]
    assert lr_scheduler == dict(
        scheduler=lr_scheduler['scheduler'],
        monitor='early_stop_on',
        interval='epoch',
        frequency=1,
        reduce_on_plateau=True,
        strict=True,
    ), 'lr scheduler was not correctly converted to dict'
コード例 #8
0
def test_reducelronplateau_scheduling(tmpdir):
    model = EvalModelTemplate()
    optimizer = torch.optim.Adam(model.parameters())
    model.configure_optimizers = lambda: {
        'optimizer': optimizer,
        'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
        'monitor': 'val_acc',
    }
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model)
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    lr_scheduler = trainer.lr_schedulers[0]
    assert lr_scheduler == dict(
        scheduler=lr_scheduler['scheduler'],
        monitor='val_acc',
        interval='epoch',
        frequency=1,
        reduce_on_plateau=True,
        strict=True,
        name=None,
    ), 'lr scheduler was not correctly converted to dict'
コード例 #9
0
def test_lr_scheduler_strict(tmpdir):
    """
    Test "strict" support in lr_scheduler dict
    """
    model = EvalModelTemplate()
    optimizer = optim.Adam(model.parameters())
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

    model.configure_optimizers = lambda: {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
            "monitor": "giraffe",
            "strict": True
        },
    }
    with pytest.raises(
            MisconfigurationException,
            match=
            r"ReduceLROnPlateau conditioned on metric .* which is not available\. Available metrics are:",
    ):
        trainer.fit(model)

    model.configure_optimizers = lambda: {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
            "monitor": "giraffe",
            "strict": False
        },
    }
    with pytest.warns(
            RuntimeWarning,
            match=
            r"ReduceLROnPlateau conditioned on metric .* which is not available but strict"
    ):
        trainer.fit(model)