Esempio n. 1
0
    def visit_FunctionDef(self, node):
        self.namer = naming.Namer.build(node)

        # Get the tangent of the body
        new_body = []
        for n in node.body:
            new = self.visit(n)
            if isinstance(new, (list, tuple)):
                new_body.extend(new)
            else:
                new_body.append(new)
        node.body = new_body

        # Add in the initial gradient argument
        grad_args = [
            create.create_grad(arg, self.namer, tangent=True)
            for i, arg in enumerate(node.args.args) if i in self.wrt
        ]
        if len(self.wrt) != len(grad_args):
            raise ValueError(
                'Mismatch between requested and retrieved derivative arguments. '
                'Requested %d, found %d') % (len(self.wrt), len(grad_args))

        node.args.args += grad_args

        if self.check_dims:
            # Define the shape check code quote
            def shape_match_template(primal, tangent_):
                if not tangent.shapes_match(primal, tangent_):
                    raise ValueError(
                        'Shape mismatch between argument value (%s) and seed derivative '
                        '(%s)' \
                  % (numpy.shape(primal), numpy.shape(tangent_)))

            # Add a shape check for each seed derivative & primal pair.
            shape_check_nodes = []
            for iwrt, tangent_var in zip(self.wrt, grad_args):
                primal = node.args.args[iwrt]
                shape_check = template.replace(shape_match_template,
                                               primal=primal,
                                               tangent_=tangent_var)[0]
                shape_check_nodes.append(shape_check)
            node.body = shape_check_nodes + node.body

        # Add in gradient initialization statements for everything else
        grad_init_nodes = [
            template.replace('d[z] = init_grad(z)',
                             replace_grad=template.Replace.TANGENT,
                             namer=self.namer,
                             z=arg,
                             init_grad=utils.INIT_GRAD)
            for i, arg in enumerate(node.args.args) if i not in self.wrt
        ]
        node.body = grad_init_nodes + node.body

        # Rename the function
        func = anno.getanno(node, 'func')
        node.name = naming.tangent_name(func, self.wrt)

        return node
Esempio n. 2
0
    def visit_Expr(self, node):
        # We need to special-case pushes, e.g. utils.push(_stack,x,op_id)
        adjoint = []
        if (isinstance(node.value, gast.Call)
                and anno.getanno(node.value, 'func') == utils.push):
            orig_src = quoting.unquote(node)
            stack, val, op_id = node.value.args
            push_template = grads.adjoints[utils.push]
            adjoint_rhs = template.replace(push_template,
                                           namer=self.namer,
                                           stack=stack,
                                           val=val,
                                           op_id=op_id)

            # Walk the RHS adjoint AST to find temporary adjoint variables to
            # sum
            accumulation = template.replace(
                'd[xi] = add_grad(d[xi], dxi_partial)',
                namer=self.namer,
                replace_grad=template.Replace.FULL,
                xi=val,
                dxi_partial=ast_.copy_node(adjoint_rhs[0].targets[0]),
                add_grad=utils.ADD_GRAD)
            adjoint = adjoint_rhs + [accumulation]
            for i, adj in enumerate(adjoint):
                adjoint[i] = comments.add_comment(adj,
                                                  'Grad of: %s' % orig_src)
        return node, adjoint
Esempio n. 3
0
  def visit_If(self, node):
    # Get the primal and adjoint of the blocks
    body, adjoint_body = self.visit_statements(node.body)
    orelse, adjoint_orelse = self.visit_statements(node.orelse)

    # We will store the condition on the stack
    cond = self.namer.cond()
    push, pop, op_id = get_push_pop()

    # Fill in the templates
    primal_template = grads.primals[gast.If]
    primal = template.replace(
        primal_template,
        body=body,
        cond=cond,
        test=node.test,
        orelse=orelse,
        push=push,
        _stack=self.stack,
        op_id=op_id)
    adjoint_template = grads.adjoints[gast.If]
    adjoint = template.replace(
        adjoint_template,
        cond=cond,
        adjoint_body=adjoint_body,
        adjoint_orelse=adjoint_orelse,
        pop=pop,
        _stack=self.stack,
        op_id=op_id)
    return primal, adjoint
