Exemple #1
0
def test_analyze_func():
    # test empty function detection
    def empty_fn_pass():
        pass

    def empty_fn_docstr():
        """
        docs for empty function
        """

    def empty_fn_docstr_pass():
        """
        docs for empty function
        """
        pass

    def nonempty_fn():
        print()

    # test cases => whether the function is expected to be empty
    empty_fns = [
        (empty_fn_pass, True),
        (empty_fn_docstr, True),
        (empty_fn_docstr_pass, True),
        (nonempty_fn, False),
    ]
    for fn, expected_empty in empty_fns:
        fn_ast = parse_ast(fn)
        analyzed_ast = analyze_func(fn_ast)
        assert get_fn_ast(analyzed_ast).is_empty == expected_empty

    # test analyze func can detect number of arguments
    def multi_args(a, b, c=2):
        pass

    n_arg_fns = [
        (0, empty_fn_pass),
        (3, multi_args),
    ]
    fn_asts = [parse_ast(f[1]) for f in n_arg_fns]
    analyzed_asts = [analyze_func(ast) for ast in fn_asts]
    assert all(
        get_fn_ast(ast).n_args == expected_n_arg[0]
        for ast, expected_n_arg in zip(analyzed_asts, n_arg_fns)
    )

    # test docstring detection
    analyzed_ast = analyze_func(parse_ast(empty_fn_docstr))
    assert get_fn_ast(analyzed_ast).docstr is not None
    analyzed_ast = analyze_func(parse_ast(empty_fn_pass))
    assert get_fn_ast(analyzed_ast).docstr is None

    # test generator detection
    def gen_fn():
        yield 2

    analyzed_ast = analyze_func(parse_ast(gen_fn))
    assert get_fn_ast(analyzed_ast).is_generator
def test_transform_build_graph():
    def convert_fn(g: Plotter):
        pass

    def convert_fn_with_long_name(g: Plotter):
        pass

    identity = lambda fn: fn

    @identity
    def convert_fn_with_annotation(g: Plotter):
        pass

    # convert fn test cases
    convert_fns = [
        convert_fn,
        convert_fn_with_long_name,
        convert_fn_with_annotation,
    ]

    req_analyzers = [
        analyze_func,
        analyze_convert_fn,
    ]

    for convert_fn in convert_fns:
        ast = parse_ast(convert_fn)
        for analyzer in req_analyzers:
            ast = analyzer(ast)
        trans_ast = transform_build_graph(ast)

        # try running the transformed function renamed to 'build_graph'
        mod = load_ast_module(trans_ast)
        mod.build_graph(Plotter(()))
Exemple #3
0
def test_symbol_analyzer():
    def symbol_fn():
        x = 2
        a.b.c = "str"
        y, z = True, False
        m[k1] = "v1"
        m["k2"] = "v2"
        m["k" + "3"] = "v3"

    ast = parse_ast(symbol_fn)
    analyzed_ast = analyze_symbol(ast)
    fn_ast = analyzed_ast.body[0]
    x_target, abc_target = [fn_ast.body[i].targets[0] for i in range(2)]
    x_value, abc_value = [fn_ast.body[i].value for i in range(2)]
    assert x_target.is_symbol and x_target.symbol == "x" and x_target.base_symbol == "x"
    assert (
        abc_target.is_symbol
        and abc_target.symbol == "a.b.c"
        and abc_target.base_symbol == "a"
    )

    y_target, z_target = fn_ast.body[2].targets[0].elts
    y_value, z_value = fn_ast.body[2].value.elts
    assert y_target.is_symbol and y_target.symbol == "y" and y_target.base_symbol == "y"
    assert z_target.is_symbol and z_target.symbol == "z" and z_target.base_symbol == "z"

    mk1_target, mk2_target, mk3_target = [
        fn_ast.body[i].targets[0] for i in range(3, 6)
    ]
    mk1_value, mk2_value, mk3_value = [fn_ast.body[i].value for i in range(3, 6)]

    assert (
        mk1_target.is_symbol
        and mk1_target.symbol == "m[k1]"
        and mk1_target.base_symbol == "m"
    )
    assert (
        mk2_target.is_symbol
        and mk2_target.symbol == "m['k2']"
        and mk2_target.base_symbol == "m"
    )
    assert (
        mk3_target.is_symbol
        and mk3_target.symbol == "m[('k' + '3')]"
        and mk3_target.base_symbol == "m"
    )

    assert all(
        [
            not val.is_symbol
            for val in [x_value, abc_value, y_value, z_value, mk1_value, mk2_value]
        ]
    )
