Ejemplo n.º 1
0
def test_compile():
  def f(x):
    return x * 2

  f = compile_.compile_function(quoting.parse_function(f))
  assert f(2) == 4
  assert inspect.getsource(f).split('\n')[0] == 'def f(x):'

  def f(x):
    return y * 2

  f = compile_.compile_function(quoting.parse_function(f), {'y': 3})
  assert f(2) == 6
Ejemplo n.º 2
0
def resolve_calls(func):
    """Parse a function into an AST with function calls resolved.

  Since the calls are resolved using the global and local namespace of the
  function it means that procedural parameters (i.e. functions passed as
  arguments) won't be resolved.

  Similarly, functions defined inside of the function that we are trying to
  resolve won't be resolved, since they are not in the local namespace of the
  outer function.

  The function definition itself is also annotated, so that it can be matched
  to calls to it in other functions.

  Args:
    func: The function whose calls are being resolved.

  Returns:
    node: An AST where each `Call` node has a `func` annotation with the
    function handle that the call resolves to.

  Raises:
    AttributeError: When a function is used on the RHS of an assignment cannot
        be resolved (because it was passed as an argument or was defined in the
        body of the function).
  """
    node = quoting.parse_function(func)
    ResolveCalls(func).visit(node)
    return node
Ejemplo n.º 3
0
def _wrap(body):
    """Take a list of statements and wrap them in a function to compile."""
    def f():
        pass

    tree = quoting.parse_function(f)
    tree.body[0].body = body
    return tree
Ejemplo n.º 4
0
def test_full_gradient_replace():
    def f(x, y):
        d[x] = d[y]

    tree = quoting.parse_function(f)
    transformer = template.ReplaceGradTransformer(template.Replace.FULL)
    new_tree = transformer.visit(tree)
    assert isinstance(new_tree.body[0].body[0].targets[0], gast.Name)
    assert new_tree.body[0].body[0].targets[0].id == 'bx'
    assert new_tree.body[0].body[0].value.id == 'by'
    compile_.compile_function(new_tree)
Ejemplo n.º 5
0
def test_unused():
    def f(x):
        y = x * 2
        return x

    node = quoting.parse_function(f)
    unused = annotate.unused(node)
    assert unused == set([('y', node.body[0].body[0])])

    def f(x):
        y = x * 2
        return y

    unused = annotate.unused(quoting.parse_function(f))
    assert not unused

    def f(x):
        while True:
            y = x * 2
            x = 3
        return y

    unused = annotate.unused(quoting.parse_function(f))
    assert not unused
Ejemplo n.º 6
0
def test_resolve():
    def g(x):
        return 2 * x

    def f(x):
        return g(x)

    node = annotate.resolve_calls(f)
    assert anno.getanno(node.body[0].body[0].value, 'func') == g

    def f(x):
        return h(x)

    node = quoting.parse_function(f)
    with pytest.raises(AttributeError):
        annotate.resolve_calls(f)
Ejemplo n.º 7
0
def test_comment():
    node = quoting.parse_function(f).body[0]

    comments.add_comment(node.body[0], 'foo', 'above')
    source = quoting.to_source(node)
    lines = source.split('\n')
    assert lines[1].strip() == '# foo'

    comments.add_comment(node.body[0], 'foo', 'right')
    source = quoting.to_source(node)
    lines = source.split('\n')
    assert lines[1].strip() == 'y = x # foo'

    comments.add_comment(node.body[0], 'foo', 'below')
    source = quoting.to_source(node)
    lines = source.split('\n')
    assert lines[2].strip() == '# foo'
Ejemplo n.º 8
0
def test_insert():
    def f(x):
        y = x
        return y

    node = quoting.parse_function(f)

    class Prepend(transformers.TreeTransformer):
        def visit_Assign(self, node):
            # If the target is y, then prepend this statement
            # NOTE Without this test, we'd have an infinite loop
            if node.targets[0].id == 'y':
                statement = quoting.quote("x = 2 * x")
                self.prepend(statement)
            return node

    Prepend().visit(node)
    assert quoting.unquote(node).split('\n')[1].strip() == "x = 2 * x"
