示例#1
0
def test_graphpipeline_blockselector_cv():

    Xnum, y = make_classification(n_samples=100)

    dfX_text = pd.DataFrame({
        "text1": get_random_strings(100),
        "text2": get_random_strings(100)
    })

    ### X = dico
    X = {"text": dfX_text, "num": Xnum}

    graphpipeline = GraphPipeline(
        models={
            "BS_text": BlockSelector("text"),
            "CV": CountVectorizerWrapper(analyzer="char"),
            "BS_num": BlockSelector("num"),
            "RF": DecisionTreeClassifier(),
        },
        edges=[("BS_text", "CV", "RF"), ("BS_num", "RF")],
    )

    from sklearn.model_selection import cross_val_score

    with pytest.raises(ValueError):
        cv_res = cross_val_score(graphpipeline,
                                 X,
                                 y,
                                 scoring="accuracy",
                                 cv=10)
        # doesn't work, can't subset dictionnary

    X = BlockManager({"text": dfX_text, "num": Xnum})

    graphpipeline = GraphPipeline(
        models={
            "BS_text": BlockSelector("text"),
            "CV": CountVectorizerWrapper(analyzer="char"),
            "BS_num": BlockSelector("num"),
            "RF": DecisionTreeClassifier(),
        },
        edges=[("BS_text", "CV", "RF"), ("BS_num", "RF")],
    )

    cv_res = cross_val_score(graphpipeline, X, y, scoring="accuracy", cv=10)

    assert len(cv_res) == 10
示例#2
0
def test_graphpipeline_blockselector():

    Xnum, y = make_classification(n_samples=100)

    dfX_text = pd.DataFrame({"text1": get_random_strings(100), "text2": get_random_strings(100)})

    X = {"text": dfX_text, "num": Xnum}

    graphpipeline = GraphPipeline(
        models={
            "BS_text": BlockSelector("text"),
            "CV": CountVectorizerWrapper(analyzer="char"),
            "BS_num": BlockSelector("num"),
            "RF": DecisionTreeClassifier(),
        },
        edges=[("BS_text", "CV", "RF"), ("BS_num", "RF")],
    )

    graphpipeline.fit(X, y)
    yhat = graphpipeline.predict(X)

    assert yhat.ndim == 1
    assert yhat.shape[0] == y.shape[0]

    ### X = dico ###
    X = {"text": dfX_text, "num": Xnum}

    graphpipeline = GraphPipeline(
        models={"BS_text": BlockSelector("text"), "BS_num": BlockSelector("num"), "PT": DebugPassThrough()},
        edges=[("BS_text", "PT"), ("BS_num", "PT")],
    )

    Xhat = graphpipeline.fit_transform(X)

    assert Xhat.shape[0] == dfX_text.shape[0]
    assert Xhat.shape[1] == dfX_text.shape[1] + Xnum.shape[1]

    assert "text1" in Xhat.columns
    assert "text2" in Xhat.columns
    assert (Xhat.loc[:, ["text1", "text2"]] == dfX_text).all().all()

    cols = diff(list(Xhat.columns), ["text1", "text2"])
    assert (Xhat.loc[:, cols].values == Xnum).all()

    ### X = list
    X = [dfX_text, Xnum]

    graphpipeline = GraphPipeline(
        models={"BS_text": BlockSelector(0), "BS_num": BlockSelector(1), "PT": DebugPassThrough()},
        edges=[("BS_text", "PT"), ("BS_num", "PT")],
    )

    Xhat = graphpipeline.fit_transform(X)

    assert Xhat.shape[0] == dfX_text.shape[0]
    assert Xhat.shape[1] == dfX_text.shape[1] + Xnum.shape[1]

    assert "text1" in Xhat.columns
    assert "text2" in Xhat.columns
    assert (Xhat.loc[:, ["text1", "text2"]] == dfX_text).all().all()

    cols = diff(list(Xhat.columns), ["text1", "text2"])
    assert (Xhat.loc[:, cols].values == Xnum).all()

    ### X = DataManager
    X = BlockManager({"text": dfX_text, "num": Xnum})

    graphpipeline = GraphPipeline(
        models={"BS_text": BlockSelector("text"), "BS_num": BlockSelector("num"), "PT": DebugPassThrough()},
        edges=[("BS_text", "PT"), ("BS_num", "PT")],
    )

    Xhat = graphpipeline.fit_transform(X)

    assert Xhat.shape[0] == dfX_text.shape[0]
    assert Xhat.shape[1] == dfX_text.shape[1] + Xnum.shape[1]

    assert "text1" in Xhat.columns
    assert "text2" in Xhat.columns
    assert (Xhat.loc[:, ["text1", "text2"]] == dfX_text).all().all()

    cols = diff(list(Xhat.columns), ["text1", "text2"])
    assert (Xhat.loc[:, cols].values == Xnum).all()