Exemple #4
0
def test_analyze_parent():
    def simple_fn():
        simple = 2

    ast = analyze_parent(parse_ast(simple_fn))
    fn_ast = ast.body[0]
    assert fn_ast.parent == ast
    assign_ast = fn_ast.body[0]
    assert assign_ast.parent == fn_ast
    simple_ast, const_ast = assign_ast.targets[0], assign_ast.value
    assert simple_ast.parent == assign_ast
    assert const_ast.parent == assign_ast
Exemple #5
0
 def do_transform(ternary_ast: AST) -> AST:
     # filter out non-ternary expressions
     if not isinstance(ternary_ast, IfExp):
         return ternary_ast
     # obtain AST of calling the plotter to plotting a switch node
     plot_switch_fn = parse_ast(Plotter.switch).body[0]
     return call_func_ast(
         fn_name=plot_switch_fn.name,
         args={
             "condition": ternary_ast.test,
             "true": ternary_ast.body,
             "false": ternary_ast.orelse,
         },
         attr_parent=ast.convert_fn.plotter_name,
     )
Exemple #6
0
def test_lint_convert_fn():
    class NotConvertFn:
        pass

    def no_args():
        pass

    def too_many_args(a, b, c=3):
        pass

    def generator_fn(g):
        yield 2

    def convert_fn(g):
        pass

    def convert_fn_type(g: Plotter):
        pass

    # define test cases: lint fn to expected error
    lint_fns = [
        (NotConvertFn, NotImplementedError),
        (no_args, TypeError),
        (too_many_args, TypeError),
        (generator_fn, ValueError),
        (convert_fn, None),
        (convert_fn_type, None),
    ]
    # required ast analyzers in order for linting to work
    req_analyzers = [
        analyze_func,
        analyze_convert_fn,
    ]

    for fn, expected_err in lint_fns:
        ast = parse_ast(fn)
        for analyze in req_analyzers:
            ast = analyze(ast)

        has_err = False
        try:
            lint_convert_fn(ast)
        except expected_err:
            has_err = True

        assert has_err if expected_err is not None else not has_err
Exemple #7
0
def test_const_analyzer():
    def const_fn(g):
        int_const = 2
        float_const = 3.14
        str_const = "str"
        list_const = [1, 2, 3]
        tuple_const = (1, 2, 3)
        boolean_const, boolean_const_2 = True, False

    ast = parse_ast(const_fn)
    analyzed_ast = analyze_const(analyze_assign(ast))
    fn_ast = analyzed_ast.body[0]
    get_const_ast = lambda i: fn_ast.body[i].value
    expected_asts = set(
        [get_const_ast(i) for i in range(5)]
        + [const for const in fn_ast.body[5].value.elts]
    )
    const_asts = set(n for n in gast.walk(ast) if n.is_constant)
    assert const_asts == expected_asts
Exemple #8
0
def test_preprocess_augassign():
    # test cases: line 2 is input AST, line 3 is expected output AST
    def add_augassign_fn():
        x, y = 1, 2
        x += y
        x = x + y

    def sub_augassign_fn():
        x, y = 1, 2
        x -= y
        x = x - y

    def mul_augassign_fn():
        x, y = 1, 2
        x *= y
        x = x * y

    def div_augassign_fn():
        x, y = 1, 2
        x /= y
        x = x / y

    class C:
        x = 1

    def attribute_augassign_fn():
        y = 1, 2
        C.x /= y
        C.x = C.x / y

    augassign_fns = [
        add_augassign_fn,
        sub_augassign_fn,
        mul_augassign_fn,
        div_augassign_fn,
        attribute_augassign_fn,
    ]

    for fn in augassign_fns:
        fn_ast = parse_ast(fn).body[0]
        aug_ast, expected_ast = fn_ast.body[1], fn_ast.body[2]
        actual_ast = preprocess_augassign(aug_ast)
        assert gast.dump(actual_ast) == gast.dump(expected_ast)
