def test_mutiple_returns(): source = unindent(''' def f(x, y, z='foo'): if x: b = y + list(x) return b else: return z ''') tree = ast.parse(source) expected_source = unindent(''' def f(__peval_mangled_1, __peval_mangled_2, __peval_mangled_3='foo'): if __peval_mangled_1: __peval_mangled_4 = __peval_mangled_2 + list(__peval_mangled_1) return __peval_mangled_4 else: return __peval_mangled_3 ''') expected_tree = ast.parse(expected_source) gen_sym = GenSym.for_tree(tree) gen_sym, new_tree = mangle(gen_sym, tree) assert_ast_equal(new_tree, expected_tree)
def function_from_source(source, globals_=None): """ A helper function to construct a Function object from a source with custom __future__ imports. """ module = ast.parse(unindent(source)) ast.fix_missing_locations(module) for stmt in module.body: if type(stmt) == ast.FunctionDef: tree = stmt name = stmt.name break else: raise ValueError("No function definitions found in the provided source") code_object = compile(module, '<nofile>', 'exec', dont_inherit=True) locals_ = {} eval(code_object, globals_, locals_) function_obj = locals_[name] function_obj._peval_source = astunparse.unparse(tree) return Function.from_object(function_obj)
def check_partial_apply(func, args=None, kwds=None, expected_source=None, expected_new_bindings=None): ''' Test that with given constants, optimized_ast transforms source to expected_source. It :expected_new_bindings: is given, we check that they are among new bindings returned by optimizer. ''' if args is None: args = tuple() if kwds is None: kwds = {} new_func = partial_apply(func, *args, **kwds) function = Function.from_object(new_func) if expected_source is not None: assert_ast_equal(function.tree, ast.parse(unindent(expected_source)).body[0]) if expected_new_bindings is not None: for k in expected_new_bindings: if k not in function.globals: print('Expected binding missing:', k) binding = function.globals[k] expected_binding = expected_new_bindings[k] # Python 3.2 defines equality for range objects incorrectly # (namely, the result is always False). # So we just test it manually. if sys.version_info < (3, 3) and isinstance(expected_binding, range): assert type(binding) == type(expected_binding) assert list(binding) == list(expected_binding) else: assert binding == expected_binding
def check_component(component, func, additional_bindings=None, expected_source=None, expected_new_bindings=None): function = Function.from_object(func) bindings = function.get_external_variables() if additional_bindings is not None: bindings.update(additional_bindings) new_tree, new_bindings = component(function.tree, bindings) if expected_source is None: expected_ast = function.tree else: expected_ast = ast.parse(unindent(expected_source)).body[0] assert_ast_equal(new_tree, expected_ast) if expected_new_bindings is not None: for k in expected_new_bindings: if k not in new_bindings: print('Expected binding missing:', k) binding = new_bindings[k] expected_binding = expected_new_bindings[k] # Python 3.2 defines equality for range objects incorrectly # (namely, the result is always False). # So we just test it manually. if sys.version_info < (3, 3) and isinstance(expected_binding, range): assert type(binding) == type(expected_binding) assert list(binding) == list(expected_binding) else: assert binding == expected_binding
def check_partial_apply(func, args=None, kwds=None, expected_source=None, expected_new_bindings=None): """ Test that with given constants, optimized_ast transforms source to expected_source. It :expected_new_bindings: is given, we check that they are among new bindings returned by optimizer. """ if args is None: args = tuple() if kwds is None: kwds = {} new_func = partial_apply(func, *args, **kwds) function = Function.from_object(new_func) if expected_source is not None: assert_ast_equal(function.tree, ast.parse(unindent(expected_source)).body[0]) if expected_new_bindings is not None: for k in expected_new_bindings: if k not in function.globals: print('Expected binding missing:', k) binding = function.globals[k] expected_binding = expected_new_bindings[k] assert binding == expected_binding
def check_component(component, func, additional_bindings=None, expected_source=None, expected_new_bindings=None): function = Function.from_object(func) bindings = function.get_external_variables() if additional_bindings is not None: bindings.update(additional_bindings) new_tree, new_bindings = component(function.tree, bindings) if expected_source is None: expected_ast = function.tree else: expected_ast = ast.parse(unindent(expected_source)).body[0] assert_ast_equal(new_tree, expected_ast) if expected_new_bindings is not None: for k in expected_new_bindings: if k not in new_bindings: print('Expected binding missing:', k) binding = new_bindings[k] expected_binding = expected_new_bindings[k] assert binding == expected_binding
def test_ast_equal(): src = """ def sample_fn(x, y, foo='bar', **kw): if (foo == 'bar'): return (x + y) else: return kw['zzz'] """ # Different node type (`-` instead of `+`) different_node = """ def sample_fn(x, y, foo='bar', **kw): if (foo == 'bar'): return (x - y) else: return kw['zzz'] """ # Different value in a node ('zzy' instead of 'zzz') different_value = """ def sample_fn(x, y, foo='bar', **kw): if (foo == 'bar'): return (x + y) else: return kw['zzy'] """ # Additional element in a body different_length = """ def sample_fn(x, y, foo='bar', **kw): if (foo == 'bar'): return (x + y) return 1 else: return kw['zzz'] """ tree = ast.parse(unindent(src)) different_node = ast.parse(unindent(different_node)) different_value = ast.parse(unindent(different_value)) different_length = ast.parse(unindent(different_length)) assert ast_equal(tree, tree) assert not ast_equal(tree, different_node) assert not ast_equal(tree, different_value) assert not ast_equal(tree, different_length)
def test_get_fn_arg_id(): src = """ def f(x): pass """ tree = ast.parse(unindent(src)) fn_arg = tree.body[0].args.args[0] assert get_fn_arg_id(fn_arg) == 'x'
def getsource(func): """ Returns the source of a function ``func``. Falls back to ``inspect.getsource()`` for regular functions, but can also return the source of a partially evaluated function. """ if hasattr(func, SOURCE_ATTRIBUTE): # An attribute created in ``Function.eval()`` return getattr(func, SOURCE_ATTRIBUTE) else: return unindent(inspect.getsource(func))
def test_unindent_unexpected_indentation(): src = """ def sample_fn(x, y, foo='bar', **kw): if (foo == 'bar'): return (x + y) else: return kw['zzz'] some_code() # indentation here does not start from the same position as the first line! """ with pytest.raises(ValueError): result = unindent(src)
def test_get_source(): function = Function.from_object(sample_fn) source = normalize_source(function.get_source()) expected_source = unindent( """ def sample_fn(x, y, foo='bar', **kw): if (foo == 'bar'): return (x + y) else: return kw['zzz'] """) assert source == expected_source
def test_unindent(): src = """ def sample_fn(x, y, foo='bar', **kw): if (foo == 'bar'): return (x + y) else: return kw['zzz'] """ expected_src = """def sample_fn(x, y, foo='bar', **kw): if (foo == 'bar'): return (x + y) else: return kw['zzz']""" assert unindent(src) == expected_src
def test_unindent_empty_line(): src = ( """ def sample_fn(x, y, foo='bar', **kw):\n""" # Technically, this line would be an unexpected indentation, # because it does not start with 8 spaces. # But `unindent` will see that it's just an empty line # and just replace it with a single `\n`. " \n" """ if (foo == 'bar'): return (x + y) else: return kw['zzz'] """) expected_src = ( "def sample_fn(x, y, foo='bar', **kw):\n" "\n" """ if (foo == 'bar'): return (x + y) else: return kw['zzz']""") assert unindent(src) == expected_src
def get_ast(function): if isinstance(function, str): return ast.parse(unindent(function)) else: return ast.parse(inspect.getsource(function))