Exemplo n.º 1
0
def test_mutiple_returns():

    source = unindent('''
    def f(x, y, z='foo'):
        if x:
            b = y + list(x)
            return b
        else:
            return z
    ''')
    tree = ast.parse(source)

    expected_source = unindent('''
    def f(__peval_mangled_1, __peval_mangled_2, __peval_mangled_3='foo'):
        if __peval_mangled_1:
            __peval_mangled_4 = __peval_mangled_2 + list(__peval_mangled_1)
            return __peval_mangled_4
        else:
            return __peval_mangled_3
    ''')
    expected_tree = ast.parse(expected_source)

    gen_sym = GenSym.for_tree(tree)
    gen_sym, new_tree = mangle(gen_sym, tree)

    assert_ast_equal(new_tree, expected_tree)
Exemplo n.º 2
0
def function_from_source(source, globals_=None):
    """
    A helper function to construct a Function object from a source
    with custom __future__ imports.
    """

    module = ast.parse(unindent(source))
    ast.fix_missing_locations(module)

    for stmt in module.body:
        if type(stmt) == ast.FunctionDef:
            tree = stmt
            name = stmt.name
            break
    else:
        raise ValueError("No function definitions found in the provided source")

    code_object = compile(module, '<nofile>', 'exec', dont_inherit=True)
    locals_ = {}
    eval(code_object, globals_, locals_)

    function_obj = locals_[name]
    function_obj._peval_source = astunparse.unparse(tree)

    return Function.from_object(function_obj)
Exemplo n.º 3
0
def check_partial_apply(func, args=None, kwds=None,
        expected_source=None, expected_new_bindings=None):
    ''' Test that with given constants, optimized_ast transforms
    source to expected_source.
    It :expected_new_bindings: is given, we check that they
    are among new bindings returned by optimizer.
    '''

    if args is None:
        args = tuple()
    if kwds is None:
        kwds = {}

    new_func = partial_apply(func, *args, **kwds)
    function = Function.from_object(new_func)

    if expected_source is not None:
        assert_ast_equal(function.tree, ast.parse(unindent(expected_source)).body[0])

    if expected_new_bindings is not None:
        for k in expected_new_bindings:
            if k not in function.globals:
                print('Expected binding missing:', k)

            binding = function.globals[k]
            expected_binding = expected_new_bindings[k]

            # Python 3.2 defines equality for range objects incorrectly
            # (namely, the result is always False).
            # So we just test it manually.
            if sys.version_info < (3, 3) and isinstance(expected_binding, range):
                assert type(binding) == type(expected_binding)
                assert list(binding) == list(expected_binding)
            else:
                assert binding == expected_binding
Exemplo n.º 4
0
def check_component(component, func, additional_bindings=None,
        expected_source=None, expected_new_bindings=None):

    function = Function.from_object(func)
    bindings = function.get_external_variables()
    if additional_bindings is not None:
        bindings.update(additional_bindings)

    new_tree, new_bindings = component(function.tree, bindings)

    if expected_source is None:
        expected_ast = function.tree
    else:
        expected_ast = ast.parse(unindent(expected_source)).body[0]

    assert_ast_equal(new_tree, expected_ast)

    if expected_new_bindings is not None:
        for k in expected_new_bindings:
            if k not in new_bindings:
                print('Expected binding missing:', k)

            binding = new_bindings[k]
            expected_binding = expected_new_bindings[k]

            # Python 3.2 defines equality for range objects incorrectly
            # (namely, the result is always False).
            # So we just test it manually.
            if sys.version_info < (3, 3) and isinstance(expected_binding, range):
                assert type(binding) == type(expected_binding)
                assert list(binding) == list(expected_binding)
            else:
                assert binding == expected_binding
Exemplo n.º 5
0
def check_partial_apply(func, args=None, kwds=None,
        expected_source=None, expected_new_bindings=None):
    """
    Test that with given constants, optimized_ast transforms
    source to expected_source.
    It :expected_new_bindings: is given, we check that they
    are among new bindings returned by optimizer.
    """

    if args is None:
        args = tuple()
    if kwds is None:
        kwds = {}

    new_func = partial_apply(func, *args, **kwds)
    function = Function.from_object(new_func)

    if expected_source is not None:
        assert_ast_equal(function.tree, ast.parse(unindent(expected_source)).body[0])

    if expected_new_bindings is not None:
        for k in expected_new_bindings:
            if k not in function.globals:
                print('Expected binding missing:', k)

            binding = function.globals[k]
            expected_binding = expected_new_bindings[k]

            assert binding == expected_binding