Exemple #9
0
def test_convert_fn_analyzer():
    class NotConvertFn:
        def still_not_convert_fn(self):
            pass

    def convert_fn(g):
        pass

    convert_fns = [
        (NotConvertFn, False),
        (convert_fn, True),
    ]
    analyzed_asts = [
        analyze_convert_fn(analyze_func(parse_ast(f[0]))) for f in convert_fns
    ]

    for ast, expected_fn in zip(analyzed_asts, convert_fns):
        fn_ast = ast.convert_fn
        assert (fn_ast is not None) == expected_fn[1]
        if fn_ast is not None:
            assert fn_ast.plotter_name == fn_ast.args.args[0].id
def test_transform_ternary():
    def ternary_fn(g: Plotter):
        int_ternary = 1 if True else 2

    req_analyzers = [
        analyze_func,
        analyze_convert_fn,
    ]
    ast = parse_ast(ternary_fn)
    for analyzer in req_analyzers:
        ast = analyzer(ast)
    trans_ast = transform_build_graph(transform_ternary(ast))

    mod = load_ast_module(trans_ast)
    mock_g = Mock(wraps=Plotter())
    mod.build_graph(mock_g)
    mock_g.switch.assert_called_once_with(
        condition=True,
        true=1,
        false=2,
    )
Exemple #11
0
def compile_graph(
    convert_fn: ConvertFn,
    entity_defs=List[EntityDef],
    component_defs=List[ComponentDef],
    preprocessors: List[Transform] = [
        preprocess_augassign,
    ],
    analyzers: List[Analyzer] = [
        analyze_parent,
        analyze_func,
        analyze_convert_fn,
        analyze_symbol,
        analyze_assign,
        resolve_symbol,
        analyze_block,
        analyze_activity,
    ],
    linters: List[Linter] = [],
    transforms: List[Transform] = [
        transform_build_graph,
        transform_ternary,
        transform_ifelse,
    ],
) -> Graph:
    """Compiles the given `convert_fn` into a computation Graph running the given sim.

    Globals can be used read only in the `convert_fn`. Writing to globals is not supported.

    Compiles by converting the given `convert_fn` function to AST
    applying the given `preprocessors` transforms to perform preprocessing on the AST,
    applying given `analyzers` on the AST to perform static analysis,
    linting the AST with the given `linters` to perform static checks,
    applying the given `transforms` to transform the AST to a function that
    plots the computational graph when run.

    Note:
        Even though both `preprocessors` and `transforms` are comprised of  a list of `Transform`s
        `preprocessors` transforms are applied before any static analysis is done while
        `transforms` are applied after static analysis. This allows `preprocessors` to focus
        on transforming the AST to make static analysis easier while `transforms` to focus on
        transforming the AST to plot a computation graph.

    The transformed AST is converted back to source where it can be imported
    to provide a compiled function that builds the graph using the given `Plotter` on call.
    The graph obtained from the `Plotter` is finally returned.

        Example:
        def car_pos_fn(g: Plotter):
            car = g.entity(
                components=[
                    "position",
                ]
            )
            env = g.entity(components=["clock"])
            x_delta = 20 if env["clock"].tick_ms > 2000 else 10
            car["position"].x = x_delta

        car_pos_graph = compile_graph(car_pos_fn, entity_defs, component_defs)
        # use compiled graph 'car_pos_graph' in code ...

    Args:
        convert_fn: Function containing the code that should be compiled into a computational graph.
            The target `convert_fn` should take in one parameter: a `Plotter` instance which
            allows users to access graphing specific operations. Must be a plain Python
            Function, not a Callable class, method, classmethod or staticmethod.

        entity_defs: List of EntityDef representing the ECS entities available for
            use in `convert_fn` via the Plotter instance.

        component_defs: List of ComponentDef representing the ECS component types
            available for use in `convert_fn` via the Plotter instance.

        preprocessors: List of `Transform`s that are run sequentially to apply
            preprocesssing transforms to the AST before any static analysis is done.
            Typically these AST transforms make static analysis easier by simplifying the AST.

        analyzers: List of `Analyzer`s that are run sequentially on the AST perform
            static analysis.  Analyzers can add attributes to AST nodes but not
            modify the AST tree.

        linters: List of `Linter`s that are run sequentially on the AST to perform
            static checks on the convertability of the AST. `Linter`s are expected
            to throw exception when failing a check.

        transforms: List of `Transform`s that are run sequentially to transform the
            AST to a compiled function (in AST form) that builds the computation
            graph when called.

    Returns:
        The converted computational Graph as a `Graph`.
    """

    # parse ast from function source
    ast = parse_ast(convert_fn)

    # apply preprocessors to apply preprocesssing transforms on the AST
    for preprocessor in preprocessors:
        ast = preprocessor(ast)
    # apply analyzers to conduct static analysis
    for analyzer in analyzers:
        ast = analyzer(ast)
    # check that AST can be coverted by applying linters to check convertability
    for linter in linters:
        linter(ast)
    # convert AST to computation graph by applying transforms
    for transform in transforms:
        ast = transform(ast)

    # load AST back as a module
    compiled, src_path = load_ast_module(ast, remove_src=False)
    # allow the use of globals symbols with respect to convert_fn function
    # to be used during graph plotting
    compiled.build_graph.__globals__.update(
        convert_fn.__globals__)  # type: ignore

    # run build graph function with plotter to build final computation graph
    g = Plotter(entity_defs, component_defs)
    try:
        compiled.build_graph(g)
    except Exception as e:
        print(f"Compilation generated source code with errors: {src_path}")
        raise e

    # remove the intermediate source file generated by load_ast_module()
    os.remove(src_path)
    return g.graph()
