def test_print_var_usage():
    """
    Tests whether the WIR Extraction works for a very simple var usage
    """
    test_code = cleandoc("""
        test_var = "test"
        print(test_var)""")
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant = WirNode(0, "test", "Constant",
                                CodeReference(1, 11, 1, 17))
    expected_assign = WirNode(1, "test_var", "Assign",
                              CodeReference(1, 0, 1, 17))
    expected_graph.add_edge(expected_constant,
                            expected_assign,
                            type="input",
                            arg_index=0)

    expected_call = WirNode(2, "print", "Call", CodeReference(2, 0, 2, 15))
    expected_graph.add_node(expected_call)
    expected_graph.add_edge(expected_assign,
                            expected_call,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_nested_import_from():
    """
    Tests whether the WIR Extraction works for nested from imports
    """
    test_code = cleandoc("""
            from mlinspect.utils import get_project_root

            print(get_project_root())
            """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_import = WirNode(0, "mlinspect.utils", "Import",
                              CodeReference(1, 0, 1, 44))
    expected_call_one = WirNode(1, "get_project_root", "Call",
                                CodeReference(3, 6, 3, 24))
    expected_graph.add_edge(expected_import,
                            expected_call_one,
                            type="caller",
                            arg_index=-1)

    expected_call_two = WirNode(2, "print", "Call", CodeReference(3, 0, 3, 25))
    expected_graph.add_edge(expected_call_one,
                            expected_call_two,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_list_creation():
    """
    Tests whether the WIR Extraction works for lists
    """
    test_code = cleandoc("""
            print(["test1", "test2"])
            """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant_one = WirNode(0, "test1", "Constant",
                                    CodeReference(1, 7, 1, 14))
    expected_constant_two = WirNode(1, "test2", "Constant",
                                    CodeReference(1, 16, 1, 23))
    expected_list = WirNode(2, "as_list", "List", CodeReference(1, 6, 1, 24))
    expected_graph.add_edge(expected_constant_one,
                            expected_list,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expected_constant_two,
                            expected_list,
                            type="input",
                            arg_index=1)

    expected_call = WirNode(3, "print", "Call", CodeReference(1, 0, 1, 25))
    expected_graph.add_edge(expected_list,
                            expected_call,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_print_expressions():
    """
    Tests whether the WIR Extraction works for an expression with very simple nested calls
    """
    test_code = cleandoc("""
        print("test".isupper())
        """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant = WirNode(0, "test", "Constant",
                                CodeReference(1, 6, 1, 12))
    expected_call_one = WirNode(1, "isupper", "Call",
                                CodeReference(1, 6, 1, 22))
    expected_graph.add_edge(expected_constant,
                            expected_call_one,
                            type="caller",
                            arg_index=-1)

    expected_call_two = WirNode(2, "print", "Call", CodeReference(1, 0, 1, 23))
    expected_graph.add_edge(expected_call_one,
                            expected_call_two,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_import_from():
    """
    Tests whether the WIR Extraction works for from imports
    """
    test_code = cleandoc("""
            from math import sqrt 

            sqrt(4)
            """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_import = WirNode(0, "math", "Import", CodeReference(1, 0, 1, 21))
    expected_constant = WirNode(1, "4", "Constant", CodeReference(3, 5, 3, 6))
    expected_constant_call = WirNode(2, "sqrt", "Call",
                                     CodeReference(3, 0, 3, 7))
    expected_graph.add_edge(expected_import,
                            expected_constant_call,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_constant,
                            expected_constant_call,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_call_after_call():
    """
    Tests whether the WIR Extraction works for a very simple attribute call
    """
    test_code = cleandoc("""
        "hello ".capitalize().count()
        """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant_one = WirNode(0, "hello ", "Constant",
                                    CodeReference(1, 0, 1, 8))
    expected_call_one = WirNode(1, "capitalize", "Call",
                                CodeReference(1, 0, 1, 21))
    expected_call_two = WirNode(2, "count", "Call", CodeReference(1, 0, 1, 29))
    expected_graph.add_edge(expected_constant_one,
                            expected_call_one,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_call_one,
                            expected_call_two,
                            type="caller",
                            arg_index=-1)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_string_call_attribute():
    """
    Tests whether the WIR Extraction works for a very simple attribute call
    """
    test_code = cleandoc("""
        "hello ".join("world")
        """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant_one = WirNode(0, "hello ", "Constant",
                                    CodeReference(1, 0, 1, 8))
    expected_constant_two = WirNode(1, "world", "Constant",
                                    CodeReference(1, 14, 1, 21))
    expected_attribute_call = WirNode(2, "join", "Call",
                                      CodeReference(1, 0, 1, 22))
    expected_graph.add_edge(expected_constant_one,
                            expected_attribute_call,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_constant_two,
                            expected_attribute_call,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_print_stmt():
    """
    Tests whether the WIR Extraction works for a very simple print statement
    """
    test_code = "print('test')"
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant = WirNode(0, "test", "Constant",
                                CodeReference(1, 6, 1, 12))
    expected_call = WirNode(1, "print", "Call", CodeReference(1, 0, 1, 13))
    expected_graph.add_edge(expected_constant,
                            expected_call,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_index_subscript_with_module_information():
    """
    Tests whether the WIR Extraction works for lists
    """
    test_code = cleandoc("""
            import pandas as pd

            data = pd.read_csv('test_path')
            data['income-per-year']
            """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    module_info = {
        CodeReference(3, 7, 3, 31): ('pandas.io.parsers', 'read_csv'),
        CodeReference(4, 0, 4, 23): ('pandas.core.frame', '__getitem__')
    }
    extracted_wir = extractor.extract_wir()
    extractor.add_runtime_info(module_info, {})
    expected_graph = networkx.DiGraph()

    expected_import = WirNode(0, "pandas", "Import",
                              CodeReference(1, 0, 1, 19))
    expected_constant_one = WirNode(1, "test_path", "Constant",
                                    CodeReference(3, 19, 3, 30))
    expected_call = WirNode(2, "read_csv", "Call", CodeReference(3, 7, 3, 31),
                            ('pandas.io.parsers', 'read_csv'))
    expected_graph.add_edge(expected_import,
                            expected_call,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_constant_one,
                            expected_call,
                            type="input",
                            arg_index=0)

    expected_assign = WirNode(3, "data", "Assign", CodeReference(3, 0, 3, 31))
    expected_graph.add_edge(expected_call,
                            expected_assign,
                            type="input",
                            arg_index=0)

    expected_constant_two = WirNode(4, "income-per-year", "Constant",
                                    CodeReference(4, 5, 4, 22))
    expected_index_subscript = WirNode(5, "Index-Subscript", "Subscript",
                                       CodeReference(4, 0, 4, 23),
                                       ('pandas.core.frame', '__getitem__'))
    expected_graph.add_edge(expected_assign,
                            expected_index_subscript,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_constant_two,
                            expected_index_subscript,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_index_assign():
    """
    Tests whether the WIR Extraction works for lists
    """
    test_code = cleandoc("""
            import pandas as pd

            data = pd.read_csv('test_path')
            data['label'] = "test"
            """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_import = WirNode(0, "pandas", "Import",
                              CodeReference(1, 0, 1, 19))
    expected_constant_one = WirNode(1, "test_path", "Constant",
                                    CodeReference(3, 19, 3, 30))
    expected_call = WirNode(2, "read_csv", "Call", CodeReference(3, 7, 3, 31))
    expected_graph.add_edge(expected_import,
                            expected_call,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_constant_one,
                            expected_call,
                            type="input",
                            arg_index=0)

    expected_assign = WirNode(3, "data", "Assign", CodeReference(3, 0, 3, 31))
    expected_graph.add_edge(expected_call,
                            expected_assign,
                            type="input",
                            arg_index=0)

    expected_constant_two = WirNode(4, "label", "Constant",
                                    CodeReference(4, 5, 4, 12))
    expected_graph.add_node(expected_constant_two)

    expected_constant_three = WirNode(5, "test", "Constant",
                                      CodeReference(4, 16, 4, 22))
    expected_graph.add_node(expected_constant_three)

    expected_subscript_assign = WirNode(6, 'data.label', 'Subscript-Assign',
                                        CodeReference(4, 0, 4, 13))
    expected_graph.add_edge(expected_assign,
                            expected_subscript_assign,
                            type="caller",
                            arg_index=-1)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_tuple_assign():
    """
    Tests whether the WIR Extraction works for a very simple var usage
    """
    test_code = cleandoc("""
        x, y = (1, 2)
        print(x)""")
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant_one = WirNode(1, '1', 'Constant',
                                    CodeReference(1, 8, 1, 9))
    expected_constant_two = WirNode(2, '2', 'Constant',
                                    CodeReference(1, 11, 1, 12))

    expetected_constant_tuple = WirNode(3, 'as_tuple', 'Tuple',
                                        CodeReference(1, 7, 1, 13))
    expected_graph.add_edge(expected_constant_one,
                            expetected_constant_tuple,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expected_constant_two,
                            expetected_constant_tuple,
                            type="input",
                            arg_index=1)

    expected_var_x = WirNode(4, 'x', 'Assign', CodeReference(1, 0, 1, 1))
    expected_var_y = WirNode(5, 'y', 'Assign', CodeReference(1, 3, 1, 4))
    expected_graph.add_edge(expetected_constant_tuple,
                            expected_var_x,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expetected_constant_tuple,
                            expected_var_y,
                            type="input",
                            arg_index=0)

    expected_call = WirNode(6, 'print', 'Call', CodeReference(2, 0, 2, 8))
    expected_graph.add_edge(expected_var_x,
                            expected_call,
                            type="input",
                            arg_index=0)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_keyword():
    """
    Tests whether the WIR Extraction works for function calls with keyword usage
    """
    test_code = cleandoc("""
        print('comma', 'separated', 'words', sep=', ')
        """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_constant_one = WirNode(0, "comma", "Constant",
                                    CodeReference(1, 6, 1, 13))
    expected_constant_two = WirNode(1, "separated", "Constant",
                                    CodeReference(1, 15, 1, 26))
    expected_constant_three = WirNode(2, "words", "Constant",
                                      CodeReference(1, 28, 1, 35))
    expected_constant_four = WirNode(3, ", ", "Constant",
                                     CodeReference(1, 41, 1, 45))
    expected_keyword = WirNode(4, "sep", "Keyword")
    expected_call = WirNode(5, "print", "Call", CodeReference(1, 0, 1, 46))

    expected_graph.add_edge(expected_constant_four,
                            expected_keyword,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expected_constant_one,
                            expected_call,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expected_constant_two,
                            expected_call,
                            type="input",
                            arg_index=1)
    expected_graph.add_edge(expected_constant_three,
                            expected_call,
                            type="input",
                            arg_index=2)
    expected_graph.add_edge(expected_keyword,
                            expected_call,
                            type="input",
                            arg_index=3)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def test_tuples():
    """
    Tests whether the WIR Extraction works for tuples
    """
    test_code = cleandoc("""
            from sklearn import preprocessing

            ('categorical', preprocessing.OneHotEncoder(handle_unknown='ignore'), ['education', 'workclass'])
            """)
    test_ast = ast.parse(test_code)
    extractor = WirExtractor(test_ast)
    extracted_wir = extractor.extract_wir()
    expected_graph = networkx.DiGraph()

    expected_import_from = WirNode(0, "sklearn", "Import",
                                   CodeReference(1, 0, 1, 33))
    expected_constant_one = WirNode(1, "categorical", "Constant",
                                    CodeReference(3, 1, 3, 14))
    expected_constant_two = WirNode(2, "ignore", "Constant",
                                    CodeReference(3, 59, 3, 67))
    expected_keyword = WirNode(3, "handle_unknown", "Keyword")
    expected_graph.add_edge(expected_constant_two,
                            expected_keyword,
                            type="input",
                            arg_index=0)

    expected_call = WirNode(4, "OneHotEncoder", "Call",
                            CodeReference(3, 16, 3, 68))
    expected_graph.add_edge(expected_import_from,
                            expected_call,
                            type="caller",
                            arg_index=-1)
    expected_graph.add_edge(expected_keyword,
                            expected_call,
                            type="input",
                            arg_index=0)

    expected_constant_three = WirNode(5, "education", "Constant",
                                      CodeReference(3, 71, 3, 82))
    expected_constant_four = WirNode(6, "workclass", "Constant",
                                     CodeReference(3, 84, 3, 95))
    expected_list = WirNode(7, "as_list", "List", CodeReference(3, 70, 3, 96))
    expected_graph.add_edge(expected_constant_three,
                            expected_list,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expected_constant_four,
                            expected_list,
                            type="input",
                            arg_index=1)

    expected_tuple = WirNode(8, "as_tuple", "Tuple",
                             CodeReference(3, 0, 3, 97))
    expected_graph.add_edge(expected_constant_one,
                            expected_tuple,
                            type="input",
                            arg_index=0)
    expected_graph.add_edge(expected_call,
                            expected_tuple,
                            type="input",
                            arg_index=1)
    expected_graph.add_edge(expected_list,
                            expected_tuple,
                            type="input",
                            arg_index=2)

    compare(networkx.to_dict_of_dicts(extracted_wir),
            networkx.to_dict_of_dicts(expected_graph))
def get_expected_cleaned_wir_adult_easy():
    """
    Get the expected cleaned WIR for the adult_easy pipeline
    """
    # pylint: disable=too-many-locals
    expected_graph = networkx.DiGraph()

    expected_print_one = WirNode(6, "print", "Call",
                                 CodeReference(10, 0, 10, 23),
                                 ('builtins', 'print'))
    expected_graph.add_node(expected_print_one)

    expected_get_project_root = WirNode(
        7, "get_project_root", "Call", CodeReference(11, 30, 11, 48),
        ('mlinspect.utils', 'get_project_root'))
    expected_str = WirNode(8, "str", "Call", CodeReference(11, 26, 11, 49),
                           ('builtins', 'str'))
    expected_graph.add_edge(expected_get_project_root, expected_str)

    expected_join = WirNode(12, "join", "Call", CodeReference(11, 13, 11, 85),
                            ('posixpath', 'join'))
    expected_graph.add_edge(expected_str, expected_join)

    expected_read_csv = WirNode(18, "read_csv", "Call",
                                CodeReference(12, 11, 12, 62),
                                ('pandas.io.parsers', 'read_csv'))
    expected_graph.add_edge(expected_join, expected_read_csv)

    expected_dropna = WirNode(20, "dropna", "Call",
                              CodeReference(14, 7, 14, 24),
                              ('pandas.core.frame', 'dropna'))
    expected_graph.add_edge(expected_read_csv, expected_dropna)

    expected_fit = WirNode(56, "fit", "Call", CodeReference(28, 0, 28, 33),
                           ('sklearn.pipeline', 'fit'))
    expected_index_subscript = WirNode(
        23, "Index-Subscript", "Subscript", CodeReference(16, 38, 16, 61),
        ('pandas.core.frame', '__getitem__', 'Projection'))
    expected_graph.add_edge(expected_dropna, expected_fit)
    expected_graph.add_edge(expected_dropna, expected_index_subscript)

    expected_label_binarize = WirNode(
        28, "label_binarize", "Call", CodeReference(16, 9, 16, 89),
        ('sklearn.preprocessing._label', 'label_binarize'))
    expected_graph.add_edge(expected_index_subscript, expected_label_binarize)
    expected_graph.add_edge(expected_label_binarize, expected_fit)

    expected_one_hot_encoder = WirNode(
        33, "OneHotEncoder", "Call", CodeReference(19, 20, 19, 72),
        ('sklearn.preprocessing._encoders', 'OneHotEncoder'))
    expected_standard_scaler = WirNode(
        39, "StandardScaler", "Call", CodeReference(20, 16, 20, 46),
        ('sklearn.preprocessing._data', 'StandardScaler'))
    expected_column_transformer = WirNode(
        46, "ColumnTransformer", "Call", CodeReference(18, 25, 21, 2),
        ('sklearn.compose._column_transformer', 'ColumnTransformer'))
    expected_graph.add_edge(expected_one_hot_encoder,
                            expected_column_transformer)
    expected_graph.add_edge(expected_standard_scaler,
                            expected_column_transformer)

    expected_decision_tree_classifier = WirNode(
        51, "DecisionTreeClassifier", "Call", CodeReference(26, 19, 26, 48),
        ('sklearn.tree._classes', 'DecisionTreeClassifier'))
    expected_pipeline = WirNode(54, "Pipeline", "Call",
                                CodeReference(24, 18, 26, 51),
                                ('sklearn.pipeline', 'Pipeline'))
    expected_graph.add_edge(expected_column_transformer, expected_pipeline)
    expected_graph.add_edge(expected_decision_tree_classifier,
                            expected_pipeline)
    expected_graph.add_edge(expected_pipeline, expected_fit)

    expected_print_two = WirNode(58, "print", "Call",
                                 CodeReference(31, 0, 31, 26),
                                 ('builtins', 'print'))
    expected_graph.add_node(expected_print_two)

    return expected_graph