예제 #1
0
def test_custom_cascade_layer_workflow_in_memory():

    model = CascadeForestClassifier()

    n_estimators = 4
    estimators = [DecisionTreeClassifier() for _ in range(n_estimators)]
    model.set_estimator(estimators)  # set custom base estimators

    predictor = DecisionTreeClassifier()
    model.set_predictor(predictor)

    model.fit(X_train, y_train)
    y_pred_before = model.predict(X_test)

    # Save and Reload
    model.save(save_dir)

    model = CascadeForestClassifier()
    model.load(save_dir)

    # Predictions after loading
    y_pred_after = model.predict(X_test)

    # Make sure the same predictions before and after model serialization
    assert_array_equal(y_pred_before, y_pred_after)

    assert (model.get_estimator(0, 0, "custom") is
            model._get_layer(0).estimators_["0-0-custom"].estimator_)

    model.clean()  # clear the buffer
    shutil.rmtree(save_dir)
def test_model_properties_after_fitting():
    """Check the model properties after fitting a deep forest model."""
    model = CascadeForestClassifier(**toy_kwargs)
    model.fit(X_train, y_train)

    assert len(model) == model.n_layers_

    assert model[0] is model._get_layer(0)

    with pytest.raises(ValueError) as excinfo:
        model._get_layer(model.n_layers_)
    assert "The layer index should be in the range" in str(excinfo.value)

    with pytest.raises(RuntimeError) as excinfo:
        model._set_layer(0, None)
    assert "already exists in the internal container" in str(excinfo.value)

    with pytest.raises(ValueError) as excinfo:
        model._get_binner(model.n_layers_ + 1)
    assert "The binner index should be in the range" in str(excinfo.value)

    with pytest.raises(RuntimeError) as excinfo:
        model._set_binner(0, None)
    assert "already exists in the internal container" in str(excinfo.value)

    # Test the hook on forest estimator
    assert (
        model.get_estimator(0, 0, "rf")
        is model._get_layer(0).estimators_["0-0-rf"].estimator_
    )

    with pytest.raises(ValueError) as excinfo:
        model.get_estimator(model.n_layers_, 0, "rf")
    assert "`layer_idx` should be in the range" in str(excinfo.value)

    with pytest.raises(ValueError) as excinfo:
        model.get_estimator(0, model.n_estimators, "rf")
    assert "`est_idx` should be in the range" in str(excinfo.value)

    with pytest.raises(ValueError) as excinfo:
        model.get_estimator(0, 0, "Unknown")
    assert "`estimator_type` should be one of" in str(excinfo.value)