Exemple #12
0
    def do_transform(ifelse_ast: AST) -> AST:
        # filter out non-ifelse statements
        if not isinstance(ifelse_ast, If):
            return ifelse_ast

        # convert ifelse condition branches into functions with the arguments
        # set to the names of the base input symbols and return values set to output symbols.
        # base symbols are use to generate the arguments as the full symbol might be qualified
        # ie A.x which is not a valid argument name: https://github.com/joeltio/bento-box/issues/37
        args = list(ifelse_ast.base_in_syms.keys())
        returns = list(ifelse_ast.output_syms.keys())
        fn_asts = [
            wrap_func_ast(
                name=name,
                args=args,
                block=block,
                returns=returns,
                # zip() requires the returned outputs to be iterable
                return_tuple=True,
            ) for name, block in zip(["__if_block", "__else_block"],
                                     [ifelse_ast.body, ifelse_ast.orelse])
        ]

        # deepcopy the condition before tracing the if/else block functions to
        # prevent side effects tracing from interfering with the condition.
        condition_ast = name_ast("__if_condition")
        eval_condition_ast = assign_ast(
            targets=[condition_ast],
            values=[call_func_ast("deepcopy", args=[ifelse_ast.test])],
        )

        # call if/else block functions to trace results of evaluating each branch
        # of the conditional if/else block functions have arguments with the same
        # names as symbols we have to pass in.
        # deepcopy to prevent input symbols from being passed by reference and
        # causing interference between branches https://github.com/joeltio/bento-box/issues/39
        import_deepcopy_ast = import_from_ast(module="copy",
                                              names=["deepcopy"])

        call_args = {
            a: call_func_ast(
                fn_name="deepcopy",
                args=[name_ast(a)],
            )
            for a in args
        }
        branch_outputs = [
            name_ast(n) for n in ["__if_outputs", "__else_outputs"]
        ]
        call_fn_asts = [
            assign_ast(
                targets=[target],
                values=[call_func_ast(fn_ast.name, args=call_args)],
            ) for target, fn_ast in zip(branch_outputs, fn_asts)
        ]

        # create switch nodes for each output symbol via list comprehension
        plot_switch_fn = parse_ast(Plotter.switch).body[0]
        # g.switch(test, if_out, else_out)
        call_switch_ast = call_func_ast(
            fn_name=plot_switch_fn.name,
            args={
                "condition": condition_ast,
                "true": name_ast("if_out"),
                "false": name_ast("else_out"),
            },
            attr_parent=ast.convert_fn.plotter_name,
        )

        # (symbol, ...) = [g.switch(...) for if_out, else_out in zip(if_outputs, else_outputs)]
        switch_asts = assign_ast(
            targets=[name_ast(r, ctx=Store()) for r in returns],
            values=[
                ListComp(
                    elt=call_switch_ast,
                    generators=[
                        comprehension(
                            target=Tuple(
                                elts=[
                                    name_ast("if_out"),
                                    name_ast("else_out"),
                                ],
                                ctx=Load(),
                            ),
                            iter=call_func_ast(
                                fn_name="zip",
                                args=branch_outputs,
                            ),
                            ifs=[],
                            is_async=False,
                        )
                    ],
                )
            ],
            force_tuple=True,
        )
        # wrap transformed code block as single AST node
        return wrap_block_ast(block=fn_asts +
                              [import_deepcopy_ast, eval_condition_ast] +
                              call_fn_asts + [switch_asts], )
