Esempio n. 1
0
def test_ogbm_regressor(tmp_path: pathlib.Path) -> None:
    pytest.importorskip("sklearn", minversion="0.20.0")

    # from sklearn.utils.estimator_checks import check_estimator
    from sklearn.utils.estimator_checks import check_estimators_pickle
    from sklearn.utils.estimator_checks import check_set_params

    reg = OGBMRegressor(model_dir=tmp_path)
    name = reg.__class__.__name__

    # check_estimator(reg)

    check_estimators_pickle(name, reg)
    check_set_params(name, reg)
Esempio n. 2
0
def test_ogbm_classifier(tmp_path: pathlib.Path) -> None:
    pytest.importorskip("sklearn", minversion="0.20.0")

    # from sklearn.utils.estimator_checks import check_estimator
    from sklearn.utils.estimator_checks import check_estimators_pickle
    from sklearn.utils.estimator_checks import check_set_params

    clf = OGBMClassifier(model_dir=tmp_path)
    name = clf.__class__.__name__

    # check_estimator(clf)

    check_estimators_pickle(name, clf)
    check_set_params(name, clf)
Esempio n. 3
0
def test_callbacks():
    """Test estimators with callbacks.
    """
    estimator = KerasClassifier(
        model=dynamic_classifier,
        callbacks=(SentinalCallback(),),
        optimizer="adam",
        model__hidden_layer_sizes=(100,),
    )
    # Check for picklign and partial fit
    check_estimators_pickle("KerasClassifier", estimator)
    check_estimators_partial_fit_n_features("KerasClassifier", estimator)
    # Check that callback was called
    estimator.fit([[0]], [1])  # quick fit
    assert estimator.callbacks[0].called != SentinalCallback.called
    serialized_estimator = pickle.dumps(estimator)
    deserialized_estimator = pickle.loads(serialized_estimator)
    assert deserialized_estimator.callbacks[0].called == estimator.callbacks[0].called
    estimator.fit([[0]], [1])  # quick fit
def test_estimators_pickle():
    checks.check_estimators_pickle(NAME, CLF)
Esempio n. 5
0
def test_pickling(name, est):
    if hasattr(est, 'fit'):
        check_estimators_pickle('b', est)