def test_nontrivial_hierarchy_leaf_classification():
    """Test that a nontrivial hierarchy leaf classification behaves as expected.

    We build the following class hierarchy along with data from the handwritten digits dataset:

            <ROOT>
           /      \
          A        B
         / \      / \ \
        1   7    3   8  9

    """
    class_hierarchy = {
        ROOT: ["A", "B"],
        "A": [1, 7],
        "B": [3, 8, 9],
    }
    base_estimator = svm.SVC(gamma=0.001, kernel="rbf", probability=True)
    clf = make_classifier(
        base_estimator=base_estimator,
        class_hierarchy=class_hierarchy,
    )
    X, y = make_digits_dataset(targets=[1, 7, 3, 8, 9], )
    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        test_size=0.2,
        random_state=RANDOM_STATE,
    )

    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)

    assert_that(accuracy, is_(close_to(1., delta=0.02)))
示例#2
0
def test_base_estimator_as_dict():
    """Test that specifying base_estimator as a dictionary mappings nodes to base estimators works."""
    class_hierarchy = {
        ROOT: ["A", "B"],
        "A": [1, 7],
        "B": [3, 8, 9],
    }
    clf = make_classifier(
        base_estimator={
            ROOT: KNeighborsClassifier(),
            "B": svm.SVC(probability=True),
            DEFAULT: MultinomialNB(),
        },
        class_hierarchy=class_hierarchy,
    )
    X, y = make_digits_dataset(
        targets=[1, 7, 3, 8, 9],
        as_str=False,
    )
    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        test_size=0.2,
        random_state=RANDOM_STATE,
    )

    clf.fit(X_train, y_train)

    assert_that(
        isinstance(clf.graph_.nodes[ROOT][CLASSIFIER], KNeighborsClassifier))
    assert_that(isinstance(clf.graph_.nodes["B"][CLASSIFIER], svm.SVC))
    assert_that(isinstance(clf.graph_.nodes["A"][CLASSIFIER], MultinomialNB))
示例#3
0
def test_nmlnp_strategy_on_tree_with_dummy_classifier():
    """Test classification works on a tree graph when one of the nodes has out-degree 1 resulting in
    creation of a "dummy" classifier at that node to trivially predict its child."""
    # since NMLNP results in a mix of intermediate and lefa nodes,
    # make sure they are all of same dtype (str)
    class_hierarchy = {
        ROOT: ["A", "B", "C"],
        "A": ["1", "5", "6", "7"],
        "B": ["2", "3", "8", "9"],
        "C": ["4"],
    }
    base_estimator = svm.SVC(gamma=0.001, kernel="rbf", probability=True)
    clf = make_classifier(
        base_estimator=base_estimator,
        class_hierarchy=class_hierarchy,
        prediction_depth="nmlnp",
        stopping_criteria=0.9,
    )

    X, y = make_digits_dataset()
    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        test_size=0.2,
        random_state=RANDOM_STATE,
    )

    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)

    assert_that(list(y_pred), has_item("4"))
示例#4
0
def test_nmlnp_strategy_with_float_stopping_criteria():
    # since NMLNP results in a mix of intermediate and leaf nodes,
    # make sure they are all of same dtype (str)
    class_hierarchy = {
        ROOT: ["A", "B"],
        "A": ["1", "5", "6", "7"],
        "B": ["2", "3", "4", "8", "9"],
    }
    base_estimator = svm.SVC(gamma=0.001, kernel="rbf", probability=True)
    clf = make_classifier(
        base_estimator=base_estimator,
        class_hierarchy=class_hierarchy,
        prediction_depth="nmlnp",
        stopping_criteria=0.9,
    )

    X, y = make_digits_dataset()
    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        test_size=0.2,
        random_state=RANDOM_STATE,
    )

    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)

    assert_that(list(y_pred), has_item("B"))
def classify_digits():
    """Test that a nontrivial hierarchy leaf classification behaves as expected.

    We build the following class hierarchy along with data from the handwritten digits dataset:

            <ROOT>
           /      \
          A        B
         / \      / \ \
        1   7    3   8  9

    """
    class_hierarchy = {
        ROOT: ["A", "B"],
        "A": [1, 7],
        "B": [3, 8, 9],
    }
    base_estimator = make_pipeline(
        TruncatedSVD(n_components=24),
        svm.SVC(
            gamma=0.001,
            kernel="rbf",
            probability=True
        ),
    )
    clf = HierarchicalClassifier(
        base_estimator=base_estimator,
        class_hierarchy=class_hierarchy,
    )
    X, y = make_digits_dataset(
        targets=[1, 7, 3, 8, 9],
    )
    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        test_size=0.2,
        random_state=RANDOM_STATE,
    )

    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print("Classification Report:\n", classification_report(y_test, y_pred))
示例#6
0
def test_nmlnp_strategy_on_dag_with_dummy_classifier():
    """Test classification works on a "deep" DAG when one of the nodes has out-degree 1 resulting in
    creation of a "dummy" classifier at that node to triially predict its child.

    This test case actually tests a few more subtle edge cases:

    - String-based target labels with length > 1
    - Multi-level degenerate sub-graphs, e.g some nodes having a sub-graph which is a path.

    """
    # since NMLNP results in a mix of intermediate and lefa nodes,
    # make sure they are all of same dtype (str)
    class_hierarchy = {
        ROOT: ["A", "B", "C"],
        "A": ["1", "5", "6", "7"],
        "B": ["2", "BC.1", "8", "9"],
        "BC.1": ["3a"],
        "C": ["BC.1"],
    }
    base_estimator = svm.SVC(gamma=0.001, kernel="rbf", probability=True)
    clf = make_classifier(
        base_estimator=base_estimator,
        class_hierarchy=class_hierarchy,
        prediction_depth="nmlnp",
        stopping_criteria=0.9,
    )

    X, y = make_digits_dataset()
    y[where(y == "3")] = "3a"
    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        test_size=0.2,
        random_state=RANDOM_STATE,
    )

    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)

    assert_that(list(y_pred), has_item("3a"))