Exemple #13
0
def test_analyze_activity():
    def output_only_fn():
        if True:
            # since 'x' is assigned inside the if block
            # it should not be label as input_syms wrt. if block
            x = 1
            y = x + 1

    def input_only_fn():
        x, f = 1, (lambda x: x)
        if True:
            f(x)
            f(x + 3)

    def input_output_fn():
        x = 1
        if True:
            y = x + 1
            x = 2
            y = x + 3

    def nested_in_out_fn():
        x = 1
        # parent if block should obtain child node's inputs and ouputs
        if False:
            if True:
                y = x + 1
                x = 2
                y = x + 3

    def multi_in_out_fn():
        x = 0
        if True:
            x = x + 1
            y = 2
        else:
            x = x - 1
            y = 3

    class A:
        x = 1

    def qualified_in_out():
        A.x = 1
        # test with qualified symbol A.x
        if True:
            y = A.x + 1
            A.x = 2
            y = A.x + 3

    # test case, expected attributes
    activity_fns = [
        (
            output_only_fn,
            {
                "input_syms": [],
                "output_syms": ["x", "y"],
                "base_in_syms": [],
                "base_out_syms": ["x", "y"],
            },
        ),
        (
            input_only_fn,
            {
                "input_syms": ["x", "f"],
                "output_syms": [],
                "base_in_syms": ["x", "f"],
                "base_out_syms": [],
            },
        ),
        (
            input_output_fn,
            {
                "input_syms": ["x"],
                "output_syms": ["x", "y"],
                "base_in_syms": ["x"],
                "base_out_syms": ["x", "y"],
            },
        ),
        (
            nested_in_out_fn,
            {
                "input_syms": ["x"],
                "output_syms": ["x", "y"],
                "base_in_syms": ["x"],
                "base_out_syms": ["x", "y"],
            },
        ),
        (
            multi_in_out_fn,
            {
                "input_syms": ["x"],
                "output_syms": ["x", "y"],
                "base_in_syms": ["x"],
                "base_out_syms": ["x", "y"],
            },
        ),
        (
            qualified_in_out,
            {
                "input_syms": ["A.x"],
                "output_syms": ["A.x", "y"],
                "base_in_syms": ["A"],
                "base_out_syms": ["A", "y"],
            },
        ),
    ]

    for fn, expected_attrs in activity_fns:
        required_analyzers = [
            analyze_assign,
            analyze_symbol,
            resolve_symbol,
            analyze_block,
        ]
        ast = parse_ast(fn)
        for analyzer in required_analyzers:
            ast = analyzer(ast)
        analyzed_ast = analyze_activity(ast)
        fn_ast = analyzed_ast.body[0]

        # extract code block AST node
        block_ast = [n for n in gast.walk(fn_ast) if n.is_block and n != fn_ast][0]
        actual_attrs = {
            "output_syms": block_ast.output_syms.keys(),
            "input_syms": block_ast.input_syms.keys(),
            "base_in_syms": block_ast.base_in_syms.keys(),
            "base_out_syms": block_ast.base_out_syms.keys(),
        }
        # check detect of input and output symbols
        for attr in expected_attrs.keys():
            if set(actual_attrs[attr]) != set(expected_attrs[attr]):
                print(fn)
                __import__("pprint").pprint(expected_attrs)
                __import__("pprint").pprint(actual_attrs)
            assert set(actual_attrs[attr]) == set(expected_attrs[attr])

        # check contents of symbol dict
        combined_syms = list(block_ast.input_syms.items()) + list(
            block_ast.output_syms.items()
        )
        for symbol, sym_asts in combined_syms:
            assert all(ast.symbol == symbol for ast in sym_asts)
            # check sym_asts sorted by order of appearance in source code
            # (ie code pos does not decrease) https://stackoverflow.com/a/4983359
            if len(sym_asts) > 1:
                code_pos = lambda ast: (ast.lineno, ast.col_offset)
                assert all(
                    [code_pos(x) <= code_pos(y) for x, y in zip(sym_asts, sym_asts[1:])]
                )
