示例#1
0
def test_variable_replace():
    def f(x):
        x = 2
        return x

    body = template.replace(f, x=gast.Name(id='y', ctx=None, annotation=None))
    assert body[0].targets[0].id == 'y'
    assert isinstance(body[0].targets[0].ctx, gast.Store)
    assert isinstance(body[1].value.ctx, gast.Load)
    compile_.compile_function(_wrap(body))
示例#2
0
def test_full_gradient_replace():
    def f(x, y):
        d[x] = d[y]

    tree = quoting.parse_function(f)
    transformer = template.ReplaceGradTransformer(template.Replace.FULL)
    new_tree = transformer.visit(tree)
    assert isinstance(new_tree.body[0].body[0].targets[0], gast.Name)
    assert new_tree.body[0].body[0].targets[0].id == 'bx'
    assert new_tree.body[0].body[0].value.id == 'by'
    compile_.compile_function(new_tree)
示例#3
0
def test_statement_replace():
    def f(body):
        body

    body = [
        gast.Expr(value=gast.Name(id=var, ctx=gast.Load(), annotation=None))
        for var in 'xy'
    ]
    new_body = template.replace(f, body=body)
    assert len(new_body) == 2
    assert isinstance(new_body[0], gast.Expr)
    compile_.compile_function(_wrap(new_body))
示例#4
0
def test_compile():
  def f(x):
    return x * 2

  f = compile_.compile_function(quoting.parse_function(f))
  assert f(2) == 4
  assert inspect.getsource(f).split('\n')[0] == 'def f(x):'

  def f(x):
    return y * 2

  f = compile_.compile_function(quoting.parse_function(f), {'y': 3})
  assert f(2) == 6
示例#5
0
def test_function_replace():
    def f(f, args):
        def f(args):
            pass

    body = template.replace(
        f,
        f='g',
        args=[gast.Name(id=arg, ctx=None, annotation=None) for arg in 'ab'])
    assert isinstance(body[0], gast.FunctionDef)
    assert body[0].name == 'g'
    assert len(body[0].args.args) == 2
    assert isinstance(body[0].args.args[0].ctx, gast.Param)
    assert body[0].args.args[1].id == 'b'
    compile_.compile_function(_wrap(body))
示例#6
0
def tangent(f):
    """A decorator which removes the `with grad_of` statement.

  This allows the function to be called as usual.

  Args:
    f: A function

  Returns:
    A function with any `with grad_of` context managers removed.
  """
    node = annotate.resolve_calls(f)
    RemoveWith().visit(node)
    wrapped = functools.wraps(f)(compile_.compile_function(node))
    wrapped.tangent = f
    return wrapped
示例#7
0
def test_function_compile():
  with pytest.raises(TypeError):
    compile_.compile_function(quoting.quote('x = y'))
  with pytest.raises(ValueError):
    compile_.compile_function(gast.parse('x = y'))