Esempio n. 4
0
  def visit_While(self, node):
    if node.orelse:
      raise ValueError

    body, adjoint_body = self.visit_statements(node.body)

    # We create a loop counter which will be pushed on the stack
    push, pop, op_id = get_push_pop()
    counter = self.namer.counter()

    primal_template = grads.primals[gast.While]
    primal = template.replace(
        primal_template,
        namer=self.namer,
        body=body,
        i=counter,
        push=push,
        test=node.test,
        _stack=self.stack,
        op_id=op_id)

    adjoint_template = grads.adjoints[gast.While]
    adjoint = template.replace(
        adjoint_template,
        namer=self.namer,
        adjoint_body=adjoint_body,
        i=counter,
        pop=pop,
        _stack=self.stack,
        op_id=op_id)

    return primal, adjoint
Esempio n. 5
0
  def visit_Assign(self, node):
    """Visit assignment statement.

    Notes
    -----
    This method sets the `self.target` attribute to the first assignment
    target for callees to use.
    """

    # Generate tangent name of the assignment
    self.target = node.targets[0]
    self.value = node.value

    # Get the tangent node
    tangent_node = self.visit(self.value)

    # If no forward-mode statement was created, then no extra work is needed.
    if tangent_node == self.value:
      self.target = None
      return node

    if self.value:
      new_node = template.replace(
          tangents.tangents[gast.Assign],
          replace_grad=template.Replace.TANGENT,
          namer=self.namer,
          temp=self.tmp_node,
          tangent=tangent_node,
          target=self.target,
          value=self.value)
      # Ensure that we use a unique tmp node per primal/tangent pair
      self.reset_tmp_node()
    else:
      # We're already in ANF form right now,
      # so we can assume LHS is a single Name "z"
      def template_(z, f):
        tmp = f
        d[z] = tmp[0]
        z = tmp[1]

      new_node = template.replace(
          template_,
          replace_grad=template.Replace.TANGENT,
          namer=self.namer,
          z=self.target,
          f=tangent_node[0])

    # Add it after the original statement
    self.target = None

    # Add in some cool comments
    for i in range(len(new_node)):
      new_node[i] = comments.add_comment(new_node[i], 'Primal and tangent of: '
                                         '%s' % quoting.unquote(node))
    return new_node
Esempio n. 6
0
 def visit_Subscript(self, node):
     adjoint = template.replace('d[x[i]] = d[y]',
                                namer=self.namer,
                                y=self.target,
                                x=node.value,
                                i=node.slice.value)
     return node, adjoint
Esempio n. 7
0
 def visit_UnaryOp(self, node):
   op = type(node.op)
   if op not in grads.adjoints:
     raise ValueError('unknown unary operator')
   adjoint_template = grads.adjoints[op]
   adjoint = template.replace(adjoint_template, namer=self.namer,
                              x=node.operand, y=self.target)
   return node, adjoint
Esempio n. 8
0
  def primal_and_adjoint_for_tracing(self, node):
    """Build the primal and adjoint of a traceable function.

    Args:
      node: ast.Call node of a function we wish to trace, instead of transform

    Returns:
      primal: new ast.Assign node to replace the original primal call
      adjoint: new ast.Assign node using the VJP generated in primal to
        calculate the adjoint.
    """
    primal_template = grads.primals[tracing.Traceable]
    adjoint_template = grads.adjoints[tracing.Traceable]

    # Prep
    to_pack = node.args
    target = ast_.copy_node(self.orig_target)
    vjp = quoting.quote(self.namer.unique('%s_grad' % node.func.id))
    tmp = create.create_temp(quoting.quote('tmp'), self.namer)
    assert len(node.keywords) == 0

    # Full replacement of primal
    # TODO: do we need to set 'pri_call' on this?
    primal = template.replace(
        primal_template,
        namer=self.namer,
        result=target,
        fn=node.func,
        tmp=tmp,
        vjp=vjp,
        args=gast.Tuple(elts=to_pack, ctx=gast.Load()))

    # Building adjoint using the vjp generated with the primal
    dto_pack = gast.Tuple(
        elts=[create.create_temp_grad(arg, self.namer) for arg in to_pack],
        ctx=gast.Store())

    adjoint = template.replace(
        adjoint_template,
        namer=self.namer,
        result=target,
        vjp=vjp,
        dargs=dto_pack)

    return primal, adjoint
Esempio n. 9
0
 def visit_container(self, node):
     adjoint = []
     for i, elt in enumerate(node.elts):
         adjoint.append(
             template.replace('d[x] = d[t[i]]',
                              namer=self.namer,
                              t=self.target,
                              i=gast.Num(n=i),
                              x=elt))
     return node, adjoint