Exemple #14
0
def test_analyze_block():
    def ternary_fn():
        x = 1 if True else False

    def list_comp_fn():
        x = [i for i in [1, 2, 3]]

    def dict_comp_fn():
        x = {i: j for i, j in [[1, 2], [2, 3], [3, 4]]}

    def lambda_fn():
        x = lambda y: y + 2

    def func_fn():
        def fn(y):
            return y + 2

    def ifelse_fn():
        if True:
            x = 1
        else:
            x = 3

    def for_fn():
        for i in [1, 2, 3]:
            x = i
        else:
            y = i

    def while_fn():
        while True:
            x = 2

    def with_fn():
        with 1 as x:
            y = x

    def try_fn():
        try:
            x = 1
        finally:
            z = x == 1

    # test case, whether the first statement is code block, expected ast getter
    block_fns = [
        (ternary_fn, False),
        (list_comp_fn, False),
        (dict_comp_fn, False),
        (lambda_fn, False),
        (func_fn, True),
        (ifelse_fn, True),
        (for_fn, True),
        (while_fn, True),
        (with_fn, True),
        (try_fn, True),
    ]

    for fn, is_expected_block in block_fns:
        analyzed_ast = analyze_block(parse_ast(fn))
        fn_ast = analyzed_ast.body[0]
        block_ast = fn_ast.body[0]
        # check code blocks are labeled correctly
        assert block_ast.is_block == is_expected_block
        # check back edges to code block are created correctedly to child nodes
        if is_expected_block:
            for child in gast.walk(block_ast):
                # walk() will include the root block ast node..
                # ignore when checking code block back edges
                if child == block_ast:
                    continue

                if child.block != block_ast:
                    __import__("pprint").pprint((gast.dump(child.block)))
                    __import__("pprint").pprint((gast.dump(block_ast)))
                assert child.block == block_ast
Exemple #15
0
def test_symbol_resolution():
    def simple_fn():
        simple = 2
        simple

    def multi_assign():
        multi_a, multi_b = True, False
        multi_a, multi_b

    def repeated_assign():
        repeated = "first"
        repeated = "second"
        repeated

    def scoped_assign():
        scoped = False

        def fn():
            scoped = True

        # should reference the first definition of 'scoped' as the second definition
        # is scoped to only within function
        scoped

    def aug_assign():
        aug = 0
        aug = aug + 1
        aug

    class Qualified:
        a = 1

    def qualified_assign():
        Qualified.a = 1
        Qualified.a

    # test case functions, the line no. wrt. the function where the variable last defined
    # and finally list of all line no. where variable is defined.
    # if line no. is None, the symbols is defined global symbol
    symbol_fns = [
        (simple_fn, 0, [0]),
        (multi_assign, 0, [0]),
        (repeated_assign, 1, [0, 1]),
        (scoped_assign, 0, [0]),
        (aug_assign, 1, [0, 1]),
        (qualified_assign, 0, [0]),
    ]

    for symbol_fn, n_latest_def_line, n_def_lines in symbol_fns:
        ast = parse_ast(symbol_fn)
        required_analyzers = [
            analyze_symbol,
            analyze_assign,
        ]
        for analyzer in required_analyzers:
            ast = analyzer(ast)
        analyzed_ast = resolve_symbol(ast)
        fn_ast = analyzed_ast.body[0]

        # check latest symbol definition labeled as 'definition'
        latest_sym_def = fn_ast.body[n_latest_def_line]
        sym_ref = fn_ast.body[-1].value
        latest_sym_defs = latest_sym_def.values
        sym_refs = sym_ref.elts if isinstance(sym_ref, Tuple) else [sym_ref]
        for latest_sym_def, sym_ref in zip(latest_sym_defs, sym_refs):
            if sym_ref.definition != latest_sym_def:
                print(gast.dump(sym_ref))
                print(gast.dump(sym_ref.definition))
            assert sym_ref.definition == latest_sym_def

        # check all symbol definitions labeled as 'definitions'
        sym_defs = [fn_ast.body[n_line] for n_line in n_def_lines]
        for line_sym_defs in sym_defs:
            for sym_def, sym_ref in zip(line_sym_defs.values, sym_refs):
                if sym_ref.definitions.count(sym_def) != 1:
                    print(gast.dump(sym_ref))
                    print([gast.dump(d) for d in sym_ref.definitions])
                assert sym_ref.definitions.count(sym_def) == 1
