def test_variable_replace(): def f(x): x = 2 return x body = template.replace(f, x=gast.Name(id='y', ctx=None, annotation=None)) assert body[0].targets[0].id == 'y' assert isinstance(body[0].targets[0].ctx, gast.Store) assert isinstance(body[1].value.ctx, gast.Load) compile_.compile_function(_wrap(body))
def test_full_gradient_replace(): def f(x, y): d[x] = d[y] tree = quoting.parse_function(f) transformer = template.ReplaceGradTransformer(template.Replace.FULL) new_tree = transformer.visit(tree) assert isinstance(new_tree.body[0].body[0].targets[0], gast.Name) assert new_tree.body[0].body[0].targets[0].id == 'bx' assert new_tree.body[0].body[0].value.id == 'by' compile_.compile_function(new_tree)
def test_statement_replace(): def f(body): body body = [ gast.Expr(value=gast.Name(id=var, ctx=gast.Load(), annotation=None)) for var in 'xy' ] new_body = template.replace(f, body=body) assert len(new_body) == 2 assert isinstance(new_body[0], gast.Expr) compile_.compile_function(_wrap(new_body))
def test_compile(): def f(x): return x * 2 f = compile_.compile_function(quoting.parse_function(f)) assert f(2) == 4 assert inspect.getsource(f).split('\n')[0] == 'def f(x):' def f(x): return y * 2 f = compile_.compile_function(quoting.parse_function(f), {'y': 3}) assert f(2) == 6
def test_function_replace(): def f(f, args): def f(args): pass body = template.replace( f, f='g', args=[gast.Name(id=arg, ctx=None, annotation=None) for arg in 'ab']) assert isinstance(body[0], gast.FunctionDef) assert body[0].name == 'g' assert len(body[0].args.args) == 2 assert isinstance(body[0].args.args[0].ctx, gast.Param) assert body[0].args.args[1].id == 'b' compile_.compile_function(_wrap(body))
def tangent(f): """A decorator which removes the `with grad_of` statement. This allows the function to be called as usual. Args: f: A function Returns: A function with any `with grad_of` context managers removed. """ node = annotate.resolve_calls(f) RemoveWith().visit(node) wrapped = functools.wraps(f)(compile_.compile_function(node)) wrapped.tangent = f return wrapped
def test_function_compile(): with pytest.raises(TypeError): compile_.compile_function(quoting.quote('x = y')) with pytest.raises(ValueError): compile_.compile_function(gast.parse('x = y'))