示例#1
0
def test_keras_pruning_callback_monitor_is_invalid() -> None:

    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = study.ask()
    callback = KerasPruningCallback(trial, "InvalidMonitor")

    with pytest.warns(UserWarning):
        callback.on_epoch_end(0, {"loss": 1.0})
示例#2
0
def test_keras_pruning_callback_observation_isnan() -> None:

    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = study.ask()
    callback = KerasPruningCallback(trial, "loss")

    with pytest.raises(optuna.TrialPruned):
        callback.on_epoch_end(0, {"loss": 1.0})

    with pytest.raises(optuna.TrialPruned):
        callback.on_epoch_end(0, {"loss": float("nan")})
示例#3
0
def test_keras_pruning_callback_observation_isnan():
    # type: () -> None

    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    callback = KerasPruningCallback(trial, "loss")

    with pytest.raises(optuna.exceptions.TrialPruned):
        callback.on_epoch_end(0, {"loss": 1.0})

    with pytest.raises(optuna.exceptions.TrialPruned):
        callback.on_epoch_end(0, {"loss": float("nan")})
示例#4
0
def test_keras_pruning_callback_observation_isnan():
    # type: () -> None

    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = study._run_trial(func=lambda _: 1.0, catch=(Exception, ))
    callback = KerasPruningCallback(trial, 'loss')

    with pytest.raises(optuna.structs.TrialPruned):
        callback.on_epoch_end(0, {'loss': 1.0})

    with pytest.raises(optuna.structs.TrialPruned):
        callback.on_epoch_end(0, {'loss': float('nan')})
def test_keras_pruning_callback_observation_isnan():
    # type: () -> None

    # TODO(higumachan): remove this "if" section after Tensorflow supports Python 3.7.
    if not _available:
        pytest.skip(
            'This test requires keras '
            'but this version can not install keras(tensorflow) with pip.')

    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = study._run_trial(func=lambda _: 1.0, catch=(Exception, ))
    callback = KerasPruningCallback(trial, 'loss')

    with pytest.raises(optuna.structs.TrialPruned):
        callback.on_epoch_end(0, {'loss': 1.0})

    with pytest.raises(optuna.structs.TrialPruned):
        callback.on_epoch_end(0, {'loss': float('nan')})