Пример #1
0
def test_calling_fit_reinitializes():
    est = LinearSVC(random_state=0)
    transformer = SelectFromModel(estimator=est)
    transformer.fit(data, y)
    transformer.set_params(estimator__C=100)
    transformer.fit(data, y)
    assert_equal(transformer.estimator_.C, 100)
Пример #2
0
def test_calling_fit_reinitializes():
    est = LinearSVC(random_state=0)
    transformer = SelectFromModel(estimator=est)
    transformer.fit(data, y)
    transformer.set_params(estimator__C=100)
    transformer.fit(data, y)
    assert transformer.estimator_.C == 100
Пример #3
0
def test_prefit_max_features():
    """Check the interaction between `prefit` and `max_features`."""
    # case 1: an error should be raised at `transform` if `fit` was not called to
    # validate the attributes
    estimator = RandomForestClassifier(n_estimators=5, random_state=0)
    estimator.fit(data, y)
    model = SelectFromModel(estimator,
                            prefit=True,
                            max_features=lambda X: X.shape[1])

    err_msg = (
        "When `prefit=True` and `max_features` is a callable, call `fit` "
        "before calling `transform`.")
    with pytest.raises(NotFittedError, match=err_msg):
        model.transform(data)

    # case 2: `max_features` is not validated and different from an integer
    # FIXME: we cannot validate the upper bound of the attribute at transform
    # and we should force calling `fit` if we intend to force the attribute
    # to have such an upper bound.
    max_features = 2.5
    model.set_params(max_features=max_features)
    with pytest.raises(ValueError, match="`max_features` must be an integer"):
        model.transform(data)