def test_transform_ifelse():
    def if_fn(g: Plotter):
        x, w = "str1", "str2"
        if True:
            x = w

    def ifelse_fn(g: Plotter):
        w, y = "str1", "str2"
        if True:
            x = w
            z = 1
        else:
            x = y
            z = 2

    def ifelse_elif_else_fn(g: Plotter):
        y, m, n = "str1", "str2", "str3"
        if True:
            x = y
            z = 1
        elif False:
            x = m
            z = 2
        else:
            x = n
            z = 3

    def ifelse_augassign_fn(g: Plotter):
        x = 1
        if True:
            x = x + 1
        else:
            x = x + 2

    def if_assign_condition_fn(g: Plotter):
        class A:
            b = True

        # test that the condition is evaluated immediately => True
        if A.b:
            A.b = True
        else:
            A.b = False

    req_analyzers = [
        analyze_func,
        analyze_convert_fn,
        analyze_assign,
        analyze_symbol,
        resolve_symbol,
        analyze_block,
        analyze_activity,
    ]

    # test case plotter => expected g.switch() call args
    g = Plotter()
    ifelse_fns = [
        (
            if_fn,
            [
                {
                    "condition": True,
                    "true": "str2",
                    "false": "str1"
                },
            ],
        ),
        (
            ifelse_fn,
            [
                {
                    "condition": True,
                    "true": "str1",
                    "false": "str2"
                },
                {
                    "condition": True,
                    "true": 1,
                    "false": 2
                },
            ],
        ),
        (
            ifelse_elif_else_fn,
            [
                {
                    "condition": True,
                    "true": "str1",
                    "false": g.switch(False, "str2", "str3"),
                },
                {
                    "condition": True,
                    "true": 1,
                    "false": g.switch(False, 2, 3)
                },
                {
                    "condition": False,
                    "true": "str2",
                    "false": "str3"
                },
                {
                    "condition": False,
                    "true": 2,
                    "false": 3
                },
            ],
        ),
        (
            ifelse_augassign_fn,
            [
                {
                    "condition": True,
                    "true": 2,
                    "false": 3
                },
            ],
        ),
        (
            if_assign_condition_fn,
            [
                {
                    "condition": True,
                    "true": True,
                    "false": False
                },
            ],
        ),
    ]

    for fn, expected_switch_args in ifelse_fns:
        ast = parse_ast(fn)
        for analyzer in req_analyzers:
            ast = analyzer(ast)
        trans_ast = transform_build_graph(transform_ifelse(ast))

        mod = load_ast_module(trans_ast)
        mock_g = Mock(wraps=Plotter())
        mod.build_graph(mock_g)

        for expected_arg in expected_switch_args:
            mock_g.switch.assert_any_call(**expected_arg)
Exemple #17
0
def test_assign_analyzer():
    def single_assign():
        x = 2

    def multi_assign():
        a = b = True

    def unpack_assign():
        x, y = 1, 2

    assign_fns = [
        (
            single_assign,
            {
                "n_targets": 1,
                "n_values": 1,
                "is_unpack": False,
                "is_multi": False,
                "values": [2],
                "tgts": ["x"],
            },
        ),
        (
            multi_assign,
            {
                "n_targets": 2,
                "n_values": 1,
                "is_unpack": False,
                "is_multi": True,
                "values": [True],
                "tgts": ["a", "b"],
            },
        ),
        (
            unpack_assign,
            {
                "n_targets": 2,
                "n_values": 2,
                "is_unpack": True,
                "is_multi": False,
                "values": [1, 2],
                "tgts": ["x", "y"],
            },
        ),
    ]

    assign_types = (
        Assign,
        AnnAssign,
        AugAssign,
    )
    analyzed_asts = [analyze_assign(parse_ast(f[0])) for f in assign_fns]
    for ast, assign_fn in zip(analyzed_asts, assign_fns):
        assign_ast = [n for n in gast.walk(ast) if isinstance(n, assign_types)][0]

        # check ast has expected annotations
        _, expected_annotations = assign_fn
        for name, expected_value in expected_annotations.items():
            value = getattr(assign_ast, name)
            # for tgts & values, directly check name match instead of comparing ast node
            if name == "tgts":
                tgts = value
                assert [t.id for t in tgts] == expected_value
            elif name == "values":
                values = value
                assert [v.value for v in values] == expected_value
            else:
                value == expected_value