def test_nested_import_from(): """ Tests whether the WIR Extraction works for nested from imports """ test_code = cleandoc(""" from mlinspect.utils import get_project_root print(get_project_root()) """) test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) extracted_wir = extractor.extract_wir() expected_graph = networkx.DiGraph() expected_import = WirNode(0, "mlinspect.utils", "Import", CodeReference(1, 0, 1, 44)) expected_call_one = WirNode(1, "get_project_root", "Call", CodeReference(3, 6, 3, 24)) expected_graph.add_edge(expected_import, expected_call_one, type="caller", arg_index=-1) expected_call_two = WirNode(2, "print", "Call", CodeReference(3, 0, 3, 25)) expected_graph.add_edge(expected_call_one, expected_call_two, type="input", arg_index=0) compare(networkx.to_dict_of_dicts(extracted_wir), networkx.to_dict_of_dicts(expected_graph))
def test_print_var_usage(): """ Tests whether the WIR Extraction works for a very simple var usage """ test_code = cleandoc(""" test_var = "test" print(test_var)""") test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) extracted_wir = extractor.extract_wir() expected_graph = networkx.DiGraph() expected_constant = WirNode(0, "test", "Constant", CodeReference(1, 11, 1, 17)) expected_assign = WirNode(1, "test_var", "Assign", CodeReference(1, 0, 1, 17)) expected_graph.add_edge(expected_constant, expected_assign, type="input", arg_index=0) expected_call = WirNode(2, "print", "Call", CodeReference(2, 0, 2, 15)) expected_graph.add_node(expected_call) expected_graph.add_edge(expected_assign, expected_call, type="input", arg_index=0) compare(networkx.to_dict_of_dicts(extracted_wir), networkx.to_dict_of_dicts(expected_graph))
def test_import_from(): """ Tests whether the WIR Extraction works for from imports """ test_code = cleandoc(""" from math import sqrt sqrt(4) """) test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) extracted_wir = extractor.extract_wir() expected_graph = networkx.DiGraph() expected_import = WirNode(0, "math", "Import", CodeReference(1, 0, 1, 21)) expected_constant = WirNode(1, "4", "Constant", CodeReference(3, 5, 3, 6)) expected_constant_call = WirNode(2, "sqrt", "Call", CodeReference(3, 0, 3, 7)) expected_graph.add_edge(expected_import, expected_constant_call, type="caller", arg_index=-1) expected_graph.add_edge(expected_constant, expected_constant_call, type="input", arg_index=0) compare(networkx.to_dict_of_dicts(extracted_wir), networkx.to_dict_of_dicts(expected_graph))
def test_print_expressions(): """ Tests whether the WIR Extraction works for an expression with very simple nested calls """ test_code = cleandoc(""" print("test".isupper()) """) test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) extracted_wir = extractor.extract_wir() expected_graph = networkx.DiGraph() expected_constant = WirNode(0, "test", "Constant", CodeReference(1, 6, 1, 12)) expected_call_one = WirNode(1, "isupper", "Call", CodeReference(1, 6, 1, 22)) expected_graph.add_edge(expected_constant, expected_call_one, type="caller", arg_index=-1) expected_call_two = WirNode(2, "print", "Call", CodeReference(1, 0, 1, 23)) expected_graph.add_edge(expected_call_one, expected_call_two, type="input", arg_index=0) compare(networkx.to_dict_of_dicts(extracted_wir), networkx.to_dict_of_dicts(expected_graph))
def test_call_after_call(): """ Tests whether the WIR Extraction works for a very simple attribute call """ test_code = cleandoc(""" "hello ".capitalize().count() """) test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) extracted_wir = extractor.extract_wir() expected_graph = networkx.DiGraph() expected_constant_one = WirNode(0, "hello ", "Constant", CodeReference(1, 0, 1, 8)) expected_call_one = WirNode(1, "capitalize", "Call", CodeReference(1, 0, 1, 21)) expected_call_two = WirNode(2, "count", "Call", CodeReference(1, 0, 1, 29)) expected_graph.add_edge(expected_constant_one, expected_call_one, type="caller", arg_index=-1) expected_graph.add_edge(expected_call_one, expected_call_two, type="caller", arg_index=-1) compare(networkx.to_dict_of_dicts(extracted_wir), networkx.to_dict_of_dicts(expected_graph))
def test_string_call_attribute(): """ Tests whether the WIR Extraction works for a very simple attribute call """ test_code = cleandoc(""" "hello ".join("world") """) test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) extracted_wir = extractor.extract_wir() expected_graph = networkx.DiGraph() expected_constant_one = WirNode(0, "hello ", "Constant", CodeReference(1, 0, 1, 8)) expected_constant_two = WirNode(1, "world", "Constant", CodeReference(1, 14, 1, 21)) expected_attribute_call = WirNode(2, "join", "Call", CodeReference(1, 0, 1, 22)) expected_graph.add_edge(expected_constant_one, expected_attribute_call, type="caller", arg_index=-1) expected_graph.add_edge(expected_constant_two, expected_attribute_call, type="input", arg_index=0) compare(networkx.to_dict_of_dicts(extracted_wir), networkx.to_dict_of_dicts(expected_graph))
def test_list_creation(): """ Tests whether the WIR Extraction works for lists """ test_code = cleandoc(""" print(["test1", "test2"]) """) test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) extracted_wir = extractor.extract_wir() expected_graph = networkx.DiGraph() expected_constant_one = WirNode(0, "test1", "Constant", CodeReference(1, 7, 1, 14)) expected_constant_two = WirNode(1, "test2", "Constant", CodeReference(1, 16, 1, 23)) expected_list = WirNode(2, "as_list", "List", CodeReference(1, 6, 1, 24)) expected_graph.add_edge(expected_constant_one, expected_list, type="input", arg_index=0) expected_graph.add_edge(expected_constant_two, expected_list, type="input", arg_index=1) expected_call = WirNode(3, "print", "Call", CodeReference(1, 0, 1, 25)) expected_graph.add_edge(expected_list, expected_call, type="input", arg_index=0) compare(networkx.to_dict_of_dicts(extracted_wir), networkx.to_dict_of_dicts(expected_graph))
def test_index_subscript_with_module_information(): """ Tests whether the WIR Extraction works for lists """ test_code = cleandoc(""" import pandas as pd data = pd.read_csv('test_path') data['income-per-year'] """) test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) module_info = { CodeReference(3, 7, 3, 31): ('pandas.io.parsers', 'read_csv'), CodeReference(4, 0, 4, 23): ('pandas.core.frame', '__getitem__') } extracted_wir = extractor.extract_wir() extractor.add_runtime_info(module_info, {}) expected_graph = networkx.DiGraph() expected_import = WirNode(0, "pandas", "Import", CodeReference(1, 0, 1, 19)) expected_constant_one = WirNode(1, "test_path", "Constant", CodeReference(3, 19, 3, 30)) expected_call = WirNode(2, "read_csv", "Call", CodeReference(3, 7, 3, 31), ('pandas.io.parsers', 'read_csv')) expected_graph.add_edge(expected_import, expected_call, type="caller", arg_index=-1) expected_graph.add_edge(expected_constant_one, expected_call, type="input", arg_index=0) expected_assign = WirNode(3, "data", "Assign", CodeReference(3, 0, 3, 31)) expected_graph.add_edge(expected_call, expected_assign, type="input", arg_index=0) expected_constant_two = WirNode(4, "income-per-year", "Constant", CodeReference(4, 5, 4, 22)) expected_index_subscript = WirNode(5, "Index-Subscript", "Subscript", CodeReference(4, 0, 4, 23), ('pandas.core.frame', '__getitem__')) expected_graph.add_edge(expected_assign, expected_index_subscript, type="caller", arg_index=-1) expected_graph.add_edge(expected_constant_two, expected_index_subscript, type="input", arg_index=0) compare(networkx.to_dict_of_dicts(extracted_wir), networkx.to_dict_of_dicts(expected_graph))
def get_test_wir(): """ Get the extracted WIR for the adult_easy pipeline with runtime info """ test_ast = get_adult_easy_py_ast() extractor = WirExtractor(test_ast) extractor.extract_wir() wir = extractor.add_runtime_info(get_module_info(), get_call_description_info()) return wir
def test_fails_for_unknown_ast_node_types(): """ Tests whether the WIR Extraction fails properly for node types it can not handle yet """ test_code = "print(*['hello ', 'world'])" parsed_ast = ast.parse(test_code) exec(compile(parsed_ast, filename="<ast>", mode="exec"), {}) extractor = WirExtractor(parsed_ast) with pytest.raises(AssertionError): extractor.extract_wir()
def test_adult_easy_pipeline(): """ Tests whether the WIR Extraction works for the adult_easy pipeline """ with open(FILE_PY) as file: test_code = file.read() test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) extracted_wir = extractor.extract_wir() assert len(extracted_wir) == 59
def test_index_assign(): """ Tests whether the WIR Extraction works for lists """ test_code = cleandoc(""" import pandas as pd data = pd.read_csv('test_path') data['label'] = "test" """) test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) extracted_wir = extractor.extract_wir() expected_graph = networkx.DiGraph() expected_import = WirNode(0, "pandas", "Import", CodeReference(1, 0, 1, 19)) expected_constant_one = WirNode(1, "test_path", "Constant", CodeReference(3, 19, 3, 30)) expected_call = WirNode(2, "read_csv", "Call", CodeReference(3, 7, 3, 31)) expected_graph.add_edge(expected_import, expected_call, type="caller", arg_index=-1) expected_graph.add_edge(expected_constant_one, expected_call, type="input", arg_index=0) expected_assign = WirNode(3, "data", "Assign", CodeReference(3, 0, 3, 31)) expected_graph.add_edge(expected_call, expected_assign, type="input", arg_index=0) expected_constant_two = WirNode(4, "label", "Constant", CodeReference(4, 5, 4, 12)) expected_graph.add_node(expected_constant_two) expected_constant_three = WirNode(5, "test", "Constant", CodeReference(4, 16, 4, 22)) expected_graph.add_node(expected_constant_three) expected_subscript_assign = WirNode(6, 'data.label', 'Subscript-Assign', CodeReference(4, 0, 4, 13)) expected_graph.add_edge(expected_assign, expected_subscript_assign, type="caller", arg_index=-1) compare(networkx.to_dict_of_dicts(extracted_wir), networkx.to_dict_of_dicts(expected_graph))
def test_tuple_assign(): """ Tests whether the WIR Extraction works for a very simple var usage """ test_code = cleandoc(""" x, y = (1, 2) print(x)""") test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) extracted_wir = extractor.extract_wir() expected_graph = networkx.DiGraph() expected_constant_one = WirNode(1, '1', 'Constant', CodeReference(1, 8, 1, 9)) expected_constant_two = WirNode(2, '2', 'Constant', CodeReference(1, 11, 1, 12)) expetected_constant_tuple = WirNode(3, 'as_tuple', 'Tuple', CodeReference(1, 7, 1, 13)) expected_graph.add_edge(expected_constant_one, expetected_constant_tuple, type="input", arg_index=0) expected_graph.add_edge(expected_constant_two, expetected_constant_tuple, type="input", arg_index=1) expected_var_x = WirNode(4, 'x', 'Assign', CodeReference(1, 0, 1, 1)) expected_var_y = WirNode(5, 'y', 'Assign', CodeReference(1, 3, 1, 4)) expected_graph.add_edge(expetected_constant_tuple, expected_var_x, type="input", arg_index=0) expected_graph.add_edge(expetected_constant_tuple, expected_var_y, type="input", arg_index=0) expected_call = WirNode(6, 'print', 'Call', CodeReference(2, 0, 2, 8)) expected_graph.add_edge(expected_var_x, expected_call, type="input", arg_index=0) compare(networkx.to_dict_of_dicts(extracted_wir), networkx.to_dict_of_dicts(expected_graph))
def test_keyword(): """ Tests whether the WIR Extraction works for function calls with keyword usage """ test_code = cleandoc(""" print('comma', 'separated', 'words', sep=', ') """) test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) extracted_wir = extractor.extract_wir() expected_graph = networkx.DiGraph() expected_constant_one = WirNode(0, "comma", "Constant", CodeReference(1, 6, 1, 13)) expected_constant_two = WirNode(1, "separated", "Constant", CodeReference(1, 15, 1, 26)) expected_constant_three = WirNode(2, "words", "Constant", CodeReference(1, 28, 1, 35)) expected_constant_four = WirNode(3, ", ", "Constant", CodeReference(1, 41, 1, 45)) expected_keyword = WirNode(4, "sep", "Keyword") expected_call = WirNode(5, "print", "Call", CodeReference(1, 0, 1, 46)) expected_graph.add_edge(expected_constant_four, expected_keyword, type="input", arg_index=0) expected_graph.add_edge(expected_constant_one, expected_call, type="input", arg_index=0) expected_graph.add_edge(expected_constant_two, expected_call, type="input", arg_index=1) expected_graph.add_edge(expected_constant_three, expected_call, type="input", arg_index=2) expected_graph.add_edge(expected_keyword, expected_call, type="input", arg_index=3) compare(networkx.to_dict_of_dicts(extracted_wir), networkx.to_dict_of_dicts(expected_graph))
def test_remove_all_nodes_but_calls_and_subscripts(): """ Tests whether the WIR Extraction works for the adult_easy pipeline """ test_ast = get_adult_easy_py_ast() extractor = WirExtractor(test_ast) extractor.extract_wir() extracted_wir_with_module_info = extractor.add_runtime_info( get_module_info(), {}) cleaned_wir = WirToDagTransformer( ).remove_all_nodes_but_calls_and_subscripts(extracted_wir_with_module_info) assert len(cleaned_wir) == 15 expected_graph = get_expected_cleaned_wir_adult_easy() compare(networkx.to_dict_of_dicts(cleaned_wir), networkx.to_dict_of_dicts(expected_graph))
def test_print_stmt(): """ Tests whether the WIR Extraction works for a very simple print statement """ test_code = "print('test')" test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) extracted_wir = extractor.extract_wir() expected_graph = networkx.DiGraph() expected_constant = WirNode(0, "test", "Constant", CodeReference(1, 6, 1, 12)) expected_call = WirNode(1, "print", "Call", CodeReference(1, 0, 1, 13)) expected_graph.add_edge(expected_constant, expected_call, type="input", arg_index=0) compare(networkx.to_dict_of_dicts(extracted_wir), networkx.to_dict_of_dicts(expected_graph))
def test_tuples(): """ Tests whether the WIR Extraction works for tuples """ test_code = cleandoc(""" from sklearn import preprocessing ('categorical', preprocessing.OneHotEncoder(handle_unknown='ignore'), ['education', 'workclass']) """) test_ast = ast.parse(test_code) extractor = WirExtractor(test_ast) extracted_wir = extractor.extract_wir() expected_graph = networkx.DiGraph() expected_import_from = WirNode(0, "sklearn", "Import", CodeReference(1, 0, 1, 33)) expected_constant_one = WirNode(1, "categorical", "Constant", CodeReference(3, 1, 3, 14)) expected_constant_two = WirNode(2, "ignore", "Constant", CodeReference(3, 59, 3, 67)) expected_keyword = WirNode(3, "handle_unknown", "Keyword") expected_graph.add_edge(expected_constant_two, expected_keyword, type="input", arg_index=0) expected_call = WirNode(4, "OneHotEncoder", "Call", CodeReference(3, 16, 3, 68)) expected_graph.add_edge(expected_import_from, expected_call, type="caller", arg_index=-1) expected_graph.add_edge(expected_keyword, expected_call, type="input", arg_index=0) expected_constant_three = WirNode(5, "education", "Constant", CodeReference(3, 71, 3, 82)) expected_constant_four = WirNode(6, "workclass", "Constant", CodeReference(3, 84, 3, 95)) expected_list = WirNode(7, "as_list", "List", CodeReference(3, 70, 3, 96)) expected_graph.add_edge(expected_constant_three, expected_list, type="input", arg_index=0) expected_graph.add_edge(expected_constant_four, expected_list, type="input", arg_index=1) expected_tuple = WirNode(8, "as_tuple", "Tuple", CodeReference(3, 0, 3, 97)) expected_graph.add_edge(expected_constant_one, expected_tuple, type="input", arg_index=0) expected_graph.add_edge(expected_call, expected_tuple, type="input", arg_index=1) expected_graph.add_edge(expected_list, expected_tuple, type="input", arg_index=2) compare(networkx.to_dict_of_dicts(extracted_wir), networkx.to_dict_of_dicts(expected_graph))