Example #1
0
def test_pytorch_ignite_pruning_handler():
    # type: () -> None

    def update(engine, batch):
        # type: (Engine, Iterable) -> None

        pass

    trainer = Engine(update)

    # The pruner is activated.
    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)

    handler = optuna.integration.PyTorchIgnitePruningHandler(
        trial, 'accuracy', trainer)
    with patch.object(trainer, 'state', epoch=1, metrics={'accuracy': 1}):
        with pytest.raises(optuna.exceptions.TrialPruned):
            handler(trainer)

    # The pruner is not activated.
    study = optuna.create_study(pruner=DeterministicPruner(False))
    trial = create_running_trial(study, 1.0)

    handler = optuna.integration.PyTorchIgnitePruningHandler(
        trial, 'accuracy', trainer)
    with patch.object(trainer, 'state', epoch=1, metrics={'accuracy': 1}):
        handler(trainer)
def test_lightgbm_pruning_callback_call(cv: bool) -> None:

    callback_env = partial(
        lgb.callback.CallbackEnv,
        model="test",
        params={},
        begin_iteration=0,
        end_iteration=1,
        iteration=1,
    )

    if cv:
        env = callback_env(evaluation_result_list=[(("cv_agg", "binary_error", 1.0, False, 1.0))])
    else:
        env = callback_env(evaluation_result_list=[("validation", "binary_error", 1.0, False)])

    # The pruner is deactivated.
    study = optuna.create_study(pruner=DeterministicPruner(False))
    trial = create_running_trial(study, 1.0)
    pruning_callback = LightGBMPruningCallback(trial, "binary_error", valid_name="validation")
    pruning_callback(env)

    # The pruner is activated.
    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    pruning_callback = LightGBMPruningCallback(trial, "binary_error", valid_name="validation")
    with pytest.raises(optuna.TrialPruned):
        pruning_callback(env)
Example #3
0
def test_lightgbm_pruning_callback_call(cv):
    # type: (bool) -> None

    callback_env = partial(
        lgb.callback.CallbackEnv,
        model='test',
        params={},
        begin_iteration=0,
        end_iteration=1,
        iteration=1)

    if cv:
        env = callback_env(evaluation_result_list=[(('cv_agg', 'binary_error', 1., False, 1.))])
    else:
        env = callback_env(evaluation_result_list=[('validation', 'binary_error', 1., False)])

    # The pruner is deactivated.
    study = optuna.create_study(pruner=DeterministicPruner(False))
    trial = create_running_trial(study, 1.0)
    pruning_callback = LightGBMPruningCallback(trial, 'binary_error', valid_name='validation')
    pruning_callback(env)

    # The pruner is activated.
    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    pruning_callback = LightGBMPruningCallback(trial, 'binary_error', valid_name='validation')
    with pytest.raises(optuna.structs.TrialPruned):
        pruning_callback(env)
Example #4
0
def test_xgboost_pruning_callback_call():
    # type: () -> None

    env = xgb.core.CallbackEnv(
        model="test",
        cvfolds=1,
        begin_iteration=0,
        end_iteration=1,
        rank=1,
        iteration=1,
        evaluation_result_list=[["validation-error", 1.0]],
    )

    # The pruner is deactivated.
    study = optuna.create_study(pruner=DeterministicPruner(False))
    trial = create_running_trial(study, 1.0)
    pruning_callback = XGBoostPruningCallback(trial, "validation-error")
    pruning_callback(env)

    # The pruner is activated.
    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    pruning_callback = XGBoostPruningCallback(trial, "validation-error")
    with pytest.raises(optuna.exceptions.TrialPruned):
        pruning_callback(env)
Example #5
0
def test_catboost_pruning_callback_call() -> None:
    # The pruner is deactivated.
    study = optuna.create_study(pruner=DeterministicPruner(False))
    trial = create_running_trial(study, 1.0)
    pruning_callback = CatBoostPruningCallback(trial, "Logloss")
    info = types.SimpleNamespace(iteration=1,
                                 metrics={
                                     "learn": {
                                         "Logloss": [1.0]
                                     },
                                     "validation": {
                                         "Logloss": [1.0]
                                     }
                                 })
    assert pruning_callback.after_iteration(info)

    # The pruner is activated.
    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    pruning_callback = CatBoostPruningCallback(trial, "Logloss")
    info = types.SimpleNamespace(iteration=1,
                                 metrics={
                                     "learn": {
                                         "Logloss": [1.0]
                                     },
                                     "validation": {
                                         "Logloss": [1.0]
                                     }
                                 })
    assert not pruning_callback.after_iteration(info)
def test_pytorch_ignite_pruning_handler():
    # type: () -> None

    def update(engine, batch):
        # type: (Engine, Iterable) -> None

        pass

    trainer = Engine(update)
    evaluator = Engine(update)

    # The pruner is activated.
    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)

    handler = optuna.integration.PyTorchIgnitePruningHandler(
        trial, "accuracy", trainer)
    with patch.object(trainer, "state", epoch=3):
        with patch.object(evaluator, "state", metrics={"accuracy": 1}):
            with pytest.raises(optuna.exceptions.TrialPruned):
                handler(evaluator)
            assert study.trials[0].intermediate_values == {3: 1}

    # The pruner is not activated.
    study = optuna.create_study(pruner=DeterministicPruner(False))
    trial = create_running_trial(study, 1.0)

    handler = optuna.integration.PyTorchIgnitePruningHandler(
        trial, "accuracy", trainer)
    with patch.object(trainer, "state", epoch=5):
        with patch.object(evaluator, "state", metrics={"accuracy": 2}):
            handler(evaluator)
            assert study.trials[0].intermediate_values == {5: 2}
