예제 #1
0
def assert_tree_classifier_explained(clf, vec, target_names, **explain_kwargs):
    top = 30
    get_res = lambda: explain_weights(
        clf, vec=vec, target_names=target_names, top=top, **explain_kwargs)
    res = get_res()
    expl_text, expl_html = format_as_all(res, clf)
    for expl in [expl_text, expl_html]:
        assert 'feature importances' in expl
        assert 'god' in expl  # high-ranked feature
        if len(res.feature_importances.importances) > top:
            assert 'more features' in expl or 'more …' in expl

    if isinstance(clf, (DecisionTreeClassifier, OneVsRestClassifier)):
        if _graphviz.is_supported():
            assert '<svg' in expl_html
        else:
            assert '<svg' not in expl_html

    assert res == get_res()
예제 #2
0
def test_explain_tree_classifier(newsgroups_train, clf):
    docs, y, target_names = newsgroups_train
    vec = CountVectorizer()
    X = vec.fit_transform(docs)
    clf.fit(X.toarray(), y)

    top = 30
    get_res = lambda: explain_weights(
        clf, vec=vec, target_names=target_names, top=top)
    res = get_res()
    expl_text, expl_html = format_as_all(res, clf)
    for expl in [expl_text, expl_html]:
        assert 'feature importances' in expl
        assert 'god' in expl  # high-ranked feature
        if len(res.feature_importances.importances) > top:
            assert 'more features' in expl or 'more &hellip;' in expl

    if isinstance(clf, (DecisionTreeClassifier, OneVsRestClassifier)):
        if _graphviz.is_supported():
            assert '<svg' in expl_html
        else:
            assert '<svg' not in expl_html

    assert res == get_res()
예제 #3
0
def _format_decision_tree(treedict):
    if treedict.graphviz and _graphviz.is_supported():
        return _graphviz.dot2svg(treedict.graphviz)
    else:
        return tree2text(treedict)