def test_nested_import_from():
    """
    Tests whether the WIR Extraction works for nested from imports
    """
    test_code = cleandoc("""
            from mlinspect.utils import get_project_root

            print(get_project_root())
            """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_import = WirNode(0, "mlinspect.utils", "Import",
                              CodeReference(1, 0, 1, 44))
    expected_call_one = WirNode(1, "get_project_root", "Call",
                                CodeReference(3, 6, 3, 24))
    expected_graph.add_edge(expected_import,
                            expected_call_one,
                            type="caller",
                            arg_index=-1)

    expected_call_two = WirNode(2, "print", "Call", CodeReference(3, 0, 3, 25))
    expected_graph.add_edge(expected_call_one,
                            expected_call_two,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_string_call_attribute():
    """
    Tests whether the WIR Extraction works for a very simple attribute call
    """
    test_code = cleandoc("""
        "hello ".join("world")
        """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant_one = WirNode(0, "hello ", "Constant",
                                    CodeReference(1, 0, 1, 8))
    expected_constant_two = WirNode(1, "world", "Constant",
                                    CodeReference(1, 14, 1, 21))
    expected_attribute_call = WirNode(2, "join", "Call",
                                      CodeReference(1, 0, 1, 22))
    expected_graph.add_edge(expected_constant_one,
                            expected_attribute_call,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_constant_two,
                            expected_attribute_call,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_list_creation():
    """
    Tests whether the WIR Extraction works for lists
    """
    test_code = cleandoc("""
            print(["test1", "test2"])
            """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant_one = WirNode(0, "test1", "Constant",
                                    CodeReference(1, 7, 1, 14))
    expected_constant_two = WirNode(1, "test2", "Constant",
                                    CodeReference(1, 16, 1, 23))
    expected_list = WirNode(2, "as_list", "List", CodeReference(1, 6, 1, 24))
    expected_graph.add_edge(expected_constant_one,
                            expected_list,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expected_constant_two,
                            expected_list,
                            type="input",
                            arg_index=1)

    expected_call = WirNode(3, "print", "Call", CodeReference(1, 0, 1, 25))
    expected_graph.add_edge(expected_list,
                            expected_call,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_print_expressions():
    """
    Tests whether the WIR Extraction works for an expression with very simple nested calls
    """
    test_code = cleandoc("""
        print("test".isupper())
        """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant = WirNode(0, "test", "Constant",
                                CodeReference(1, 6, 1, 12))
    expected_call_one = WirNode(1, "isupper", "Call",
                                CodeReference(1, 6, 1, 22))
    expected_graph.add_edge(expected_constant,
                            expected_call_one,
                            type="caller",
                            arg_index=-1)

    expected_call_two = WirNode(2, "print", "Call", CodeReference(1, 0, 1, 23))
    expected_graph.add_edge(expected_call_one,
                            expected_call_two,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_import_from():
    """
    Tests whether the WIR Extraction works for from imports
    """
    test_code = cleandoc("""
            from math import sqrt 

            sqrt(4)
            """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_import = WirNode(0, "math", "Import", CodeReference(1, 0, 1, 21))
    expected_constant = WirNode(1, "4", "Constant", CodeReference(3, 5, 3, 6))
    expected_constant_call = WirNode(2, "sqrt", "Call",
                                     CodeReference(3, 0, 3, 7))
    expected_graph.add_edge(expected_import,
                            expected_constant_call,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_constant,
                            expected_constant_call,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_call_after_call():
    """
    Tests whether the WIR Extraction works for a very simple attribute call
    """
    test_code = cleandoc("""
        "hello ".capitalize().count()
        """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant_one = WirNode(0, "hello ", "Constant",
                                    CodeReference(1, 0, 1, 8))
    expected_call_one = WirNode(1, "capitalize", "Call",
                                CodeReference(1, 0, 1, 21))
    expected_call_two = WirNode(2, "count", "Call", CodeReference(1, 0, 1, 29))
    expected_graph.add_edge(expected_constant_one,
                            expected_call_one,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_call_one,
                            expected_call_two,
                            type="caller",
                            arg_index=-1)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_print_var_usage():
    """
    Tests whether the WIR Extraction works for a very simple var usage
    """
    test_code = cleandoc("""
        test_var = "test"
        print(test_var)""")
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant = WirNode(0, "test", "Constant",
                                CodeReference(1, 11, 1, 17))
    expected_assign = WirNode(1, "test_var", "Assign",
                              CodeReference(1, 0, 1, 17))
    expected_graph.add_edge(expected_constant,
                            expected_assign,
                            type="input",
                            arg_index=0)

    expected_call = WirNode(2, "print", "Call", CodeReference(2, 0, 2, 15))
    expected_graph.add_node(expected_call)
    expected_graph.add_edge(expected_assign,
                            expected_call,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_index_subscript_with_module_information():
    """
    Tests whether the WIR Extraction works for lists
    """
    test_code = cleandoc("""
            import pandas as pd

            data = pd.read_csv('test_path')
            data['income-per-year']
            """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    module_info = {
        CodeReference(3, 7, 3, 31): ('pandas.io.parsers', 'read_csv'),
        CodeReference(4, 0, 4, 23): ('pandas.core.frame', '__getitem__')
    }
    extracted_wir = extractor.extract_wir()
    extractor.add_runtime_info(module_info, {})
    expected_graph = networkx.DiGraph()

    expected_import = WirNode(0, "pandas", "Import",
                              CodeReference(1, 0, 1, 19))
    expected_constant_one = WirNode(1, "test_path", "Constant",
                                    CodeReference(3, 19, 3, 30))
    expected_call = WirNode(2, "read_csv", "Call", CodeReference(3, 7, 3, 31),
                            ('pandas.io.parsers', 'read_csv'))
    expected_graph.add_edge(expected_import,
                            expected_call,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_constant_one,
                            expected_call,
                            type="input",
                            arg_index=0)

    expected_assign = WirNode(3, "data", "Assign", CodeReference(3, 0, 3, 31))
    expected_graph.add_edge(expected_call,
                            expected_assign,
                            type="input",
                            arg_index=0)

    expected_constant_two = WirNode(4, "income-per-year", "Constant",
                                    CodeReference(4, 5, 4, 22))
    expected_index_subscript = WirNode(5, "Index-Subscript", "Subscript",
                                       CodeReference(4, 0, 4, 23),
                                       ('pandas.core.frame', '__getitem__'))
    expected_graph.add_edge(expected_assign,
                            expected_index_subscript,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_constant_two,
                            expected_index_subscript,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_keyword():
    """
    Tests whether the WIR Extraction works for function calls with keyword usage
    """
    test_code = cleandoc("""
        print('comma', 'separated', 'words', sep=', ')
        """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant_one = WirNode(0, "comma", "Constant",
                                    CodeReference(1, 6, 1, 13))
    expected_constant_two = WirNode(1, "separated", "Constant",
                                    CodeReference(1, 15, 1, 26))
    expected_constant_three = WirNode(2, "words", "Constant",
                                      CodeReference(1, 28, 1, 35))
    expected_constant_four = WirNode(3, ", ", "Constant",
                                     CodeReference(1, 41, 1, 45))
    expected_keyword = WirNode(4, "sep", "Keyword")
    expected_call = WirNode(5, "print", "Call", CodeReference(1, 0, 1, 46))

    expected_graph.add_edge(expected_constant_four,
                            expected_keyword,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expected_constant_one,
                            expected_call,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expected_constant_two,
                            expected_call,
                            type="input",
                            arg_index=1)
    expected_graph.add_edge(expected_constant_three,
                            expected_call,
                            type="input",
                            arg_index=2)
    expected_graph.add_edge(expected_keyword,
                            expected_call,
                            type="input",
                            arg_index=3)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
示例#10
0
def test_pipeline_executor_function_call_info_extraction():
    """
    Tests whether the capturing of module information works
    """
    test_code = get_pandas_read_csv_and_dropna_code()

    pipeline_executor.singleton = pipeline_executor.PipelineExecutor()
    pipeline_executor.singleton.run(None, None, test_code, [])
    expected_module_info = {
        CodeReference(5, 13, 5, 85): ('posixpath', 'join'),
        CodeReference(5, 26, 5, 49): ('builtins', 'str'),
        CodeReference(5, 30, 5, 48): ('mlinspect.utils', 'get_project_root'),
        CodeReference(6, 11, 6, 34): ('pandas.io.parsers', 'read_csv'),
        CodeReference(7, 7, 7, 24): ('pandas.core.frame', 'dropna')
    }

    compare(pipeline_executor.singleton.code_reference_to_module,
            expected_module_info)
def test_print_stmt():
    """
    Tests whether the WIR Extraction works for a very simple print statement
    """
    test_code = "print('test')"
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant = WirNode(0, "test", "Constant",
                                CodeReference(1, 6, 1, 12))
    expected_call = WirNode(1, "print", "Call", CodeReference(1, 0, 1, 13))
    expected_graph.add_edge(expected_constant,
                            expected_call,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
示例#12
0
def test_pipeline_executor_function_subscript_index_info_extraction():
    """
    Tests whether the capturing of module information works
    """
    test_code = cleandoc("""
            import os
            import pandas as pd
            from mlinspect.utils import get_project_root

            train_file = os.path.join(str(get_project_root()), "test", "data", "adult_train.csv")
            raw_data = pd.read_csv(train_file, na_values='?', index_col=0)
            data = raw_data.dropna()
            data['income-per-year']
            """)

    pipeline_executor.singleton = pipeline_executor.PipelineExecutor()
    pipeline_executor.singleton.run(None, None, test_code, [])
    expected_module_info = {
        CodeReference(5, 13, 5, 85): ('posixpath', 'join'),
        CodeReference(5, 26, 5, 49): ('builtins', 'str'),
        CodeReference(5, 30, 5, 48): ('mlinspect.utils', 'get_project_root'),
        CodeReference(6, 11, 6, 62): ('pandas.io.parsers', 'read_csv'),
        CodeReference(7, 7, 7, 24): ('pandas.core.frame', 'dropna'),
        CodeReference(8, 0, 8, 23): ('pandas.core.frame', '__getitem__')
    }

    compare(pipeline_executor.singleton.code_reference_to_module,
            expected_module_info)
def test_index_assign():
    """
    Tests whether the WIR Extraction works for lists
    """
    test_code = cleandoc("""
            import pandas as pd

            data = pd.read_csv('test_path')
            data['label'] = "test"
            """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_import = WirNode(0, "pandas", "Import",
                              CodeReference(1, 0, 1, 19))
    expected_constant_one = WirNode(1, "test_path", "Constant",
                                    CodeReference(3, 19, 3, 30))
    expected_call = WirNode(2, "read_csv", "Call", CodeReference(3, 7, 3, 31))
    expected_graph.add_edge(expected_import,
                            expected_call,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_constant_one,
                            expected_call,
                            type="input",
                            arg_index=0)

    expected_assign = WirNode(3, "data", "Assign", CodeReference(3, 0, 3, 31))
    expected_graph.add_edge(expected_call,
                            expected_assign,
                            type="input",
                            arg_index=0)

    expected_constant_two = WirNode(4, "label", "Constant",
                                    CodeReference(4, 5, 4, 12))
    expected_graph.add_node(expected_constant_two)

    expected_constant_three = WirNode(5, "test", "Constant",
                                      CodeReference(4, 16, 4, 22))
    expected_graph.add_node(expected_constant_three)

    expected_subscript_assign = WirNode(6, 'data.label', 'Subscript-Assign',
                                        CodeReference(4, 0, 4, 13))
    expected_graph.add_edge(expected_assign,
                            expected_subscript_assign,
                            type="caller",
                            arg_index=-1)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
示例#14
0
def get_call_description_info():
    """
    Get the module info for the adult_easy pipeline
    """
    call_description_info = {
        CodeReference(lineno=12,
                      col_offset=11,
                      end_lineno=12,
                      end_col_offset=62):
        'adult_train.csv',
        CodeReference(lineno=14,
                      col_offset=7,
                      end_lineno=14,
                      end_col_offset=24):
        'dropna',
        CodeReference(lineno=16,
                      col_offset=38,
                      end_lineno=16,
                      end_col_offset=61):
        "to ['income-per-year']",
        CodeReference(lineno=16,
                      col_offset=9,
                      end_lineno=16,
                      end_col_offset=89):
        "label_binarize, classes: ['>50K', '<=50K']",
        CodeReference(lineno=19,
                      col_offset=20,
                      end_lineno=19,
                      end_col_offset=72):
        'Categorical Encoder (OneHotEncoder)',
        CodeReference(lineno=20,
                      col_offset=16,
                      end_lineno=20,
                      end_col_offset=46):
        'Numerical Encoder (StandardScaler)',
        CodeReference(lineno=26,
                      col_offset=19,
                      end_lineno=26,
                      end_col_offset=48):
        'Decision Tree'
    }

    return call_description_info
def test_tuple_assign():
    """
    Tests whether the WIR Extraction works for a very simple var usage
    """
    test_code = cleandoc("""
        x, y = (1, 2)
        print(x)""")
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant_one = WirNode(1, '1', 'Constant',
                                    CodeReference(1, 8, 1, 9))
    expected_constant_two = WirNode(2, '2', 'Constant',
                                    CodeReference(1, 11, 1, 12))

    expetected_constant_tuple = WirNode(3, 'as_tuple', 'Tuple',
                                        CodeReference(1, 7, 1, 13))
    expected_graph.add_edge(expected_constant_one,
                            expetected_constant_tuple,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expected_constant_two,
                            expetected_constant_tuple,
                            type="input",
                            arg_index=1)

    expected_var_x = WirNode(4, 'x', 'Assign', CodeReference(1, 0, 1, 1))
    expected_var_y = WirNode(5, 'y', 'Assign', CodeReference(1, 3, 1, 4))
    expected_graph.add_edge(expetected_constant_tuple,
                            expected_var_x,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expetected_constant_tuple,
                            expected_var_y,
                            type="input",
                            arg_index=0)

    expected_call = WirNode(6, 'print', 'Call', CodeReference(2, 0, 2, 8))
    expected_graph.add_edge(expected_var_x,
                            expected_call,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
示例#16
0
def get_module_info():
    """
    Get the module info for the adult_easy pipeline
    """
    module_info = {
        CodeReference(lineno=10,
                      col_offset=0,
                      end_lineno=10,
                      end_col_offset=23): ('builtins', 'print'),
        CodeReference(lineno=11,
                      col_offset=30,
                      end_lineno=11,
                      end_col_offset=48):
        ('mlinspect.utils', 'get_project_root'),
        CodeReference(lineno=11,
                      col_offset=26,
                      end_lineno=11,
                      end_col_offset=49): ('builtins', 'str'),
        CodeReference(lineno=11,
                      col_offset=13,
                      end_lineno=11,
                      end_col_offset=85): ('posixpath', 'join'),
        CodeReference(lineno=12,
                      col_offset=11,
                      end_lineno=12,
                      end_col_offset=62): ('pandas.io.parsers', 'read_csv'),
        CodeReference(lineno=14,
                      col_offset=7,
                      end_lineno=14,
                      end_col_offset=24): ('pandas.core.frame', 'dropna'),
        CodeReference(lineno=16,
                      col_offset=38,
                      end_lineno=16,
                      end_col_offset=61):
        ('pandas.core.frame', '__getitem__', 'Projection'),
        CodeReference(lineno=16,
                      col_offset=9,
                      end_lineno=16,
                      end_col_offset=89):
        ('sklearn.preprocessing._label', 'label_binarize'),
        CodeReference(lineno=19,
                      col_offset=20,
                      end_lineno=19,
                      end_col_offset=72): ('sklearn.preprocessing._encoders',
                                           'OneHotEncoder'),
        CodeReference(lineno=20,
                      col_offset=16,
                      end_lineno=20,
                      end_col_offset=46): ('sklearn.preprocessing._data',
                                           'StandardScaler'),
        CodeReference(lineno=18,
                      col_offset=25,
                      end_lineno=21,
                      end_col_offset=2):
        ('sklearn.compose._column_transformer', 'ColumnTransformer'),
        CodeReference(lineno=26,
                      col_offset=19,
                      end_lineno=26,
                      end_col_offset=48): ('sklearn.tree._classes',
                                           'DecisionTreeClassifier'),
        CodeReference(lineno=24,
                      col_offset=18,
                      end_lineno=26,
                      end_col_offset=51): ('sklearn.pipeline', 'Pipeline'),
        CodeReference(lineno=28,
                      col_offset=0,
                      end_lineno=28,
                      end_col_offset=33): ('sklearn.pipeline', 'fit'),
        CodeReference(lineno=31,
                      col_offset=0,
                      end_lineno=31,
                      end_col_offset=26): ('builtins', 'print')
    }

    return module_info
示例#17
0
def get_expected_dag_adult_easy_py():
    """
    Get the expected DAG for the adult_easy pipeline
    """
    # pylint: disable=too-many-locals
    expected_graph = networkx.DiGraph()

    expected_data_source = DagNode(18, OperatorType.DATA_SOURCE,
                                   CodeReference(12, 11, 12, 62),
                                   ('pandas.io.parsers', 'read_csv'),
                                   "adult_train.csv")
    expected_graph.add_node(expected_data_source)

    expected_select = DagNode(20, OperatorType.SELECTION,
                              CodeReference(14, 7, 14, 24),
                              ('pandas.core.frame', 'dropna'), "dropna")
    expected_graph.add_edge(expected_data_source, expected_select)

    expected_train_data = DagNode(56, OperatorType.TRAIN_DATA,
                                  CodeReference(24, 18, 26, 51),
                                  ('sklearn.pipeline', 'fit', 'Train Data'))
    expected_graph.add_edge(expected_select, expected_train_data)

    expected_pipeline_project_one = DagNode(
        34, OperatorType.PROJECTION, CodeReference(18, 25, 21, 2),
        ('sklearn.compose._column_transformer', 'ColumnTransformer',
         'Projection'), "to ['education'] (ColumnTransformer)")
    expected_graph.add_edge(expected_train_data, expected_pipeline_project_one)
    expected_pipeline_project_two = DagNode(
        35, OperatorType.PROJECTION, CodeReference(18, 25, 21, 2),
        ('sklearn.compose._column_transformer', 'ColumnTransformer',
         'Projection'), "to ['workclass'] (ColumnTransformer)")
    expected_graph.add_edge(expected_train_data, expected_pipeline_project_two)
    expected_pipeline_project_three = DagNode(
        40, OperatorType.PROJECTION, CodeReference(18, 25, 21, 2),
        ('sklearn.compose._column_transformer', 'ColumnTransformer',
         'Projection'), "to ['age'] (ColumnTransformer)")
    expected_graph.add_edge(expected_train_data,
                            expected_pipeline_project_three)
    expected_pipeline_project_four = DagNode(
        41, OperatorType.PROJECTION, CodeReference(18, 25, 21, 2),
        ('sklearn.compose._column_transformer', 'ColumnTransformer',
         'Projection'), "to ['hours-per-week'] (ColumnTransformer)")
    expected_graph.add_edge(expected_train_data,
                            expected_pipeline_project_four)

    expected_pipeline_transformer_one = DagNode(
        34, OperatorType.TRANSFORMER, CodeReference(19, 20, 19, 72),
        ('sklearn.preprocessing._encoders', 'OneHotEncoder', 'Pipeline'),
        "Categorical Encoder (OneHotEncoder), Column: 'education'")
    expected_graph.add_edge(expected_pipeline_project_one,
                            expected_pipeline_transformer_one)
    expected_pipeline_transformer_two = DagNode(
        35, OperatorType.TRANSFORMER, CodeReference(19, 20, 19, 72),
        ('sklearn.preprocessing._encoders', 'OneHotEncoder', 'Pipeline'),
        "Categorical Encoder (OneHotEncoder), Column: 'workclass'")
    expected_graph.add_edge(expected_pipeline_project_two,
                            expected_pipeline_transformer_two)
    expected_pipeline_transformer_three = DagNode(
        40, OperatorType.TRANSFORMER, CodeReference(20, 16, 20, 46),
        ('sklearn.preprocessing._data', 'StandardScaler', 'Pipeline'),
        "Numerical Encoder (StandardScaler), Column: 'age'")
    expected_graph.add_edge(expected_pipeline_project_three,
                            expected_pipeline_transformer_three)
    expected_pipeline_transformer_four = DagNode(
        41, OperatorType.TRANSFORMER, CodeReference(20, 16, 20, 46),
        ('sklearn.preprocessing._data', 'StandardScaler', 'Pipeline'),
        "Numerical Encoder (StandardScaler), Column: 'hours-per-week'")
    expected_graph.add_edge(expected_pipeline_project_four,
                            expected_pipeline_transformer_four)

    expected_pipeline_concatenation = DagNode(
        46, OperatorType.CONCATENATION, CodeReference(18, 25, 21, 2),
        ('sklearn.compose._column_transformer', 'ColumnTransformer',
         'Concatenation'))
    expected_graph.add_edge(expected_pipeline_transformer_one,
                            expected_pipeline_concatenation)
    expected_graph.add_edge(expected_pipeline_transformer_two,
                            expected_pipeline_concatenation)
    expected_graph.add_edge(expected_pipeline_transformer_three,
                            expected_pipeline_concatenation)
    expected_graph.add_edge(expected_pipeline_transformer_four,
                            expected_pipeline_concatenation)

    expected_estimator = DagNode(
        51, OperatorType.ESTIMATOR, CodeReference(26, 19, 26, 48),
        ('sklearn.tree._classes', 'DecisionTreeClassifier', 'Pipeline'),
        "Decision Tree")
    expected_graph.add_edge(expected_pipeline_concatenation,
                            expected_estimator)

    expected_pipeline_fit = DagNode(56, OperatorType.FIT,
                                    CodeReference(24, 18, 26, 51),
                                    ('sklearn.pipeline', 'fit', 'Pipeline'))
    expected_graph.add_edge(expected_estimator, expected_pipeline_fit)

    expected_project = DagNode(
        23, OperatorType.PROJECTION, CodeReference(16, 38, 16, 61),
        ('pandas.core.frame', '__getitem__', 'Projection'),
        "to ['income-per-year']")
    expected_graph.add_edge(expected_select, expected_project)

    expected_project_modify = DagNode(
        28, OperatorType.PROJECTION_MODIFY, CodeReference(16, 9, 16, 89),
        ('sklearn.preprocessing._label', 'label_binarize'),
        "label_binarize, classes: ['>50K', '<=50K']")
    expected_graph.add_edge(expected_project, expected_project_modify)

    expected_train_labels = DagNode(
        56, OperatorType.TRAIN_LABELS, CodeReference(24, 18, 26, 51),
        ('sklearn.pipeline', 'fit', 'Train Labels'))
    expected_graph.add_edge(expected_project_modify, expected_train_labels)
    expected_graph.add_edge(expected_train_labels, expected_pipeline_fit)

    return expected_graph
def test_tuples():
    """
    Tests whether the WIR Extraction works for tuples
    """
    test_code = cleandoc("""
            from sklearn import preprocessing

            ('categorical', preprocessing.OneHotEncoder(handle_unknown='ignore'), ['education', 'workclass'])
            """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_import_from = WirNode(0, "sklearn", "Import",
                                   CodeReference(1, 0, 1, 33))
    expected_constant_one = WirNode(1, "categorical", "Constant",
                                    CodeReference(3, 1, 3, 14))
    expected_constant_two = WirNode(2, "ignore", "Constant",
                                    CodeReference(3, 59, 3, 67))
    expected_keyword = WirNode(3, "handle_unknown", "Keyword")
    expected_graph.add_edge(expected_constant_two,
                            expected_keyword,
                            type="input",
                            arg_index=0)

    expected_call = WirNode(4, "OneHotEncoder", "Call",
                            CodeReference(3, 16, 3, 68))
    expected_graph.add_edge(expected_import_from,
                            expected_call,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_keyword,
                            expected_call,
                            type="input",
                            arg_index=0)

    expected_constant_three = WirNode(5, "education", "Constant",
                                      CodeReference(3, 71, 3, 82))
    expected_constant_four = WirNode(6, "workclass", "Constant",
                                     CodeReference(3, 84, 3, 95))
    expected_list = WirNode(7, "as_list", "List", CodeReference(3, 70, 3, 96))
    expected_graph.add_edge(expected_constant_three,
                            expected_list,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expected_constant_four,
                            expected_list,
                            type="input",
                            arg_index=1)

    expected_tuple = WirNode(8, "as_tuple", "Tuple",
                             CodeReference(3, 0, 3, 97))
    expected_graph.add_edge(expected_constant_one,
                            expected_tuple,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expected_call,
                            expected_tuple,
                            type="input",
                            arg_index=1)
    expected_graph.add_edge(expected_list,
                            expected_tuple,
                            type="input",
                            arg_index=2)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
示例#19
0
def get_expected_result():
    """
    Get the expected PrintFirstRowsAnalyzer(2) result for the adult_easy example
    """
    expected_result = {
        DagNode(node_id=18, operator_type=OperatorType.DATA_SOURCE, module=('pandas.io.parsers', 'read_csv'),
                code_reference=CodeReference(12, 11, 12, 62), description='adult_train.csv'): [
                    InspectionInputRow(
                        values=[46, 'Private', 128645, 'Some-college', 10, 'Divorced', 'Prof-specialty',
                                'Not-in-family', 'White', 'Female', 0, 0, 40, 'United-States', '<=50K'],
                        fields=['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
                                'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
                                'hours-per-week', 'native-country', 'income-per-year']),
                    InspectionInputRow(
                        values=[29, 'Local-gov', 115585, 'Some-college', 10, 'Never-married', 'Handlers-cleaners',
                                'Not-in-family', 'White', 'Male', 0, 0, 50, 'United-States', '<=50K'],
                        fields=['age', 'workclass', 'fnlwgt', 'education', 'education-num',
                                'marital-status', 'occupation', 'relationship', 'race',
                                'sex', 'capital-gain', 'capital-loss', 'hours-per-week',
                                'native-country', 'income-per-year'])],
        DagNode(node_id=20, operator_type=OperatorType.SELECTION, module=('pandas.core.frame', 'dropna'),
                code_reference=CodeReference(14, 7, 14, 24), description='dropna'): [
                    InspectionInputRow(
                        values=[46, 'Private', 128645, 'Some-college', 10, 'Divorced', 'Prof-specialty',
                                'Not-in-family', 'White', 'Female', 0, 0, 40, 'United-States', '<=50K'],
                        fields=['age', 'workclass', 'fnlwgt', 'education', 'education-num',
                                'marital-status', 'occupation', 'relationship', 'race', 'sex',
                                'capital-gain', 'capital-loss', 'hours-per-week',
                                'native-country', 'income-per-year']),
                    InspectionInputRow(
                        values=[29, 'Local-gov', 115585, 'Some-college', 10, 'Never-married', 'Handlers-cleaners',
                                'Not-in-family', 'White', 'Male', 0, 0, 50, 'United-States', '<=50K'],
                        fields=['age', 'workclass', 'fnlwgt', 'education', 'education-num',
                                'marital-status', 'occupation', 'relationship', 'race', 'sex',
                                'capital-gain', 'capital-loss', 'hours-per-week',
                                'native-country', 'income-per-year'])],
        DagNode(node_id=23, operator_type=OperatorType.PROJECTION, module=('pandas.core.frame', '__getitem__',
                                                                           'Projection'),
                code_reference=CodeReference(16, 38, 16, 61), description="to ['income-per-year']"): [
                    InspectionInputRow(values=['<=50K'], fields=['array']),
                    InspectionInputRow(values=['<=50K'], fields=['array'])],
        DagNode(node_id=28, operator_type=OperatorType.PROJECTION_MODIFY,
                module=('sklearn.preprocessing._label', 'label_binarize'),
                code_reference=CodeReference(16, 9, 16, 89),
                description="label_binarize, classes: ['>50K', '<=50K']"): [
                    InspectionInputRow(values=[array(1)], fields=['array']),
                    InspectionInputRow(values=[array(1)], fields=['array'])],
        DagNode(node_id=56, operator_type=OperatorType.TRAIN_DATA, module=('sklearn.pipeline', 'fit', 'Train Data'),
                code_reference=CodeReference(24, 18, 26, 51), description=None): [
                    InspectionInputRow(
                        values=[46, 'Private', 128645, 'Some-college', 10, 'Divorced', 'Prof-specialty',
                                'Not-in-family', 'White', 'Female', 0, 0, 40, 'United-States', '<=50K', 1],
                        fields=['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
                                'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
                                'hours-per-week', 'native-country', 'income-per-year', 'mlinspect_index']),
                    InspectionInputRow(
                        values=[29, 'Local-gov', 115585, 'Some-college', 10, 'Never-married', 'Handlers-cleaners',
                                'Not-in-family', 'White', 'Male', 0, 0, 50, 'United-States', '<=50K', 2],
                        fields=['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status',
                                'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss',
                                'hours-per-week', 'native-country', 'income-per-year', 'mlinspect_index'])],
        DagNode(node_id=56, operator_type=OperatorType.TRAIN_LABELS, module=('sklearn.pipeline', 'fit', 'Train Labels'),
                code_reference=CodeReference(24, 18, 26, 51), description=None): [
                    InspectionInputRow(values=[array(1)], fields=['array']),
                    InspectionInputRow(values=[array(1)], fields=['array'])],
        DagNode(node_id=40, operator_type=OperatorType.PROJECTION, code_reference=CodeReference(18, 25, 21, 2),
                module=('sklearn.compose._column_transformer', 'ColumnTransformer', 'Projection'),
                description="to ['age'] (ColumnTransformer)"): [
                    InspectionInputRow(values=[46], fields=['age']), InspectionInputRow(values=[29], fields=['age'])],
        DagNode(node_id=34, operator_type=OperatorType.PROJECTION, code_reference=CodeReference(18, 25, 21, 2),
                module=('sklearn.compose._column_transformer', 'ColumnTransformer', 'Projection'),
                description="to ['education'] (ColumnTransformer)"): [
                    InspectionInputRow(values=['Some-college'], fields=['education']),
                    InspectionInputRow(values=['Some-college'], fields=['education'])],
        DagNode(node_id=41, operator_type=OperatorType.PROJECTION, code_reference=CodeReference(18, 25, 21, 2),
                module=('sklearn.compose._column_transformer', 'ColumnTransformer', 'Projection'),
                description="to ['hours-per-week'] (ColumnTransformer)"): [
                    InspectionInputRow(values=[40], fields=['hours-per-week']),
                    InspectionInputRow(values=[50], fields=['hours-per-week'])],
        DagNode(node_id=35, operator_type=OperatorType.PROJECTION, code_reference=CodeReference(18, 25, 21, 2),
                module=('sklearn.compose._column_transformer', 'ColumnTransformer', 'Projection'),
                description="to ['workclass'] (ColumnTransformer)"): [
                    InspectionInputRow(values=['Private'], fields=['workclass']),
                    InspectionInputRow(values=['Local-gov'], fields=['workclass'])],
        DagNode(node_id=40, operator_type=OperatorType.TRANSFORMER, code_reference=CodeReference(20, 16, 20, 46),
                module=('sklearn.preprocessing._data', 'StandardScaler', 'Pipeline'),
                description="Numerical Encoder (StandardScaler), Column: 'age'"): [
                    InspectionInputRow(values=[array(RangeComparison(0.5, 0.6))], fields=['array']),
                    InspectionInputRow(values=[array(RangeComparison(-0.8, -0.7))], fields=['array'])],
        DagNode(node_id=41, operator_type=OperatorType.TRANSFORMER, code_reference=CodeReference(20, 16, 20, 46),
                module=('sklearn.preprocessing._data', 'StandardScaler', 'Pipeline'),
                description="Numerical Encoder (StandardScaler), Column: 'hours-per-week'"): [
                    InspectionInputRow(values=[array(RangeComparison(-0.09, -0.08))], fields=['array']),
                    InspectionInputRow(values=[array(RangeComparison(0.7, 0.8))], fields=['array'])],
        DagNode(node_id=34, operator_type=OperatorType.TRANSFORMER, code_reference=CodeReference(19, 20, 19, 72),
                module=('sklearn.preprocessing._encoders', 'OneHotEncoder', 'Pipeline'),
                description="Categorical Encoder (OneHotEncoder), Column: 'education'"): [
                    InspectionInputRow(
                        values=[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.])],
                        fields=['array']),
                    InspectionInputRow(
                        values=[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.])],
                        fields=['array'])],
        DagNode(node_id=35, operator_type=OperatorType.TRANSFORMER, code_reference=CodeReference(19, 20, 19, 72),
                module=('sklearn.preprocessing._encoders', 'OneHotEncoder', 'Pipeline'),
                description="Categorical Encoder (OneHotEncoder), Column: 'workclass'"): [
                    InspectionInputRow(values=[array([0., 0., 1., 0., 0., 0., 0.])], fields=['array']),
                    InspectionInputRow(values=[array([0., 1., 0., 0., 0., 0., 0.])], fields=['array'])],
        DagNode(node_id=46, operator_type=OperatorType.CONCATENATION, code_reference=CodeReference(18, 25, 21, 2),
                module=('sklearn.compose._column_transformer', 'ColumnTransformer', 'Concatenation'),
                description=None): [
                    InspectionInputRow(values=[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                                      0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
                                                      0., 0., 0., RangeComparison(0.5, 0.6),
                                                      RangeComparison(-0.09, -0.08)])],
                                       fields=['array']),
                    InspectionInputRow(values=[array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                                      0., 0., 0., 0., 0., 1., 0., 1., 0., 0.,
                                                      0., 0., 0., RangeComparison(-0.8, -0.7),
                                                      RangeComparison(0.7, 0.8)])],
                                       fields=['array'])],
        DagNode(node_id=51, operator_type=OperatorType.ESTIMATOR, code_reference=CodeReference(26, 19, 26, 48),
                module=('sklearn.tree._classes', 'DecisionTreeClassifier', 'Pipeline'),
                description='Decision Tree'): None
    }
    return expected_result
def get_expected_cleaned_wir_adult_easy():
    """
    Get the expected cleaned WIR for the adult_easy pipeline
    """
    # pylint: disable=too-many-locals
    expected_graph = networkx.DiGraph()

    expected_print_one = WirNode(6, "print", "Call",
                                 CodeReference(10, 0, 10, 23),
                                 ('builtins', 'print'))
    expected_graph.add_node(expected_print_one)

    expected_get_project_root = WirNode(
        7, "get_project_root", "Call", CodeReference(11, 30, 11, 48),
        ('mlinspect.utils', 'get_project_root'))
    expected_str = WirNode(8, "str", "Call", CodeReference(11, 26, 11, 49),
                           ('builtins', 'str'))
    expected_graph.add_edge(expected_get_project_root, expected_str)

    expected_join = WirNode(12, "join", "Call", CodeReference(11, 13, 11, 85),
                            ('posixpath', 'join'))
    expected_graph.add_edge(expected_str, expected_join)

    expected_read_csv = WirNode(18, "read_csv", "Call",
                                CodeReference(12, 11, 12, 62),
                                ('pandas.io.parsers', 'read_csv'))
    expected_graph.add_edge(expected_join, expected_read_csv)

    expected_dropna = WirNode(20, "dropna", "Call",
                              CodeReference(14, 7, 14, 24),
                              ('pandas.core.frame', 'dropna'))
    expected_graph.add_edge(expected_read_csv, expected_dropna)

    expected_fit = WirNode(56, "fit", "Call", CodeReference(28, 0, 28, 33),
                           ('sklearn.pipeline', 'fit'))
    expected_index_subscript = WirNode(
        23, "Index-Subscript", "Subscript", CodeReference(16, 38, 16, 61),
        ('pandas.core.frame', '__getitem__', 'Projection'))
    expected_graph.add_edge(expected_dropna, expected_fit)
    expected_graph.add_edge(expected_dropna, expected_index_subscript)

    expected_label_binarize = WirNode(
        28, "label_binarize", "Call", CodeReference(16, 9, 16, 89),
        ('sklearn.preprocessing._label', 'label_binarize'))
    expected_graph.add_edge(expected_index_subscript, expected_label_binarize)
    expected_graph.add_edge(expected_label_binarize, expected_fit)

    expected_one_hot_encoder = WirNode(
        33, "OneHotEncoder", "Call", CodeReference(19, 20, 19, 72),
        ('sklearn.preprocessing._encoders', 'OneHotEncoder'))
    expected_standard_scaler = WirNode(
        39, "StandardScaler", "Call", CodeReference(20, 16, 20, 46),
        ('sklearn.preprocessing._data', 'StandardScaler'))
    expected_column_transformer = WirNode(
        46, "ColumnTransformer", "Call", CodeReference(18, 25, 21, 2),
        ('sklearn.compose._column_transformer', 'ColumnTransformer'))
    expected_graph.add_edge(expected_one_hot_encoder,
                            expected_column_transformer)
    expected_graph.add_edge(expected_standard_scaler,
                            expected_column_transformer)

    expected_decision_tree_classifier = WirNode(
        51, "DecisionTreeClassifier", "Call", CodeReference(26, 19, 26, 48),
        ('sklearn.tree._classes', 'DecisionTreeClassifier'))
    expected_pipeline = WirNode(54, "Pipeline", "Call",
                                CodeReference(24, 18, 26, 51),
                                ('sklearn.pipeline', 'Pipeline'))
    expected_graph.add_edge(expected_column_transformer, expected_pipeline)
    expected_graph.add_edge(expected_decision_tree_classifier,
                            expected_pipeline)
    expected_graph.add_edge(expected_pipeline, expected_fit)

    expected_print_two = WirNode(58, "print", "Call",
                                 CodeReference(31, 0, 31, 26),
                                 ('builtins', 'print'))
    expected_graph.add_node(expected_print_two)

    return expected_graph