Ejemplo n.º 9
0
def test_remove():
    def f(x):
        while True:
            y = x
            z = y
        return y

    node = quoting.parse_function(f)

    class InsertTop(transformers.TreeTransformer):
        def visit_Assign(self, node):
            # If the target is y, then prepend this statement
            # NOTE Without this test, we'd have an infinite loop
            if node.targets[0].id == 'z':
                self.remove(node)
            return node

    InsertTop().visit(node)
    assert quoting.unquote(node).split('\n')[3].strip() == "return y"
Ejemplo n.º 10
0
def _create_joint(fwdbwd, func, wrt, input_derivative):
    """Create a user-friendly gradient function.

  By default, gradient functions expect the stack to be passed to them
  explicitly. This function modifies the function so that the stack doesn't
  need to be passed and gets initialized in the function body instead.

  For consistency, gradient functions always return a tuple, even if the
  gradient of only one input was required. We unpack the tuple if it is of
  length one.

  Args:
    fwdbwd: An AST. The function definition of the joint primal and adjoint.
    func: A function handle. The original function that was differentiated.
    wrt: A tuple of integers. The arguments with respect to which we differentiated.

  Returns:
    The function definition of the new function.
  """
    # Correct return to be a non-tuple if there's only one element
    retval = fwdbwd.body[-1]
    if len(retval.value.elts) == 1:
        retval.value = retval.value.elts[0]

    # Make a stack init statement
    init_stack = quoting.quote('%s = tangent.Stack()' % fwdbwd.args.args[0].id)
    init_stack = comments.add_comment(init_stack, 'Initialize the tape')

    # Prepend the stack init to the top of the function
    fwdbwd.body = [init_stack] + fwdbwd.body

    # Replace the function arguments with the original ones
    grad_name = fwdbwd.args.args[1].id
    fwdbwd.args = quoting.parse_function(func).body[0].args

    # Give the function a nice name
    fwdbwd.name = naming.joint_name(func, wrt)

    # Allow the initial gradient to be passed as a keyword argument
    fwdbwd = ast_.append_args(fwdbwd, [grad_name])
    if input_derivative == INPUT_DERIVATIVE.DefaultOne:
        fwdbwd.args.defaults.append(quoting.quote('1.0'))
    return fwdbwd
Ejemplo n.º 11
0
def test_anf():
    def g(x):
        return x * 2

    h = g

    def f(x):
        y = g(h(x))
        return y

    assert anf_lines(f)[1].strip() == "h_x = h(x)"
    assert anf_function(f, locals())(2) == 8

    def f(x):
        return x * x * x

    assert 'return' in anf_lines(f)[-1] and '*' not in anf_lines(f)[-1]
    assert anf_function(f)(2) == 8

    def f(x):
        y = [(x.y[0], ), 3]
        y += x * f(x[g(x)].b, (3, x / -2))

    assert anf.anf(quoting.parse_function(f))
Ejemplo n.º 12
0
def replace(template,
            replace_grad=Replace.PARTIAL,
            namer=None,
            **replacements):
    """Replace placeholders in a Python template (quote).

  Args:
    template: A function, AST node or string to be used as a template. Note
        that if a function is passed, any placeholder is expected to also be a
        function argument. If a string is passed, it must represent valid
        Python code, and any variable it references is a placeholder.
    replace_grad: If Replace.NONE, statements of the form `d[x]` are ignored.
        For the other possible values, see `ReplaceGradTransformer`.
    namer: See `ReplaceGradTransformer`.
    **replacements: A mapping from placeholder names to (lists of) AST nodes
        that these placeholders will be replaced by. If a string is passed,
        `quote` will be called on it to turn it into a node.

  Returns:
    body: An AST node or list of AST nodes with the replacements made. If the
        template was a function, a list will be returned. If the template was a
        node, the same node will be returned. If the template was a string, an
        AST node will be returned (a `Module` node in the case of a multi-line
        string, an `Expr` node otherwise).

  Raises:
    ValueError: If a function is used as a template and an incorrect set of
        replacements was passed.
  """
    # Handle the 3 different types of templates: funcs, nodes, and strings
    is_function = isinstance(template, types.FunctionType)
    if is_function:
        tree = quoting.parse_function(template).body[0]
        placeholders = set(arg.id for arg in tree.args.args)
        tree.args.args = []
        if tree.args.vararg:
            placeholders.add(tree.args.vararg)
            tree.args.vararg = None
        if set(replacements.keys()) != placeholders:
            raise ValueError('too many or few replacements')
    elif isinstance(template, gast.AST):
        tree = template
    else:
        tree = quoting.quote(template, return_expr=True)
    # If the replacements are strings, turn them into nodes
    for k, v in replacements.items():
        if isinstance(v, six.string_types):
            replacements[k] = quoting.quote(v)
    # Perform the replacement
    ReplaceTransformer(replacements).visit(tree)
    # Handle the d[x] operator
    if replace_grad is not Replace.NONE:
        rgt = ReplaceGradTransformer(replace_grad=replace_grad,
                                     namer=namer,
                                     tangent=replace_grad is Replace.TANGENT)
        rgt.visit(tree)
    # Return the AST node with replacements made
    if is_function:
        return tree.body
    else:
        return tree
