def test_pytorch_ignite_pruning_handler(): # type: () -> None def update(engine, batch): # type: (Engine, Iterable) -> None pass trainer = Engine(update) # The pruner is activated. study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) handler = optuna.integration.PyTorchIgnitePruningHandler( trial, 'accuracy', trainer) with patch.object(trainer, 'state', epoch=1, metrics={'accuracy': 1}): with pytest.raises(optuna.exceptions.TrialPruned): handler(trainer) # The pruner is not activated. study = optuna.create_study(pruner=DeterministicPruner(False)) trial = create_running_trial(study, 1.0) handler = optuna.integration.PyTorchIgnitePruningHandler( trial, 'accuracy', trainer) with patch.object(trainer, 'state', epoch=1, metrics={'accuracy': 1}): handler(trainer)
def test_lightgbm_pruning_callback_call(cv: bool) -> None: callback_env = partial( lgb.callback.CallbackEnv, model="test", params={}, begin_iteration=0, end_iteration=1, iteration=1, ) if cv: env = callback_env(evaluation_result_list=[(("cv_agg", "binary_error", 1.0, False, 1.0))]) else: env = callback_env(evaluation_result_list=[("validation", "binary_error", 1.0, False)]) # The pruner is deactivated. study = optuna.create_study(pruner=DeterministicPruner(False)) trial = create_running_trial(study, 1.0) pruning_callback = LightGBMPruningCallback(trial, "binary_error", valid_name="validation") pruning_callback(env) # The pruner is activated. study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) pruning_callback = LightGBMPruningCallback(trial, "binary_error", valid_name="validation") with pytest.raises(optuna.TrialPruned): pruning_callback(env)
def test_lightgbm_pruning_callback_call(cv): # type: (bool) -> None callback_env = partial( lgb.callback.CallbackEnv, model='test', params={}, begin_iteration=0, end_iteration=1, iteration=1) if cv: env = callback_env(evaluation_result_list=[(('cv_agg', 'binary_error', 1., False, 1.))]) else: env = callback_env(evaluation_result_list=[('validation', 'binary_error', 1., False)]) # The pruner is deactivated. study = optuna.create_study(pruner=DeterministicPruner(False)) trial = create_running_trial(study, 1.0) pruning_callback = LightGBMPruningCallback(trial, 'binary_error', valid_name='validation') pruning_callback(env) # The pruner is activated. study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) pruning_callback = LightGBMPruningCallback(trial, 'binary_error', valid_name='validation') with pytest.raises(optuna.structs.TrialPruned): pruning_callback(env)
def test_xgboost_pruning_callback_call(): # type: () -> None env = xgb.core.CallbackEnv( model="test", cvfolds=1, begin_iteration=0, end_iteration=1, rank=1, iteration=1, evaluation_result_list=[["validation-error", 1.0]], ) # The pruner is deactivated. study = optuna.create_study(pruner=DeterministicPruner(False)) trial = create_running_trial(study, 1.0) pruning_callback = XGBoostPruningCallback(trial, "validation-error") pruning_callback(env) # The pruner is activated. study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) pruning_callback = XGBoostPruningCallback(trial, "validation-error") with pytest.raises(optuna.exceptions.TrialPruned): pruning_callback(env)
def test_catboost_pruning_callback_call() -> None: # The pruner is deactivated. study = optuna.create_study(pruner=DeterministicPruner(False)) trial = create_running_trial(study, 1.0) pruning_callback = CatBoostPruningCallback(trial, "Logloss") info = types.SimpleNamespace(iteration=1, metrics={ "learn": { "Logloss": [1.0] }, "validation": { "Logloss": [1.0] } }) assert pruning_callback.after_iteration(info) # The pruner is activated. study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) pruning_callback = CatBoostPruningCallback(trial, "Logloss") info = types.SimpleNamespace(iteration=1, metrics={ "learn": { "Logloss": [1.0] }, "validation": { "Logloss": [1.0] } }) assert not pruning_callback.after_iteration(info)
def test_pytorch_ignite_pruning_handler(): # type: () -> None def update(engine, batch): # type: (Engine, Iterable) -> None pass trainer = Engine(update) evaluator = Engine(update) # The pruner is activated. study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) handler = optuna.integration.PyTorchIgnitePruningHandler( trial, "accuracy", trainer) with patch.object(trainer, "state", epoch=3): with patch.object(evaluator, "state", metrics={"accuracy": 1}): with pytest.raises(optuna.exceptions.TrialPruned): handler(evaluator) assert study.trials[0].intermediate_values == {3: 1} # The pruner is not activated. study = optuna.create_study(pruner=DeterministicPruner(False)) trial = create_running_trial(study, 1.0) handler = optuna.integration.PyTorchIgnitePruningHandler( trial, "accuracy", trainer) with patch.object(trainer, "state", epoch=5): with patch.object(evaluator, "state", metrics={"accuracy": 2}): handler(evaluator) assert study.trials[0].intermediate_values == {5: 2}
def test_observation_exists(): # type: () -> None study = optuna.create_study() trial = create_running_trial(study, 1.0) MockTrainer = namedtuple("_MockTrainer", ("observation", )) trainer = MockTrainer(observation={"OK": 0}) # Trigger is deactivated. Return False whether trainer has observation or not. with patch.object(triggers.IntervalTrigger, "__call__", Mock(return_value=False)) as mock: extension = ChainerPruningExtension(trial, "NG", (1, "epoch")) assert extension._observation_exists(trainer) is False extension = ChainerPruningExtension(trial, "OK", (1, "epoch")) assert extension._observation_exists(trainer) is False assert mock.call_count == 2 # Trigger is activated. Return True if trainer has observation. with patch.object(triggers.IntervalTrigger, "__call__", Mock(return_value=True)) as mock: extension = ChainerPruningExtension(trial, "NG", (1, "epoch")) assert extension._observation_exists(trainer) is False extension = ChainerPruningExtension(trial, "OK", (1, "epoch")) assert extension._observation_exists(trainer) is True assert mock.call_count == 2
def test_tfkeras_pruning_callback_monitor_is_invalid() -> None: study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) callback = TFKerasPruningCallback(trial, "InvalidMonitor") with pytest.warns(UserWarning): callback.on_epoch_end(0, {"loss": 1.0})
def test_tfkeras_pruning_callback_observation_isnan() -> None: study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) callback = TFKerasPruningCallback(trial, "loss") with pytest.raises(optuna.TrialPruned): callback.on_epoch_end(0, {"loss": 1.0}) with pytest.raises(optuna.TrialPruned): callback.on_epoch_end(0, {"loss": float("nan")})
def test_tfkeras_pruning_callback_observation_isnan(): # type: () -> None study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) callback = TFKerasPruningCallback(trial, 'loss') with pytest.raises(optuna.structs.TrialPruned): callback.on_epoch_end(0, {'loss': 1.0}) with pytest.raises(optuna.structs.TrialPruned): callback.on_epoch_end(0, {'loss': float('nan')})
def test_chainer_pruning_extension_observation_nan() -> None: study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) extension = ChainerPruningExtension(trial, "main/loss", (1, "epoch")) MockTrainer = namedtuple("_MockTrainer", ("observation", "updater")) MockUpdater = namedtuple("_MockUpdater", ("epoch")) trainer = MockTrainer(observation={"main/loss": float("nan")}, updater=MockUpdater(1)) with patch.object(extension, "_observation_exists", Mock(return_value=True)) as mock: with pytest.raises(optuna.TrialPruned): extension(trainer) # type: ignore assert mock.call_count == 1
def test_chainer_pruning_extension_observation_nan(): # type: () -> None study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) extension = ChainerPruningExtension(trial, 'main/loss', (1, 'epoch')) MockTrainer = namedtuple('_MockTrainer', ('observation', 'updater')) MockUpdater = namedtuple('_MockUpdater', ('epoch')) trainer = MockTrainer(observation={'main/loss': float('nan')}, updater=MockUpdater(1)) with patch.object(extension, '_observation_exists', Mock(return_value=True)) as mock: with pytest.raises(TrialPruned): extension(trainer) assert mock.call_count == 1
def test_pytorch_lightning_pruning_callback_monitor_is_invalid() -> None: study = optuna.create_study(pruner=DeterministicPruner(True)) trial = create_running_trial(study, 1.0) callback = PyTorchLightningPruningCallback(trial, "InvalidMonitor") trainer = pl.Trainer( max_epochs=1, enable_checkpointing=False, callbacks=[callback], ) model = Model() with pytest.warns(UserWarning): callback.on_validation_end(trainer, model)
def test_chainer_pruning_extension_trigger(): # type: () -> None study = optuna.create_study() trial = create_running_trial(study, 1.0) extension = ChainerPruningExtension(trial, 'main/loss', (1, 'epoch')) assert isinstance(extension.pruner_trigger, triggers.IntervalTrigger) extension = ChainerPruningExtension(trial, 'main/loss', triggers.IntervalTrigger(1, 'epoch')) assert isinstance(extension.pruner_trigger, triggers.IntervalTrigger) extension = ChainerPruningExtension(trial, 'main/loss', triggers.ManualScheduleTrigger(1, 'epoch')) assert isinstance(extension.pruner_trigger, triggers.ManualScheduleTrigger) with pytest.raises(TypeError): ChainerPruningExtension(trial, 'main/loss', triggers.TimeTrigger(1.))