Esempio n. 10
0
def test_variable_replace():
    def f(x):
        x = 2
        return x

    body = template.replace(f, x=gast.Name(id='y', ctx=None, annotation=None))
    assert body[0].targets[0].id == 'y'
    assert isinstance(body[0].targets[0].ctx, gast.Store)
    assert isinstance(body[1].value.ctx, gast.Load)
    compile_.compile_function(_wrap(body))
Esempio n. 11
0
  def visit_For(self, node):
    if node.orelse:
      raise ValueError

    # Construct the primal and adjoint of the loop
    body, adjoint_body = self.visit_statements(node.body)

    # We create a loop counter which will be pushed on the stack
    push, pop, op_id = get_push_pop()
    counter = self.namer.counter()

    # In `for i in range ...` the variable `i` is the target, which we
    # temporarily set aside each iteration to push to the stack later
    push_target, pop_target, op_id_target = get_push_pop()
    tmp_target = create.create_temp(node.target, self.namer)

    primal_template = grads.primals[gast.For]
    primal = template.replace(
        primal_template,
        body=body,
        i=counter,
        push=push,
        target=node.target,
        iter_=node.iter,
        push_target=push_target,
        _target=tmp_target,
        _stack=self.stack,
        op_id_iter=op_id,
        op_id_target=op_id_target)

    adjoint_template = grads.adjoints[gast.For]
    adjoint = template.replace(
        adjoint_template,
        adjoint_body=adjoint_body,
        i=counter,
        pop=pop,
        pop_target=pop_target,
        target=ast_.copy_node(node.target),
        _stack=self.stack,
        op_id_iter=op_id,
        op_id_target=op_id_target)

    return primal, adjoint
Esempio n. 12
0
    def visit_UnaryOp(self, node):
        if not self.target:
            return node

        template_ = tangents.tangents[node.op.__class__]
        tangent_node = template.replace(template=template_,
                                        replace_grad=template.Replace.TANGENT,
                                        namer=self.namer,
                                        x=node.operand,
                                        z=self.target)
        return tangent_node
Esempio n. 13
0
def test_statement_replace():
    def f(body):
        body

    body = [
        gast.Expr(value=gast.Name(id=var, ctx=gast.Load(), annotation=None))
        for var in 'xy'
    ]
    new_body = template.replace(f, body=body)
    assert len(new_body) == 2
    assert isinstance(new_body[0], gast.Expr)
    compile_.compile_function(_wrap(new_body))
Esempio n. 14
0
 def visit_List(self, node):
     if not self.target:
         return node
     dlist = self.create_grad_list(node)
     tangent_node = [
         template.replace('d[z] = dlist',
                          replace_grad=template.Replace.TANGENT,
                          namer=self.namer,
                          z=self.target,
                          dlist=dlist)
     ]
     return tangent_node
Esempio n. 15
0
    def visit_Subscript(self, node):
        """Tangent of e.g.
    x = y[i]"""
        if not self.target:
            return node

        template_ = tangents.tangents[node.__class__]
        tangent_node = template.replace(template=template_,
                                        replace_grad=template.Replace.TANGENT,
                                        namer=self.namer,
                                        x=node,
                                        z=self.target)
        return tangent_node
Esempio n. 16
0
def test_function_replace():
    def f(f, args):
        def f(args):
            pass

    body = template.replace(
        f,
        f='g',
        args=[gast.Name(id=arg, ctx=None, annotation=None) for arg in 'ab'])
    assert isinstance(body[0], gast.FunctionDef)
    assert body[0].name == 'g'
    assert len(body[0].args.args) == 2
    assert isinstance(body[0].args.args[0].ctx, gast.Param)
    assert body[0].args.args[1].id == 'b'
    compile_.compile_function(_wrap(body))
Esempio n. 17
0
    def visit_Name(self, node):
        """Tangent of e.g.
    x = y"""
        if not self.target:
            return node

        # The gast representation of 'None' is as a name in some version fo Python,
        # not as a special NameConstant. So, we have to do a bit of
        # special-casing here. Ideally, this is fixed in gast.
        if node.id in ['None', 'True', 'False']:
            template_ = 'd[z] = x'
        else:
            template_ = tangents.tangents[node.__class__]
        tangent_node = template.replace(template=template_,
                                        replace_grad=template.Replace.TANGENT,
                                        namer=self.namer,
                                        x=node,
                                        z=self.target)
        return tangent_node
