コード例 #1
0
ファイル: test_common.py プロジェクト: shaypal5/skift
def test_col_lbl():
    ftdf = _ftdf()
    ft_clf = ColLblBasedFtClassifier('txt', epoch=14)
    ft_clf.fit(ftdf[['txt']], ftdf['lbl'])

    assert ft_clf.predict(pd.DataFrame(
        [['woof woof']], columns=['txt']))[0] == 0
    assert ft_clf.predict(
        pd.DataFrame([['meow meow']], columns=['txt']))[0] == 1
    assert ft_clf.predict(
        pd.DataFrame([['meow']], columns=['txt']))[0] == 1
    assert ft_clf.predict(
        pd.DataFrame([['woof lol']], columns=['txt']))[0] == 0
    assert ft_clf.predict(
        pd.DataFrame([['meow lolz']], columns=['txt']))[0] == 1
コード例 #2
0
ファイル: test_common.py プロジェクト: shaypal5/skift
def test_bad_param():
    ftdf = _ftdf()
    ft_clf = ColLblBasedFtClassifier('txt', bad_param=14)
    with pytest.raises(TypeError):
        ft_clf.fit(ftdf[['txt']], ftdf['lbl'])