Exemplo n.º 1
0
def test_get_features_names_out_classifier_error():
    """Check that error is raised when voting="soft" and flatten_transform=False."""
    X = [[1, 2], [3, 4], [5, 6]]
    y = [0, 1, 2]

    voting = VotingClassifier(
        estimators=[
            ("lr", LogisticRegression(random_state=0)),
            ("tree", DecisionTreeClassifier(random_state=0)),
        ],
        voting="soft",
        flatten_transform=False,
    )
    voting.fit(X, y)

    msg = ("get_feature_names_out is not supported when `voting='soft'` and "
           "`flatten_transform=False`")
    with pytest.raises(ValueError, match=msg):
        voting.get_feature_names_out()
Exemplo n.º 2
0
def test_get_features_names_out_classifier(kwargs, expected_names):
    """Check get_feature_names_out for classifier for different settings."""
    X = [[1, 2], [3, 4], [5, 6], [1, 1.2]]
    y = [0, 1, 2, 0]

    voting = VotingClassifier(
        estimators=[
            ("lr", LogisticRegression(random_state=0)),
            ("tree", DecisionTreeClassifier(random_state=0)),
        ],
        **kwargs,
    )
    voting.fit(X, y)
    X_trans = voting.transform(X)
    names_out = voting.get_feature_names_out()

    assert X_trans.shape[1] == len(expected_names)
    assert_array_equal(names_out, expected_names)