Esempio n. 18
0
  def visit_For(self, node):
    # If the iter is a Name that is active,
    # we need to rewrite the loop.
    # Iterators of the form `for a in x` rely on an implicit
    # indexing operation, which Tangent cannot reverse without
    # more information. So, we will create an explicit
    # indexing operation. Note that we will use
    # integer indexes, which will cause strange behavior if
    # the iterator's `next()` behavior deviates from
    # a plain incrementing index.
    # The right thing to do (eventually) is to write a multiple-dispatch
    # version of the `next` operator, and its adjoint, so that
    # we can handle e.g. dicts.

    if isinstance(node.iter, (gast.Name, gast.Subscript, gast.Attribute)):
      iter_name = ast.get_name(node.iter)
      if iter_name in anno.getanno(node, 'active_in'):
        # for a in x:
        #   f(a)
        # # becomes
        # for i in range(len(x)):
        #   a = x[i]
        #   f(a)

        # Get a unique iterator name
        old_target = copy.deepcopy(node.target)
        new_target = quoting.quote(self.namer.unique('_idx'))
        old_iter = copy.deepcopy(node.iter)

        item_access = template.replace(
          'old_target = x[i]',
          old_target=old_target,
          x=old_iter,
          i=new_target)

        node.target = gast.Name(id=new_target.id, ctx=gast.Store(), annotation=None)
        node.iter = quoting.quote('range(len(%s))' % iter_name)
        anno.setanno(node.iter, 'func', range)
        anno.setanno(node.iter.args[0], 'func', len)
        node.body = [item_access] + node.body

    return node
Esempio n. 19
0
  def visit_FunctionDef(self, node):
    self.namer = naming.Namer.build(node)

    # Get the tangent of the body
    new_body = []
    for n in node.body:
      new = self.visit(n)
      if isinstance(new, (list, tuple)):
        new_body.extend(new)
      else:
        new_body.append(new)
    node.body = new_body

    # Add in the initial gradient argument
    grad_args = [
        create.create_grad(arg, self.namer, tangent=True)
        for i, arg in enumerate(node.args.args) if i in self.wrt
    ]

    node.args.args += grad_args

    # Add in gradient initialization statements for everything else
    grad_init_nodes = [
        template.replace(
            'd[z] = init_grad(z)',
            replace_grad=template.Replace.TANGENT,
            namer=self.namer,
            z=arg,
            init_grad=utils.INIT_GRAD) for i, arg in enumerate(node.args.args)
        if i not in self.wrt
    ]
    node.body = grad_init_nodes + node.body

    # Rename the function
    func = anno.getanno(node, 'func')
    node.name = naming.tangent_name(func, self.wrt)

    return node
Esempio n. 20
0
  def visit_Expr(self, node):
    # Special-case the push() expression (have to reverse the usual
    # tangent/primal order)

    if isinstance(node.value, gast.Call):
      fn = anno.getanno(node.value, 'func')
      if fn in [tangent.push, tangent.push_stack]:
        # Save the pop associated with this push
        template_ = tangents.tangents[fn]
        stack_node, var_node, op_id = node.value.args
        tangent_node = template.replace(
            template_,
            replace_grad=template.Replace.TANGENT,
            namer=self.namer,
            x=var_node,
            stack=stack_node,
            op_id=op_id)
        # Push the original node and the tangent_node push
        # onto a "meta-stack", so we can track the relationship
        # between the pushes and pops in tangent mode
        self.metastack.append(tangent_node[0])
        return [node] + tangent_node
    return node
Esempio n. 21
0
def test_node_replace():
    node = template.replace(quoting.quote("a = b"), a="y", b="x * 2")
    assert quoting.unquote(node) == "y = x * 2"