Ejemplo n.º 13
0
def anf_lines(f):
    """Return the ANF transformed source code as lines."""
    return quoting.unquote(anf.anf(quoting.parse_function(f))).split('\n')
Ejemplo n.º 14
0
def _assert_tangent_parse_error(func, fragment):
    try:
        fence.validate(quoting.parse_function(func), inspect.getsource(func))
        assert False
    except fence.TangentParseError as expected:
        assert fragment in str(expected)
Ejemplo n.º 15
0
  def visit_Call(self, node):
    """Create adjoint for call.

    We don't allow unpacking of parameters, so we know that each argument
    gets passed in explicitly, allowing us to create partials for each.
    However, templates might perform parameter unpacking (for cases where
    the number of arguments is variable) and express their gradient as a
    tuple. In this case, we have to unpack this tuple of partials.
    """
    # Find the function we are differentiating
    func = anno.getanno(node, 'func')

    if func in non_differentiable.NON_DIFFERENTIABLE:
      return node, []

    if func == tracing.Traceable:
      return self.primal_and_adjoint_for_tracing(node)

    if func in grads.UNIMPLEMENTED_ADJOINTS:
      raise errors.ReverseNotImplementedError(func)


    # If we don't have an adjoint, we will have to step into the called
    # function and differentiate it
    if func not in grads.adjoints:
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if arg.id in self.active_variables)

      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      pri_name = naming.primal_name(func, active_args)
      pri_call = gast.Call(
          func=gast.Name(id=pri_name, ctx=gast.Load(), annotation=None),
          args=[self.substack] + node.args,
          keywords=node.keywords)
      anno.setanno(pri_call, 'pri_call', True)

      dy = create.create_grad(self.target, self.namer)
      dy.ctx = gast.Load()
      dx = create.create_grad(node.args[0], self.namer)
      dx.ctx = gast.Store()
      adj_name = naming.adjoint_name(func, active_args)
      adj_call = gast.Call(
          func=gast.Name(id=adj_name, ctx=gast.Load(), annotation=None),
          args=[self.substack, dy] + node.args,
          keywords=node.keywords)
      anno.setanno(adj_call, 'adj_call', True)
      adjoint = [template.replace('dxs = dfx', namer=self.namer, dfx=adj_call)]
      for j, i in enumerate(active_args):
        adjoint.append(template.replace('d[x] = dxs[i]', namer=self.namer,
                                        x=node.args[i].id, i=gast.Num(n=j)))
      return pri_call, adjoint

    # We have a template for the gradient that we need to fill in
    template_ = grads.adjoints[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)

    # Fill in any missing kwargs with the defaults from the template
    args = quoting.parse_function(template_).body[0].args
    kwargs = dict(zip(*map(reversed, [args.args, args.defaults])))
    kwargs.update(dict(zip(args.kwonlyargs, args.kw_defaults)))
    for arg, val in kwargs.items():
      if arg.id not in bound_args.arguments:
        bound_args.arguments[arg.id] = val

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: ast_.copy_node(self.target)}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    packing = []
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())
      packing = [gast.Assign(targets=[target], value=value)]

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(
          template_).co_varnames[-1]] = target
    adjoint = template.replace(template_, namer=self.namer, **arg_replacements)
    unpacking = []
    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [create.create_temp_grad(arg, self.namer)
                  for arg in to_pack]
      value = create.create_grad(target, self.namer)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())
      unpacking = [gast.Assign(targets=[target], value=value)]

    return node, packing + adjoint + unpacking
