Ejemplo n.º 1
0
def test_lightgbm_pruning_callback_call(cv):
    # type: (bool) -> None

    callback_env = partial(
        lgb.callback.CallbackEnv,
        model='test',
        params={},
        begin_iteration=0,
        end_iteration=1,
        iteration=1)

    if cv:
        env = callback_env(evaluation_result_list=[(('cv_agg', 'binary_error', 1., False, 1.))])
    else:
        env = callback_env(evaluation_result_list=[('validation', 'binary_error', 1., False)])

    # The pruner is deactivated.
    study = optuna.create_study(pruner=DeterministicPruner(False))
    trial = create_running_trial(study, 1.0)
    pruning_callback = LightGBMPruningCallback(trial, 'binary_error', valid_name='validation')
    pruning_callback(env)

    # The pruner is activated.
    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    pruning_callback = LightGBMPruningCallback(trial, 'binary_error', valid_name='validation')
    with pytest.raises(optuna.structs.TrialPruned):
        pruning_callback(env)
Ejemplo n.º 2
0
def test_lightgbm_pruning_callback_call(cv: bool) -> None:

    callback_env = partial(
        lgb.callback.CallbackEnv,
        model="test",
        params={},
        begin_iteration=0,
        end_iteration=1,
        iteration=1,
    )

    if cv:
        env = callback_env(evaluation_result_list=[(("cv_agg", "binary_error", 1.0, False, 1.0))])
    else:
        env = callback_env(evaluation_result_list=[("validation", "binary_error", 1.0, False)])

    # The pruner is deactivated.
    study = optuna.create_study(pruner=DeterministicPruner(False))
    trial = create_running_trial(study, 1.0)
    pruning_callback = LightGBMPruningCallback(trial, "binary_error", valid_name="validation")
    pruning_callback(env)

    # The pruner is activated.
    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = create_running_trial(study, 1.0)
    pruning_callback = LightGBMPruningCallback(trial, "binary_error", valid_name="validation")
    with pytest.raises(optuna.TrialPruned):
        pruning_callback(env)
Ejemplo n.º 3
0
def test_lightgbm_pruning_callback_call():
    # type: () -> None

    env = lgb.callback.CallbackEnv(model='test',
                                   params={},
                                   begin_iteration=0,
                                   end_iteration=1,
                                   iteration=1,
                                   evaluation_result_list=[
                                       ('validation', 'binary_error', 1.,
                                        False)
                                   ])

    # The pruner is deactivated.
    study = optuna.create_study(pruner=DeterministicPruner(False))
    trial = study._run_trial(func=lambda _: 1.0, catch=(Exception, ))
    pruning_callback = LightGBMPruningCallback(trial,
                                               'binary_error',
                                               valid_name='validation')
    pruning_callback(env)

    # The pruner is activated.
    study = optuna.create_study(pruner=DeterministicPruner(True))
    trial = study._run_trial(func=lambda _: 1.0, catch=(Exception, ))
    pruning_callback = LightGBMPruningCallback(trial,
                                               'binary_error',
                                               valid_name='validation')
    with pytest.raises(optuna.structs.TrialPruned):
        pruning_callback(env)
Ejemplo n.º 4
0
def objective(trial,
              metric='binary_error',
              valid_name='valid_0',
              force_default_valid_names=False):
    # type: (optuna.trial.Trial, str, str, bool) -> float

    dtrain = lgb.Dataset([[1.]], label=[1.])
    dtest = lgb.Dataset([[1.]], label=[1.])

    if force_default_valid_names:
        valid_names = None
    else:
        valid_names = [valid_name]

    pruning_callback = LightGBMPruningCallback(trial,
                                               metric,
                                               valid_name=valid_name)
    lgb.train({
        'objective': 'binary',
        'metric': ['auc', 'binary_error']
    },
              dtrain,
              1,
              valid_sets=[dtest],
              valid_names=valid_names,
              verbose_eval=False,
              callbacks=[pruning_callback])
    return 1.0
Ejemplo n.º 5
0
def objective(
    trial: optuna.trial.Trial,
    metric: str = "binary_error",
    valid_name: str = "valid_0",
    interval: int = 1,
    num_boost_round: int = 1,
    force_default_valid_names: bool = False,
    cv: bool = False,
) -> float:

    dtrain = lgb.Dataset(np.asarray([[1.0], [2.0], [3.0]]),
                         label=[1.0, 0.0, 1.0])
    dtest = lgb.Dataset(np.asarray([[1.0]]), label=[1.0])

    if force_default_valid_names:
        valid_names = None
    else:
        valid_names = [valid_name]

    pruning_callback = LightGBMPruningCallback(trial,
                                               metric,
                                               valid_name=valid_name,
                                               report_interval=interval)
    if cv:
        lgb.cv(
            {
                "objective": "binary",
                "metric": ["auc", "binary_error"]
            },
            dtrain,
            num_boost_round,
            verbose_eval=False,
            nfold=2,
            callbacks=[pruning_callback],
        )
    else:
        lgb.train(
            {
                "objective": "binary",
                "metric": ["auc", "binary_error"]
            },
            dtrain,
            num_boost_round,
            valid_sets=[dtest],
            valid_names=valid_names,
            verbose_eval=False,
            callbacks=[pruning_callback],
        )
    return 1.0
Ejemplo n.º 6
0
def objective(trial,
              metric="binary_error",
              valid_name="valid_0",
              force_default_valid_names=False,
              cv=False):
    # type: (optuna.trial.Trial, str, str, bool, bool) -> float

    dtrain = lgb.Dataset([[1.0], [2.0], [3.0]], label=[1.0, 0.0, 1.0])
    dtest = lgb.Dataset([[1.0]], label=[1.0])

    if force_default_valid_names:
        valid_names = None
    else:
        valid_names = [valid_name]

    pruning_callback = LightGBMPruningCallback(trial,
                                               metric,
                                               valid_name=valid_name)
    if cv:
        lgb.cv(
            {
                "objective": "binary",
                "metric": ["auc", "binary_error"]
            },
            dtrain,
            1,
            verbose_eval=False,
            nfold=2,
            callbacks=[pruning_callback],
        )
    else:
        lgb.train(
            {
                "objective": "binary",
                "metric": ["auc", "binary_error"]
            },
            dtrain,
            1,
            valid_sets=[dtest],
            valid_names=valid_names,
            verbose_eval=False,
            callbacks=[pruning_callback],
        )
    return 1.0