Esempio n. 22
0
  def visit_FunctionDef(self, node):
    # Construct a namer to guarantee we create unique names that don't
    # override existing names
    self.namer = naming.Namer.build(node)

    # Check that this function has exactly one return statement at the end
    return_nodes = [n for n in gast.walk(node) if isinstance(n, gast.Return)]
    if ((len(return_nodes) > 1) or not isinstance(node.body[-1], gast.Return)):
      raise ValueError('function must have exactly one return statement')
    return_node = ast_.copy_node(return_nodes[0])

    # Perform AD on the function body
    body, adjoint_body = self.visit_statements(node.body[:-1])

    # Annotate the first statement of the primal and adjoint as such
    if body:
      body[0] = comments.add_comment(body[0], 'Beginning of forward pass')
    if adjoint_body:
      adjoint_body[0] = comments.add_comment(
          adjoint_body[0], 'Beginning of backward pass')

    # Before updating the primal arguments, extract the arguments we want
    # to differentiate with respect to
    dx = gast.Tuple([create.create_grad(node.args.args[i], self.namer)
                     for i in self.wrt], ctx=gast.Load())

    if self.preserve_result:
      # Append an extra Assign operation to the primal body
      # that saves the original output value
      stored_result_node = quoting.quote(self.namer.unique('result'))
      assign_stored_result = template.replace(
          'result=orig_result',
          result=stored_result_node,
          orig_result=return_node.value)
      body.append(assign_stored_result)
      dx.elts.append(stored_result_node)

    for _dx in dx.elts:
      _dx.ctx = gast.Load()
    return_dx = gast.Return(value=dx)

    # We add the stack as first argument of the primal
    node.args.args = [self.stack] + node.args.args

    # Rename the function to its primal name
    func = anno.getanno(node, 'func')
    node.name = naming.primal_name(func, self.wrt)

    # The new body is the primal body plus the return statement
    node.body = body + node.body[-1:]

    # Find the cost; the first variable of potentially multiple return values
    # The adjoint will receive a value for the initial gradient of the cost
    y = node.body[-1].value
    if isinstance(y, gast.Tuple):
      y = y.elts[0]
    dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(),
                   annotation=None)

    # Construct the adjoint
    adjoint_template = grads.adjoints[gast.FunctionDef]
    adjoint, = template.replace(adjoint_template, namer=self.namer,
                                adjoint_body=adjoint_body, return_dx=return_dx)
    adjoint.args.args.extend([self.stack, dy])
    adjoint.args.args.extend(node.args.args[1:])
    adjoint.name = naming.adjoint_name(func, self.wrt)

    return node, adjoint
Esempio n. 23
0
  def visit_Call(self, node):
    if not self.target:
      return node
    func = anno.getanno(node, 'func')

    if func in tangents.UNIMPLEMENTED_TANGENTS:
      raise errors.ForwardNotImplementedError(func)

    if func == tracing.Traceable:
      raise NotImplementedError('Tracing of %s is not enabled in forward mode' %
                                quoting.unquote(node))

    if func not in tangents.tangents:
      try:
        quoting.parse_function(func)
      except:
        raise ValueError('No tangent found for %s, and could not get source.' %
                         func.__name__)

      # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy)
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if isinstance(arg, gast.Name))
      # TODO: Stack arguments are currently not considered
      # active, but for forward-mode applied to call trees,
      # they have to be. When we figure out how to update activity
      # analysis to do the right thing, we'll want to add the extra check:
      # `and arg.id in self.active_variables`

      # TODO: Duplicate of code in reverse_ad.
      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      fn_name = naming.tangent_name(func, active_args)
      orig_args = quoting.parse_function(func).body[0].args
      tangent_keywords = []
      for i in active_args:
        grad_node = create.create_grad(node.args[i], self.namer, tangent=True)
        arg_grad_node = create.create_grad(
            orig_args.args[i], self.namer, tangent=True)
        grad_node.ctx = gast.Load()
        tangent_keywords.append(
            gast.keyword(arg=arg_grad_node.id, value=grad_node))
      # Update the original call
      rhs = gast.Call(
          func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None),
          args=node.args,
          keywords=tangent_keywords + node.keywords)
      # Set self.value to False to trigger whole primal replacement
      self.value = False
      return [rhs]

    template_ = tangents.tangents[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)
    bound_args.apply_defaults()

    # If any keyword arguments weren't passed, we fill them using the
    # defaults of the original function
    if grads.DEFAULT in bound_args.arguments.values():
      # Build a mapping from names to defaults
      args = quoting.parse_function(func).body[0].args
      defaults = {}
      for arg, default in zip(*map(reversed, [args.args, args.defaults])):
        defaults[arg.id] = default
      for arg, default in zip(args.kwonlyargs, args.kw_defaults):
        if default is not None:
          defaults[arg.id] = default
      for name, value in bound_args.arguments.items():
        if value is grads.DEFAULT:
          bound_args.arguments[name] = defaults[name]

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: self.tmp_node}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(template_).co_varnames[
          -1]] = target
    tangent_node = template.replace(
        template_,
        replace_grad=template.Replace.TANGENT,
        namer=self.namer,
        **arg_replacements)

    # If the template uses the answer in the RHS of the tangent,
    # we need to make sure that the regular answer is replaced
    # with self.tmp_node, but that the gradient is not. We have
    # to be extra careful for statements like a = exp(a), because
    # both the target and RHS variables have the same name.
    tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True)
    tmp_grad_name = tmp_grad_node.id
    ans_grad_node = create.create_grad(self.target, self.namer, tangent=True)
    for _node in tangent_node:
      for succ in gast.walk(_node):
        if isinstance(succ, gast.Name) and succ.id == tmp_grad_name:
          succ.id = ans_grad_node.id

    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [
          create.create_temp_grad(arg, self.namer, True) for arg in to_pack
      ]
      value = create.create_grad(target, self.namer, tangent=True)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())

    # Stack pops have to be special-cased, we have
    # to set the 'push' attribute, so we know that if we
    # remove this pop, we have to remove the equivalent push.
    # NOTE: this only works if we're doing forward-over-reverse,
    # where reverse is applied in joint mode, with no call tree.
    # Otherwise, the pushes and pops won't be matched within a single
    # function call.
    if func == tangent.pop:
      if len(self.metastack):
        anno.setanno(tangent_node[0], 'push', self.metastack.pop())
      else:
        anno.setanno(tangent_node[0], 'push', None)
    return tangent_node