Ejemplo n.º 16
0
def test_parses(func):
  """Test all functions parse."""
  quoting.parse_function(func)
Ejemplo n.º 17
0
  def visit_Call(self, node):
    if not self.target:
      return node
    func = anno.getanno(node, 'func')

    if func in tangents.UNIMPLEMENTED_TANGENTS:
      raise errors.ForwardNotImplementedError(func)

    if func == tracing.Traceable:
      raise NotImplementedError('Tracing of %s is not enabled in forward mode' %
                                quoting.unquote(node))

    if func not in tangents.tangents:
      try:
        quoting.parse_function(func)
      except:
        raise ValueError('No tangent found for %s, and could not get source.' %
                         func.__name__)

      # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy)
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if isinstance(arg, gast.Name))
      # TODO: Stack arguments are currently not considered
      # active, but for forward-mode applied to call trees,
      # they have to be. When we figure out how to update activity
      # analysis to do the right thing, we'll want to add the extra check:
      # `and arg.id in self.active_variables`

      # TODO: Duplicate of code in reverse_ad.
      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      fn_name = naming.tangent_name(func, active_args)
      orig_args = quoting.parse_function(func).body[0].args
      tangent_keywords = []
      for i in active_args:
        grad_node = create.create_grad(node.args[i], self.namer, tangent=True)
        arg_grad_node = create.create_grad(
            orig_args.args[i], self.namer, tangent=True)
        grad_node.ctx = gast.Load()
        tangent_keywords.append(
            gast.keyword(arg=arg_grad_node.id, value=grad_node))
      # Update the original call
      rhs = gast.Call(
          func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None),
          args=node.args,
          keywords=tangent_keywords + node.keywords)
      # Set self.value to False to trigger whole primal replacement
      self.value = False
      return [rhs]

    template_ = tangents.tangents[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)
    bound_args.apply_defaults()

    # If any keyword arguments weren't passed, we fill them using the
    # defaults of the original function
    if grads.DEFAULT in bound_args.arguments.values():
      # Build a mapping from names to defaults
      args = quoting.parse_function(func).body[0].args
      defaults = {}
      for arg, default in zip(*map(reversed, [args.args, args.defaults])):
        defaults[arg.id] = default
      for arg, default in zip(args.kwonlyargs, args.kw_defaults):
        if default is not None:
          defaults[arg.id] = default
      for name, value in bound_args.arguments.items():
        if value is grads.DEFAULT:
          bound_args.arguments[name] = defaults[name]

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: self.tmp_node}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(template_).co_varnames[
          -1]] = target
    tangent_node = template.replace(
        template_,
        replace_grad=template.Replace.TANGENT,
        namer=self.namer,
        **arg_replacements)

    # If the template uses the answer in the RHS of the tangent,
    # we need to make sure that the regular answer is replaced
    # with self.tmp_node, but that the gradient is not. We have
    # to be extra careful for statements like a = exp(a), because
    # both the target and RHS variables have the same name.
    tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True)
    tmp_grad_name = tmp_grad_node.id
    ans_grad_node = create.create_grad(self.target, self.namer, tangent=True)
    for _node in tangent_node:
      for succ in gast.walk(_node):
        if isinstance(succ, gast.Name) and succ.id == tmp_grad_name:
          succ.id = ans_grad_node.id

    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [
          create.create_temp_grad(arg, self.namer, True) for arg in to_pack
      ]
      value = create.create_grad(target, self.namer, tangent=True)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())

    # Stack pops have to be special-cased, we have
    # to set the 'push' attribute, so we know that if we
    # remove this pop, we have to remove the equivalent push.
    # NOTE: this only works if we're doing forward-over-reverse,
    # where reverse is applied in joint mode, with no call tree.
    # Otherwise, the pushes and pops won't be matched within a single
    # function call.
    if func == tangent.pop:
      if len(self.metastack):
        anno.setanno(tangent_node[0], 'push', self.metastack.pop())
      else:
        anno.setanno(tangent_node[0], 'push', None)
    return tangent_node
Ejemplo n.º 18
0
def anf_function(f, globals_=None):
    m = gast.gast_to_ast(anf.anf(quoting.parse_function(f)))
    m = gast.fix_missing_locations(m)
    exec(compile(m, '<string>', 'exec'), globals_)
    return f