예제 #1
0
def test_default_logger_callbacks_cpu_model(tmpdir):
    """
    Test each of the trainer options
    :return:
    """
    tutils.reset_seed()

    trainer_options = dict(default_save_path=tmpdir,
                           max_nb_epochs=1,
                           gradient_clip_val=1.0,
                           overfit_pct=0.20,
                           print_nan_grads=True,
                           show_progress_bar=False,
                           train_percent_check=0.01,
                           val_percent_check=0.01)

    model, hparams = tutils.get_model()
    tutils.run_model_test_no_loggers(trainer_options,
                                     model,
                                     hparams,
                                     on_gpu=False)

    # test freeze on cpu
    model.freeze()
    model.unfreeze()
예제 #2
0
def test_lbfgs_cpu_model(tmpdir):
    """Test each of the trainer options."""
    tutils.reset_seed()

    trainer_options = dict(default_save_path=tmpdir,
                           max_epochs=1,
                           print_nan_grads=True,
                           show_progress_bar=False,
                           weights_summary='top',
                           train_percent_check=1.0,
                           val_percent_check=0.2)

    model, hparams = tutils.get_model(use_test_model=True, lbfgs=True)
    tutils.run_model_test_no_loggers(trainer_options, model, min_acc=0.30)
예제 #3
0
def test_lbfgs_cpu_model():
    """
    Test each of the trainer options
    :return:
    """
    tutils.reset_seed()

    trainer_options = dict(max_nb_epochs=1,
                           print_nan_grads=True,
                           show_progress_bar=False,
                           weights_summary='top',
                           train_percent_check=1.0,
                           val_percent_check=0.2)

    model, hparams = tutils.get_model(use_test_model=True, lbfgs=True)
    tutils.run_model_test_no_loggers(trainer_options,
                                     model,
                                     hparams,
                                     on_gpu=False,
                                     min_acc=0.30)

    tutils.clear_save_dir()