Esempio n. 24
0
  def visit_Assign(self, node):
    """Visit assignment statement."""
    if len(node.targets) != 1:
      raise ValueError('no support for chained assignment')

    # Before the node gets modified, get a source code representation
    # to add as a comment later on
    if anno.hasanno(node, 'pre_anf'):
      orig_src = anno.getanno(node, 'pre_anf')
    else:
      orig_src = quoting.unquote(node)

    # Set target for the RHS visitor to access
    self.orig_target = ast_.copy_node(node.targets[0])

    # If we know we're going to be putting another stack on the stack,
    # we should try to make that explicit
    if isinstance(node.value, gast.Call) and \
        anno.hasanno(node.value, 'func') and \
        anno.getanno(node.value, 'func') in (utils.Stack, utils.pop_stack):
      push, pop, op_id = get_push_pop_stack()
    else:
      push, pop, op_id = get_push_pop()
    push_stack, pop_stack, op_id_stack = get_push_pop_stack()

    # Every assignment statement requires us to store the pre-value, and in the
    # adjoint restore the value, and reset the gradient
    store = template.replace(
        'push(_stack, y, op_id)',
        push=push,
        y=self.orig_target,
        _stack=self.stack,
        op_id=op_id)
    create_substack = template.replace(
        'substack = tangent.Stack()', substack=self.substack)
    store_substack = template.replace(
        'push(stack, substack, op_id)',
        push=push_stack,
        stack=self.stack,
        substack=self.substack,
        op_id=op_id_stack)
    restore = template.replace(
        'y = pop(_stack, op_id)',
        _stack=self.stack,
        pop=pop,
        y=ast_.copy_node(self.orig_target),
        op_id=op_id)
    restore_substack = template.replace(
        'substack = pop(stack, op_id)',
        pop=pop_stack,
        stack=self.stack,
        substack=self.substack,
        op_id=op_id_stack)
    reset = template.replace(
        'd[y] = init_grad(y, allow_lazy_initializer=True)',
        y=self.orig_target,
        init_grad=utils.INIT_GRAD,
        namer=self.namer,
        replace_grad=template.Replace.FULL)

    # If there are no active nodes, we don't need to find an adjoint
    # We simply store and restore the state, and reset the gradient
    if not self.is_active(node):
      return [store, node], [restore, reset]

    # We create a temporary variable for the target that the RHS can use
    self.target = create.create_temp(self.orig_target, self.namer)
    create_tmp = template.replace(
        'tmp = y', tmp=self.target, y=self.orig_target)

    # Get the primal and adjoint of the RHS expression
    try:
      fx, adjoint_rhs = self.visit(node.value)
    except ValueError as e:
      context = [t.id if hasattr(t, 'id') else t for t in node.targets]
      raise ValueError(
          'Failed to process assignment to: %s. Error: %s' % (context, e))
    if not isinstance(adjoint_rhs, list):
      adjoint_rhs = [adjoint_rhs]

    # Walk the RHS adjoint AST to find temporary adjoint variables to sum
    accumulations = []
    for n in adjoint_rhs:
      for succ in gast.walk(n):
        if anno.hasanno(succ, 'temp_adjoint_var'):
          xi = anno.getanno(succ, 'temp_adjoint_var')
          dxi_partial = ast_.copy_node(succ)
          accumulations.append(template.replace(
              'd[xi] = add_grad(d[xi], dxi_partial)',
              namer=self.namer, replace_grad=template.Replace.FULL,
              xi=xi, dxi_partial=dxi_partial, add_grad=utils.ADD_GRAD))

    # The primal consists of storing the state and then performing the
    # assignment with the new primal.
    # The primal `fx` may be optionally (but rarely) redefined when the
    # adjoint is generated, in `fx, adjoint_rhs = self.visit(node.value)`.
    # If we see that the primal value is an Assign node, or a list of nodes
    # (with at least one being an Assign) node, we allow the primal to change.
    # Otherwise, we'll build our own Assign node.
    if isinstance(fx, gast.Assign):
      assign = [fx]
    elif (isinstance(fx, list) and
          any([isinstance(ifx, gast.Assign) for ifx in fx])):
      assign = fx
    else:
      assign = template.replace(
          'y = fx', y=ast_.copy_node(self.orig_target), fx=fx)
      assign = [assign]
    primal = [store, create_substack, store_substack] + assign

    # The adjoint involves creating the temporary, restoring the store,
    # calculating the adjoint, resetting the gradient, and finally accumulating
    # the partials
    adjoint = [create_tmp, restore_substack, restore
              ] + adjoint_rhs + [reset] + accumulations

    # If the LHS is a subscript assignment with variable index, we need to
    # store and restore that as well
    if (isinstance(self.orig_target, gast.Subscript) and
        isinstance(self.orig_target.slice.value, gast.Name)):
      push, pop, op_id = get_push_pop()
      i = self.orig_target.slice.value
      push_index = template.replace(
          'push(_stack, i, op_id)',
          push=push,
          i=i,
          _stack=self.stack,
          op_id=op_id)
      pop_index = template.replace(
          'i = pop(_stack, op_id)',
          pop=pop,
          i=i,
          _stack_=self.stack,
          op_id=op_id)

      primal.insert(len(primal), push_index)
      adjoint.insert(0, pop_index)

    # Add a comment in the backwards pass, indicating which
    # lines in the forward pass generated the adjoint
    for i, adj in enumerate(adjoint):
      adjoint[i] = comments.add_comment(adj, 'Grad of: %s' % orig_src)

    return primal, adjoint
