def test_get_best_booster(self) -> None: unexpected_value = 20 # out of scope. params = {"verbose": -1, "lambda_l1": unexpected_value} # type: Dict dataset = lgb.Dataset(np.zeros((10, 10))) study = optuna.create_study() tuner = LightGBMTuner(params, dataset, valid_sets=dataset, study=study) with pytest.raises(ValueError): tuner.get_best_booster() with mock.patch.object(BaseTuner, "_get_booster_best_score", return_value=0.0): tuner.tune_regularization_factors() best_booster = tuner.get_best_booster() assert best_booster.params["lambda_l1"] != unexpected_value # TODO(toshihikoyanase): Remove this check when LightGBMTuner.best_booster is removed. with pytest.warns(DeprecationWarning): tuner.best_booster tuner2 = LightGBMTuner(params, dataset, valid_sets=dataset, study=study) # Resumed study does not have the best booster. with pytest.raises(ValueError): tuner2.get_best_booster()
def test_best_booster_with_model_dir(self) -> None: params = {"verbose": -1} # type: Dict dataset = lgb.Dataset(np.zeros((10, 10))) study = optuna.create_study() with TemporaryDirectory() as tmpdir: tuner = LightGBMTuner(params, dataset, valid_sets=dataset, study=study, model_dir=tmpdir) with mock.patch.object(BaseTuner, "_get_booster_best_score", return_value=0.0): tuner.tune_regularization_factors() best_booster = tuner.get_best_booster() tuner2 = LightGBMTuner(params, dataset, valid_sets=dataset, study=study, model_dir=tmpdir) best_booster2 = tuner2.get_best_booster() assert best_booster.params == best_booster2.params
def test_resume_run(self) -> None: params = {"verbose": -1} # type: Dict dataset = lgb.Dataset(np.zeros((10, 10))) study = optuna.create_study() tuner = LightGBMTunerCV(params, dataset, study=study) with mock.patch.object(OptunaObjectiveCV, "_get_cv_scores", return_value=[1.0]): tuner.tune_regularization_factors() n_trials = len(study.trials) assert n_trials == len(study.trials) tuner2 = LightGBMTuner(params, dataset, valid_sets=dataset, study=study) with mock.patch.object(OptunaObjectiveCV, "_get_cv_scores", return_value=[1.0]): tuner2.tune_regularization_factors() assert n_trials == len(study.trials)