示例#1
0
def test_graphpipeline_merging_node():

    gpipeline = GraphPipeline(
        {
            "ColNum": ColumnsSelector(columns_to_use=["num1", "num2", "num3"]),
            "ColCat": ColumnsSelector(columns_to_use=["cat1", "cat2"]),
            "Pt": DebugPassThrough(debug=True),
        },
        edges=[("ColNum", "Pt"), ("ColCat", "Pt")],
    )

    gpipeline.fit(dfX, y)

    pt = gpipeline.models["Pt"]
    assert pt._expected_columns == ["num1", "num2", "num3", "cat1", "cat2"]
    assert pt._expected_type == DataTypes.DataFrame
    assert pt._expected_nbcols == 5

    dfX_transformed = gpipeline.transform(dfX)
    assert (dfX_transformed == dfX.loc[:, ["num1", "num2", "num3", "cat1", "cat2"]]).all().all()

    assert gpipeline.get_feature_names() == ["num1", "num2", "num3", "cat1", "cat2"]
    assert gpipeline.get_feature_names_at_node("Pt") == ["num1", "num2", "num3", "cat1", "cat2"]
    assert gpipeline.get_feature_names_at_node("ColNum") == ["num1", "num2", "num3"]
    assert gpipeline.get_feature_names_at_node("ColCat") == ["cat1", "cat2"]

    assert gpipeline.get_input_features_at_node("ColNum") == list(dfX.columns)
    assert gpipeline.get_input_features_at_node("ColCat") == list(dfX.columns)
    assert gpipeline.get_input_features_at_node("Pt") == ["num1", "num2", "num3", "cat1", "cat2"]

    # concatenation in the other oreder
    gpipeline = GraphPipeline(
        {
            "ColNum": ColumnsSelector(columns_to_use=["num1", "num2", "num3"]),
            "ColCat": ColumnsSelector(columns_to_use=["cat1", "cat2"]),
            "Pt": DebugPassThrough(debug=True),
        },
        edges=[("ColCat", "Pt"), ("ColNum", "Pt")],
    )

    gpipeline.fit(dfX, y)

    pt = gpipeline.models["Pt"]
    assert pt._expected_columns == ["cat1", "cat2", "num1", "num2", "num3"]  # Concanteation in the order of the edges
    assert pt._expected_type == DataTypes.DataFrame
    assert pt._expected_nbcols == 5

    assert gpipeline.get_feature_names() == ["cat1", "cat2", "num1", "num2", "num3"]
    assert gpipeline.get_feature_names_at_node("Pt") == ["cat1", "cat2", "num1", "num2", "num3"]
    assert gpipeline.get_feature_names_at_node("ColNum") == ["num1", "num2", "num3"]
    assert gpipeline.get_feature_names_at_node("ColCat") == ["cat1", "cat2"]

    assert gpipeline.get_input_features_at_node("ColNum") == list(dfX.columns)
    assert gpipeline.get_input_features_at_node("ColCat") == list(dfX.columns)
    assert gpipeline.get_input_features_at_node("Pt") == ["cat1", "cat2", "num1", "num2", "num3"]

    dfX_transformed = gpipeline.transform(dfX)
    assert (dfX_transformed == dfX.loc[:, ["cat1", "cat2", "num1", "num2", "num3"]]).all().all()
示例#2
0
def test_graphpipeline_concat_names():

    df = get_sample_df(size=100, seed=123)
    gpipeline = GraphPipeline(
        models={
            "sel": ColumnsSelector(columns_to_use=["float_col", "int_col"]),
            "vec": CountVectorizerWrapper(columns_to_use=["text_col"]),
            "pt": PassThrough(),
        },
        edges=[("sel", "pt"), ("vec", "pt")],
    )

    gpipeline.fit(df)
    df_res = gpipeline.transform(df)

    assert list(df_res.columns) == [
        "float_col",
        "int_col",
        "text_col__BAG__aaa",
        "text_col__BAG__bbb",
        "text_col__BAG__ccc",
        "text_col__BAG__ddd",
        "text_col__BAG__eee",
        "text_col__BAG__fff",
        "text_col__BAG__jjj",
    ]

    assert gpipeline.get_feature_names() == list(df_res.columns)