def assert_tree_classifier_explained(clf, vec, target_names, **explain_kwargs): top = 30 get_res = lambda: explain_weights( clf, vec=vec, target_names=target_names, top=top, **explain_kwargs) res = get_res() expl_text, expl_html = format_as_all(res, clf) for expl in [expl_text, expl_html]: assert 'feature importances' in expl assert 'god' in expl # high-ranked feature if len(res.feature_importances.importances) > top: assert 'more features' in expl or 'more …' in expl if isinstance(clf, (DecisionTreeClassifier, OneVsRestClassifier)): if _graphviz.is_supported(): assert '<svg' in expl_html else: assert '<svg' not in expl_html assert res == get_res()
def test_explain_tree_classifier(newsgroups_train, clf): docs, y, target_names = newsgroups_train vec = CountVectorizer() X = vec.fit_transform(docs) clf.fit(X.toarray(), y) top = 30 get_res = lambda: explain_weights( clf, vec=vec, target_names=target_names, top=top) res = get_res() expl_text, expl_html = format_as_all(res, clf) for expl in [expl_text, expl_html]: assert 'feature importances' in expl assert 'god' in expl # high-ranked feature if len(res.feature_importances.importances) > top: assert 'more features' in expl or 'more …' in expl if isinstance(clf, (DecisionTreeClassifier, OneVsRestClassifier)): if _graphviz.is_supported(): assert '<svg' in expl_html else: assert '<svg' not in expl_html assert res == get_res()
def _format_decision_tree(treedict): if treedict.graphviz and _graphviz.is_supported(): return _graphviz.dot2svg(treedict.graphviz) else: return tree2text(treedict)