Exemplo n.º 6
0
def check_component(component, func, additional_bindings=None,
        expected_source=None, expected_new_bindings=None):

    function = Function.from_object(func)
    bindings = function.get_external_variables()
    if additional_bindings is not None:
        bindings.update(additional_bindings)

    new_tree, new_bindings = component(function.tree, bindings)

    if expected_source is None:
        expected_ast = function.tree
    else:
        expected_ast = ast.parse(unindent(expected_source)).body[0]

    assert_ast_equal(new_tree, expected_ast)

    if expected_new_bindings is not None:
        for k in expected_new_bindings:
            if k not in new_bindings:
                print('Expected binding missing:', k)

            binding = new_bindings[k]
            expected_binding = expected_new_bindings[k]
            assert binding == expected_binding
Exemplo n.º 7
0
def test_ast_equal():
    src = """
        def sample_fn(x, y, foo='bar', **kw):
            if (foo == 'bar'):
                return (x + y)
            else:
                return kw['zzz']
        """

    # Different node type (`-` instead of `+`)
    different_node = """
        def sample_fn(x, y, foo='bar', **kw):
            if (foo == 'bar'):
                return (x - y)
            else:
                return kw['zzz']
        """

    # Different value in a node ('zzy' instead of 'zzz')
    different_value = """
        def sample_fn(x, y, foo='bar', **kw):
            if (foo == 'bar'):
                return (x + y)
            else:
                return kw['zzy']
        """

    # Additional element in a body
    different_length = """
        def sample_fn(x, y, foo='bar', **kw):
            if (foo == 'bar'):
                return (x + y)
                return 1
            else:
                return kw['zzz']
        """

    tree = ast.parse(unindent(src))
    different_node = ast.parse(unindent(different_node))
    different_value = ast.parse(unindent(different_value))
    different_length = ast.parse(unindent(different_length))

    assert ast_equal(tree, tree)
    assert not ast_equal(tree, different_node)
    assert not ast_equal(tree, different_value)
    assert not ast_equal(tree, different_length)
Exemplo n.º 8
0
def test_get_fn_arg_id():
    src = """
        def f(x):
            pass
        """
    tree = ast.parse(unindent(src))
    fn_arg = tree.body[0].args.args[0]

    assert get_fn_arg_id(fn_arg) == 'x'
Exemplo n.º 9
0
def getsource(func):
    """
    Returns the source of a function ``func``.
    Falls back to ``inspect.getsource()`` for regular functions,
    but can also return the source of a partially evaluated function.
    """

    if hasattr(func, SOURCE_ATTRIBUTE):
        # An attribute created in ``Function.eval()``
        return getattr(func, SOURCE_ATTRIBUTE)
    else:
        return unindent(inspect.getsource(func))
Exemplo n.º 10
0
def test_unindent_unexpected_indentation():
    src = """
        def sample_fn(x, y, foo='bar', **kw):
            if (foo == 'bar'):
                return (x + y)
            else:
                return kw['zzz']
       some_code() # indentation here does not start from the same position as the first line!
        """

    with pytest.raises(ValueError):
        result = unindent(src)
Exemplo n.º 11
0
def test_get_source():
    function = Function.from_object(sample_fn)
    source = normalize_source(function.get_source())

    expected_source = unindent(
        """
        def sample_fn(x, y, foo='bar', **kw):
            if (foo == 'bar'):
                return (x + y)
            else:
                return kw['zzz']
        """)

    assert source == expected_source
Exemplo n.º 12
0
def test_unindent():
    src = """
        def sample_fn(x, y, foo='bar', **kw):
            if (foo == 'bar'):
                return (x + y)
            else:
                return kw['zzz']
        """
    expected_src = """def sample_fn(x, y, foo='bar', **kw):
    if (foo == 'bar'):
        return (x + y)
    else:
        return kw['zzz']"""

    assert unindent(src) == expected_src
Exemplo n.º 13
0
def test_unindent_empty_line():
    src = (
        """
        def sample_fn(x, y, foo='bar', **kw):\n"""
        # Technically, this line would be an unexpected indentation,
        # because it does not start with 8 spaces.
        # But `unindent` will see that it's just an empty line
        # and just replace it with a single `\n`.
        "    \n"
        """            if (foo == 'bar'):
                return (x + y)
            else:
                return kw['zzz']
        """)

    expected_src = (
        "def sample_fn(x, y, foo='bar', **kw):\n"
        "\n"
        """    if (foo == 'bar'):
        return (x + y)
    else:
        return kw['zzz']""")

    assert unindent(src) == expected_src
Exemplo n.º 14
0
def get_ast(function):
    if isinstance(function, str):
        return ast.parse(unindent(function))
    else:
        return ast.parse(inspect.getsource(function))