Example #7
0
def test_observation_exists():
    # type: () -> None

    study = optuna.create_study()
    trial = create_running_trial(study, 1.0)
    MockTrainer = namedtuple("_MockTrainer", ("observation", ))
    trainer = MockTrainer(observation={"OK": 0})

    # Trigger is deactivated. Return False whether trainer has observation or not.
    with patch.object(triggers.IntervalTrigger, "__call__",
                      Mock(return_value=False)) as mock:
        extension = ChainerPruningExtension(trial, "NG", (1, "epoch"))
        assert extension._observation_exists(trainer) is False
        extension = ChainerPruningExtension(trial, "OK", (1, "epoch"))
        assert extension._observation_exists(trainer) is False
        assert mock.call_count == 2

    # Trigger is activated. Return True if trainer has observation.
    with patch.object(triggers.IntervalTrigger, "__call__",
                      Mock(return_value=True)) as mock:
        extension = ChainerPruningExtension(trial, "NG", (1, "epoch"))
        assert extension._observation_exists(trainer) is False
        extension = ChainerPruningExtension(trial, "OK", (1, "epoch"))
        assert extension._observation_exists(trainer) is True
        assert mock.call_count == 2
Example #8
0
def test_tfkeras_pruning_callback_monitor_is_invalid() -> None:

    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    callback = TFKerasPruningCallback(trial, "InvalidMonitor")

    with pytest.warns(UserWarning):
        callback.on_epoch_end(0, {"loss": 1.0})
Example #9
0
def test_tfkeras_pruning_callback_observation_isnan() -> None:

    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    callback = TFKerasPruningCallback(trial, "loss")

    with pytest.raises(optuna.TrialPruned):
        callback.on_epoch_end(0, {"loss": 1.0})

    with pytest.raises(optuna.TrialPruned):
        callback.on_epoch_end(0, {"loss": float("nan")})
Example #10
0
def test_tfkeras_pruning_callback_observation_isnan():
    # type: () -> None

    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    callback = TFKerasPruningCallback(trial, 'loss')

    with pytest.raises(optuna.structs.TrialPruned):
        callback.on_epoch_end(0, {'loss': 1.0})

    with pytest.raises(optuna.structs.TrialPruned):
        callback.on_epoch_end(0, {'loss': float('nan')})
Example #11
0
def test_chainer_pruning_extension_observation_nan() -> None:

    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    extension = ChainerPruningExtension(trial, "main/loss", (1, "epoch"))

    MockTrainer = namedtuple("_MockTrainer", ("observation", "updater"))
    MockUpdater = namedtuple("_MockUpdater", ("epoch"))
    trainer = MockTrainer(observation={"main/loss": float("nan")}, updater=MockUpdater(1))

    with patch.object(extension, "_observation_exists", Mock(return_value=True)) as mock:
        with pytest.raises(optuna.TrialPruned):
            extension(trainer)  # type: ignore
        assert mock.call_count == 1
Example #12
0
def test_chainer_pruning_extension_observation_nan():
    # type: () -> None

    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    extension = ChainerPruningExtension(trial, 'main/loss', (1, 'epoch'))

    MockTrainer = namedtuple('_MockTrainer', ('observation', 'updater'))
    MockUpdater = namedtuple('_MockUpdater', ('epoch'))
    trainer = MockTrainer(observation={'main/loss': float('nan')}, updater=MockUpdater(1))

    with patch.object(extension, '_observation_exists', Mock(return_value=True)) as mock:
        with pytest.raises(TrialPruned):
            extension(trainer)
        assert mock.call_count == 1
Example #13
0
def test_pytorch_lightning_pruning_callback_monitor_is_invalid() -> None:

    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    callback = PyTorchLightningPruningCallback(trial, "InvalidMonitor")

    trainer = pl.Trainer(
        max_epochs=1,
        enable_checkpointing=False,
        callbacks=[callback],
    )
    model = Model()

    with pytest.warns(UserWarning):
        callback.on_validation_end(trainer, model)
Example #14
0
def test_chainer_pruning_extension_trigger():
    # type: () -> None

    study = optuna.create_study()
    trial = create_running_trial(study, 1.0)

    extension = ChainerPruningExtension(trial, 'main/loss', (1, 'epoch'))
    assert isinstance(extension.pruner_trigger, triggers.IntervalTrigger)
    extension = ChainerPruningExtension(trial, 'main/loss', triggers.IntervalTrigger(1, 'epoch'))
    assert isinstance(extension.pruner_trigger, triggers.IntervalTrigger)
    extension = ChainerPruningExtension(trial, 'main/loss',
                                        triggers.ManualScheduleTrigger(1, 'epoch'))
    assert isinstance(extension.pruner_trigger, triggers.ManualScheduleTrigger)

    with pytest.raises(TypeError):
        ChainerPruningExtension(trial, 'main/loss', triggers.TimeTrigger(1.))