def test_graphpipeline_merging_node(): gpipeline = GraphPipeline( { "ColNum": ColumnsSelector(columns_to_use=["num1", "num2", "num3"]), "ColCat": ColumnsSelector(columns_to_use=["cat1", "cat2"]), "Pt": DebugPassThrough(debug=True), }, edges=[("ColNum", "Pt"), ("ColCat", "Pt")], ) gpipeline.fit(dfX, y) pt = gpipeline.models["Pt"] assert pt._expected_columns == ["num1", "num2", "num3", "cat1", "cat2"] assert pt._expected_type == DataTypes.DataFrame assert pt._expected_nbcols == 5 dfX_transformed = gpipeline.transform(dfX) assert (dfX_transformed == dfX.loc[:, ["num1", "num2", "num3", "cat1", "cat2"]]).all().all() assert gpipeline.get_feature_names() == ["num1", "num2", "num3", "cat1", "cat2"] assert gpipeline.get_feature_names_at_node("Pt") == ["num1", "num2", "num3", "cat1", "cat2"] assert gpipeline.get_feature_names_at_node("ColNum") == ["num1", "num2", "num3"] assert gpipeline.get_feature_names_at_node("ColCat") == ["cat1", "cat2"] assert gpipeline.get_input_features_at_node("ColNum") == list(dfX.columns) assert gpipeline.get_input_features_at_node("ColCat") == list(dfX.columns) assert gpipeline.get_input_features_at_node("Pt") == ["num1", "num2", "num3", "cat1", "cat2"] # concatenation in the other oreder gpipeline = GraphPipeline( { "ColNum": ColumnsSelector(columns_to_use=["num1", "num2", "num3"]), "ColCat": ColumnsSelector(columns_to_use=["cat1", "cat2"]), "Pt": DebugPassThrough(debug=True), }, edges=[("ColCat", "Pt"), ("ColNum", "Pt")], ) gpipeline.fit(dfX, y) pt = gpipeline.models["Pt"] assert pt._expected_columns == ["cat1", "cat2", "num1", "num2", "num3"] # Concanteation in the order of the edges assert pt._expected_type == DataTypes.DataFrame assert pt._expected_nbcols == 5 assert gpipeline.get_feature_names() == ["cat1", "cat2", "num1", "num2", "num3"] assert gpipeline.get_feature_names_at_node("Pt") == ["cat1", "cat2", "num1", "num2", "num3"] assert gpipeline.get_feature_names_at_node("ColNum") == ["num1", "num2", "num3"] assert gpipeline.get_feature_names_at_node("ColCat") == ["cat1", "cat2"] assert gpipeline.get_input_features_at_node("ColNum") == list(dfX.columns) assert gpipeline.get_input_features_at_node("ColCat") == list(dfX.columns) assert gpipeline.get_input_features_at_node("Pt") == ["cat1", "cat2", "num1", "num2", "num3"] dfX_transformed = gpipeline.transform(dfX) assert (dfX_transformed == dfX.loc[:, ["cat1", "cat2", "num1", "num2", "num3"]]).all().all()
def test_graphpipeline_concat_names(): df = get_sample_df(size=100, seed=123) gpipeline = GraphPipeline( models={ "sel": ColumnsSelector(columns_to_use=["float_col", "int_col"]), "vec": CountVectorizerWrapper(columns_to_use=["text_col"]), "pt": PassThrough(), }, edges=[("sel", "pt"), ("vec", "pt")], ) gpipeline.fit(df) df_res = gpipeline.transform(df) assert list(df_res.columns) == [ "float_col", "int_col", "text_col__BAG__aaa", "text_col__BAG__bbb", "text_col__BAG__ccc", "text_col__BAG__ddd", "text_col__BAG__eee", "text_col__BAG__fff", "text_col__BAG__jjj", ] assert gpipeline.get_feature_names() == list(df_res.columns)