def test_ogbm_regressor(tmp_path: pathlib.Path) -> None: pytest.importorskip("sklearn", minversion="0.20.0") # from sklearn.utils.estimator_checks import check_estimator from sklearn.utils.estimator_checks import check_estimators_pickle from sklearn.utils.estimator_checks import check_set_params reg = OGBMRegressor(model_dir=tmp_path) name = reg.__class__.__name__ # check_estimator(reg) check_estimators_pickle(name, reg) check_set_params(name, reg)
def test_ogbm_classifier(tmp_path: pathlib.Path) -> None: pytest.importorskip("sklearn", minversion="0.20.0") # from sklearn.utils.estimator_checks import check_estimator from sklearn.utils.estimator_checks import check_estimators_pickle from sklearn.utils.estimator_checks import check_set_params clf = OGBMClassifier(model_dir=tmp_path) name = clf.__class__.__name__ # check_estimator(clf) check_estimators_pickle(name, clf) check_set_params(name, clf)
def test_callbacks(): """Test estimators with callbacks. """ estimator = KerasClassifier( model=dynamic_classifier, callbacks=(SentinalCallback(),), optimizer="adam", model__hidden_layer_sizes=(100,), ) # Check for picklign and partial fit check_estimators_pickle("KerasClassifier", estimator) check_estimators_partial_fit_n_features("KerasClassifier", estimator) # Check that callback was called estimator.fit([[0]], [1]) # quick fit assert estimator.callbacks[0].called != SentinalCallback.called serialized_estimator = pickle.dumps(estimator) deserialized_estimator = pickle.loads(serialized_estimator) assert deserialized_estimator.callbacks[0].called == estimator.callbacks[0].called estimator.fit([[0]], [1]) # quick fit
def test_estimators_pickle(): checks.check_estimators_pickle(NAME, CLF)
def test_pickling(name, est): if hasattr(est, 'fit'): check_estimators_pickle('b', est)