def test_keras_pruning_callback_monitor_is_invalid() -> None: study = optuna.create_study(pruner=DeterministicPruner(True)) trial = study.ask() callback = KerasPruningCallback(trial, "InvalidMonitor") with pytest.warns(UserWarning): callback.on_epoch_end(0, {"loss": 1.0})
def test_keras_pruning_callback_observation_isnan() -> None: study = optuna.create_study(pruner=DeterministicPruner(True)) trial = study.ask() callback = KerasPruningCallback(trial, "loss") with pytest.raises(optuna.TrialPruned): callback.on_epoch_end(0, {"loss": 1.0}) with pytest.raises(optuna.TrialPruned): callback.on_epoch_end(0, {"loss": float("nan")})
def test_keras_pruning_callback_observation_isnan(): # type: () -> None study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) callback = KerasPruningCallback(trial, "loss") with pytest.raises(optuna.exceptions.TrialPruned): callback.on_epoch_end(0, {"loss": 1.0}) with pytest.raises(optuna.exceptions.TrialPruned): callback.on_epoch_end(0, {"loss": float("nan")})
def test_keras_pruning_callback_observation_isnan(): # type: () -> None study = optuna.create_study(pruner=DeterministicPruner(True)) trial = study._run_trial(func=lambda _: 1.0, catch=(Exception, )) callback = KerasPruningCallback(trial, 'loss') with pytest.raises(optuna.structs.TrialPruned): callback.on_epoch_end(0, {'loss': 1.0}) with pytest.raises(optuna.structs.TrialPruned): callback.on_epoch_end(0, {'loss': float('nan')})
def test_keras_pruning_callback_observation_isnan(): # type: () -> None # TODO(higumachan): remove this "if" section after Tensorflow supports Python 3.7. if not _available: pytest.skip( 'This test requires keras ' 'but this version can not install keras(tensorflow) with pip.') study = optuna.create_study(pruner=DeterministicPruner(True)) trial = study._run_trial(func=lambda _: 1.0, catch=(Exception, )) callback = KerasPruningCallback(trial, 'loss') with pytest.raises(optuna.structs.TrialPruned): callback.on_epoch_end(0, {'loss': 1.0}) with pytest.raises(optuna.structs.TrialPruned): callback.on_epoch_end(0, {'loss': float('nan')})