Esempio n. 25
0
 def visit_Name(self, node):
   adjoint = template.replace('d[x] = tangent.copy(d[y])',
                              namer=self.namer, x=node, y=self.target)
   return node, adjoint
Esempio n. 26
0
def test_string_replace():
    node = template.replace("a = b", a="y", b="x * 2")
    assert quoting.unquote(node) == "y = x * 2"
Esempio n. 27
0
  def visit_Call(self, node):
    """Create adjoint for call.

    We don't allow unpacking of parameters, so we know that each argument
    gets passed in explicitly, allowing us to create partials for each.
    However, templates might perform parameter unpacking (for cases where
    the number of arguments is variable) and express their gradient as a
    tuple. In this case, we have to unpack this tuple of partials.
    """
    # Find the function we are differentiating
    func = anno.getanno(node, 'func')

    if func in non_differentiable.NON_DIFFERENTIABLE:
      return node, []

    if func == tracing.Traceable:
      return self.primal_and_adjoint_for_tracing(node)

    if func in grads.UNIMPLEMENTED_ADJOINTS:
      raise errors.ReverseNotImplementedError(func)


    # If we don't have an adjoint, we will have to step into the called
    # function and differentiate it
    if func not in grads.adjoints:
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if arg.id in self.active_variables)

      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      pri_name = naming.primal_name(func, active_args)
      pri_call = gast.Call(
          func=gast.Name(id=pri_name, ctx=gast.Load(), annotation=None),
          args=[self.substack] + node.args,
          keywords=node.keywords)
      anno.setanno(pri_call, 'pri_call', True)

      dy = create.create_grad(self.target, self.namer)
      dy.ctx = gast.Load()
      dx = create.create_grad(node.args[0], self.namer)
      dx.ctx = gast.Store()
      adj_name = naming.adjoint_name(func, active_args)
      adj_call = gast.Call(
          func=gast.Name(id=adj_name, ctx=gast.Load(), annotation=None),
          args=[self.substack, dy] + node.args,
          keywords=node.keywords)
      anno.setanno(adj_call, 'adj_call', True)
      adjoint = [template.replace('dxs = dfx', namer=self.namer, dfx=adj_call)]
      for j, i in enumerate(active_args):
        adjoint.append(template.replace('d[x] = dxs[i]', namer=self.namer,
                                        x=node.args[i].id, i=gast.Num(n=j)))
      return pri_call, adjoint

    # We have a template for the gradient that we need to fill in
    template_ = grads.adjoints[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)

    # Fill in any missing kwargs with the defaults from the template
    args = quoting.parse_function(template_).body[0].args
    kwargs = dict(zip(*map(reversed, [args.args, args.defaults])))
    kwargs.update(dict(zip(args.kwonlyargs, args.kw_defaults)))
    for arg, val in kwargs.items():
      if arg.id not in bound_args.arguments:
        bound_args.arguments[arg.id] = val

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: ast_.copy_node(self.target)}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    packing = []
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())
      packing = [gast.Assign(targets=[target], value=value)]

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(
          template_).co_varnames[-1]] = target
    adjoint = template.replace(template_, namer=self.namer, **arg_replacements)
    unpacking = []
    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [create.create_temp_grad(arg, self.namer)
                  for arg in to_pack]
      value = create.create_grad(target, self.namer)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())
      unpacking = [gast.Assign(targets=[target], value=value)]

    return node, packing + adjoint + unpacking
