Exemplo n.º 1
0
def test_explain_linear_feature_filter(newsgroups_train, vec):
    clf = LogisticRegression(random_state=42)
    docs, y, target_names = newsgroups_train
    X = vec.fit_transform(docs)
    clf.fit(X, y)
    if isinstance(vec, HashingVectorizer):
        vec = InvertableHashingVectorizer(vec)
        vec.fit(docs)

    res = explain_weights(clf, vec=vec, feature_re='^ath')
    text_expl, _ = expls = format_as_all(res, clf)
    for expl in expls:
        assert 'atheists' in expl
        assert 'atheism' in expl
        assert 'space' not in expl
        assert 'BIAS' not in expl

    res = explain_weights(
        clf,
        vec=vec,
        feature_filter=lambda name: name.startswith('ath') or name == '<BIAS>')
    text_expl, _ = expls = format_as_all(res, clf)
    for expl in expls:
        assert 'atheists' in expl
        assert 'atheism' in expl
        assert 'space' not in expl
        assert 'BIAS' in expl
    assert '<BIAS>' in text_expl
Exemplo n.º 2
0
def test_explain_hashing_vectorizer(newsgroups_train_binary):
    # test that we can pass InvertableHashingVectorizer explicitly
    vec = HashingVectorizer(n_features=1000)
    ivec = InvertableHashingVectorizer(vec)
    clf = LogisticRegression(random_state=42)
    docs, y, target_names = newsgroups_train_binary
    ivec.fit([docs[0]])
    X = vec.fit_transform(docs)
    clf.fit(X, y)

    get_res = lambda **kwargs: explain_prediction(
        clf, docs[0], vec=ivec, target_names=target_names, top=20, **kwargs)
    res = get_res()
    check_explain_linear_binary(res, clf)
    assert res == get_res()
    res_vectorized = explain_prediction(clf,
                                        vec.transform([docs[0]])[0],
                                        vec=ivec,
                                        target_names=target_names,
                                        top=20,
                                        vectorized=True)
    pprint(res_vectorized)
    assert res_vectorized == _without_weighted_spans(res)

    assert res == get_res(feature_names=ivec.get_feature_names(
        always_signed=False))
Exemplo n.º 3
0
def test_explain_linear_hashed(newsgroups_train, clf):
    docs, y, target_names = newsgroups_train
    vec = HashingVectorizer(n_features=10000)
    ivec = InvertableHashingVectorizer(vec)

    X = vec.fit_transform(docs)
    clf.fit(X, y)

    # use half of the docs to find common terms, to make it more realistic
    ivec.fit(docs[::2])

    check_newsgroups_explanation_linear(clf, ivec, target_names)
Exemplo n.º 4
0
def test_explain_linear_hashed_pos_neg(newsgroups_train, pass_feature_weights):
    docs, y, target_names = newsgroups_train
    # make it binary
    y = y.copy()
    y[y != 0] = 1
    target_names = [target_names[0], 'other']
    vec = HashingVectorizer(norm=None)
    ivec = InvertableHashingVectorizer(vec)

    clf = LogisticRegression(random_state=42)
    clf.fit(vec.fit_transform(docs), y)
    ivec.fit(docs)
    if pass_feature_weights:
        res = explain_weights(
            clf,
            top=(10, 10),
            target_names=target_names,
            feature_names=ivec.get_feature_names(always_signed=False),
            coef_scale=ivec.column_signs_)
    else:
        res = explain_weights(clf,
                              ivec,
                              top=(10, 10),
                              target_names=target_names)

    # HashingVectorizer with norm=None is "the same" as CountVectorizer,
    # so we can compare it and check that explanation is almost the same.
    count_vec = CountVectorizer()
    count_clf = LogisticRegression(random_state=42)
    count_clf.fit(count_vec.fit_transform(docs), y)
    count_res = explain_weights(count_clf,
                                vec=count_vec,
                                top=(10, 10),
                                target_names=target_names)

    for key in ['pos', 'neg']:
        values, count_values = [
            sorted(get_names_coefs(getattr(r.targets[0].feature_weights, key)))
            for r in [res, count_res]
        ]
        assert len(values) == len(count_values)
        for (name, coef), (count_name,
                           count_coef) in zip(values, count_values):
            assert name == count_name
            assert abs(coef - count_coef) < 0.05