コード例 #1
0
def test_trainer_arg_str(tmpdir, use_hparams):
    """ Test that setting trainer arg to string works """
    hparams = EvalModelTemplate.get_default_hparams()
    model = EvalModelTemplate(**hparams)
    model.my_fancy_lr = 1.0  # update with non-standard field
    model.hparams['my_fancy_lr'] = 1.0
    before_lr = model.my_fancy_lr
    if use_hparams:
        del model.my_fancy_lr
        model.configure_optimizers = model.configure_optimizers__lr_from_hparams

    # logger file to get meta
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=2,
        auto_lr_find='my_fancy_lr',
    )

    trainer.tune(model)
    if use_hparams:
        after_lr = model.hparams.my_fancy_lr
    else:
        after_lr = model.my_fancy_lr

    assert before_lr != after_lr, \
        'Learning rate was not altered after running learning rate finder'
コード例 #2
0
def test_trainer_arg_str(tmpdir):
    """ Test that setting trainer arg to string works """
    model = EvalModelTemplate()
    model.my_fancy_lr = 1.0  # update with non-standard field

    before_lr = model.my_fancy_lr
    # logger file to get meta
    trainer = Trainer(default_save_path=tmpdir,
                      max_epochs=2,
                      auto_lr_find='my_fancy_lr')

    trainer.fit(model)
    after_lr = model.my_fancy_lr
    assert before_lr != after_lr, \
        'Learning rate was not altered after running learning rate finder'