Esempio n. 28
0
def store_state(node, reaching, defined, stack):
  """Push the final state of the primal onto the stack for the adjoint.

  Python's scoping rules make it possible for variables to not be defined in
  certain blocks based on the control flow path taken at runtime. In order to
  make sure we don't try to push non-existing variables onto the stack, we
  defined these variables explicitly (by assigning `None` to them) at the
  beginning of the function.

  All the variables that reach the return statement are pushed onto the
  stack, and in the adjoint they are popped off in reverse order.

  Args:
    node: A module with the primal and adjoint function definitions as returned
        by `reverse_ad`.
    reaching: The variable definitions that reach the end of the primal.
    defined: The variables defined at the end of the primal.
    stack: The stack node to use for storing and restoring state.

  Returns:
    node: A node with the requisite pushes and pops added to make sure that
        state is transferred between primal and adjoint split motion calls.
  """
  defs = [def_ for def_ in reaching if not isinstance(def_[1], gast.arguments)]
  if not len(defs):
    return node
  reaching, original_defs = zip(*defs)

  # Explicitly define variables that might or might not be in scope at the end
  assignments = []
  for id_ in set(reaching) - defined:
    assignments.append(quoting.quote('{} = None'.format(id_)))

  # Store variables at the end of the function and restore them
  store = []
  load = []
  for id_, def_ in zip(reaching, original_defs):
    # If the original definition of a value that we need to store
    # was an initialization as a stack, then we should be using `push_stack`
    # to store its state, and `pop_stack` to restore it. This allows
    # us to avoid doing any `add_grad` calls on the stack, which result
    # in type errors in unoptimized mode (they are usually elided
    # after calling `dead_code_elimination`).
    if isinstance(
        def_, gast.Assign) and 'tangent.Stack()' in quoting.unquote(def_.value):
      push, pop, op_id = get_push_pop_stack()
    else:
      push, pop, op_id = get_push_pop()
    store.append(
        template.replace(
            'push(_stack, val, op_id)',
            push=push,
            val=id_,
            _stack=stack,
            op_id=op_id))
    load.append(
        template.replace(
            'val = pop(_stack, op_id)',
            pop=pop,
            val=id_,
            _stack=stack,
            op_id=op_id))

  body, return_ = node.body[0].body[:-1], node.body[0].body[-1]
  node.body[0].body = assignments + body + store + [return_]
  node.body[1].body = load[::-1] + node.body[1].body

  return node