Ejemplo n.º 1
0
 def visit_Subscript(self, node):
     if isinstance(node.value,
                   (gast.Name, gast.Num)) and node.value.id == 'd':
         if (not isinstance(node.slice, gast.Index)
                 or not isinstance(node.slice.value,
                                   (gast.Subscript, gast.Name, gast.Str))):
             # This happens when the gradient of a constant is taken
             if self.replace_grad == Replace.TANGENT:
                 new_node = gast.Num(0)
             else:
                 new_node = gast.Name(id='_', ctx=None, annotation=None)
                 self.remove(new_node)
         elif (self.replace_grad in (Replace.FULL, Replace.TANGENT)
               or isinstance(node.ctx, gast.Load)):
             new_node = create.create_grad(node.slice.value, self.namer,
                                           self.tangent)
         elif isinstance(node.ctx, gast.Store):
             new_node = create.create_temp_grad(node.slice.value,
                                                self.namer, self.tangent)
         else:
             raise ValueError
         new_node.ctx = node.ctx
         if isinstance(new_node, gast.Tuple):
             for elt in new_node.elts:
                 elt.ctx = node.ctx
         node = new_node
     return node
Ejemplo n.º 2
0
    def visit_FunctionDef(self, node):
        self.namer = naming.Namer.build(node)

        # Get the tangent of the body
        new_body = []
        for n in node.body:
            new = self.visit(n)
            if isinstance(new, (list, tuple)):
                new_body.extend(new)
            else:
                new_body.append(new)
        node.body = new_body

        # Add in the initial gradient argument
        grad_args = [
            create.create_grad(arg, self.namer, tangent=True)
            for i, arg in enumerate(node.args.args) if i in self.wrt
        ]
        if len(self.wrt) != len(grad_args):
            raise ValueError(
                'Mismatch between requested and retrieved derivative arguments. '
                'Requested %d, found %d') % (len(self.wrt), len(grad_args))

        node.args.args += grad_args

        if self.check_dims:
            # Define the shape check code quote
            def shape_match_template(primal, tangent_):
                if not tangent.shapes_match(primal, tangent_):
                    raise ValueError(
                        'Shape mismatch between argument value (%s) and seed derivative '
                        '(%s)' \
                  % (numpy.shape(primal), numpy.shape(tangent_)))

            # Add a shape check for each seed derivative & primal pair.
            shape_check_nodes = []
            for iwrt, tangent_var in zip(self.wrt, grad_args):
                primal = node.args.args[iwrt]
                shape_check = template.replace(shape_match_template,
                                               primal=primal,
                                               tangent_=tangent_var)[0]
                shape_check_nodes.append(shape_check)
            node.body = shape_check_nodes + node.body

        # Add in gradient initialization statements for everything else
        grad_init_nodes = [
            template.replace('d[z] = init_grad(z)',
                             replace_grad=template.Replace.TANGENT,
                             namer=self.namer,
                             z=arg,
                             init_grad=utils.INIT_GRAD)
            for i, arg in enumerate(node.args.args) if i not in self.wrt
        ]
        node.body = grad_init_nodes + node.body

        # Rename the function
        func = anno.getanno(node, 'func')
        node.name = naming.tangent_name(func, self.wrt)

        return node
Ejemplo n.º 3
0
 def visit_Return(self, node):
   orig_retval = ast_.copy_node(node.value)
   retval = node.value
   if isinstance(retval, (gast.Name, gast.Subscript)):
     retval = gast.Tuple(
         elts=[create.create_grad(retval, self.namer, tangent=True)],
         ctx=gast.Load())
   elif isinstance(retval, gast.Tuple):
     retval.elts = [
         create.create_grad(elt, self.namer, tangent=True)
         for elt in retval.elts
     ]
   else:
     raise ValueError
   for n in retval.elts:
     n.ctx = gast.Load()
   if self.preserve_result:
     retval.elts.append(orig_retval)
   node.value = retval
   return node
Ejemplo n.º 4
0
  def create_grad_list(self, node):
    assert isinstance(node, (gast.List, gast.Tuple)), 'Must be list or tuple'
    list_of_nodes = node.elts
    elts = []
    for _node in list_of_nodes:
      if isinstance(_node, (gast.Name, gast.Subscript)):
        grad_node = create.create_grad(_node, self.namer, tangent=True)
        grad_node.ctx = node.ctx
        elts.append(grad_node)
      elif isinstance(_node, gast.Num):
        elts.append(gast.Num(0))
      elif isinstance(_node, (gast.List, gast.Tuple)):
        elts.append(self.create_grad_list(_node.elts))
      else:
        raise ValueError('Cannot handle node type %s' % type(_node))

    return node.__class__(elts=elts, ctx=node.ctx)
Ejemplo n.º 5
0
 def visit_With(self, node):
   """Deal with the special with insert_grad_of(x) statement."""
   if ast_.is_insert_grad_of_statement(node):
     primal = []
     adjoint = node.body
     if isinstance(adjoint[0], gast.With):
       _, adjoint = self.visit(adjoint[0])
     node.body[0] = comments.add_comment(node.body[0], 'Inserted code')
     # Rename the gradients
     replacements = {}
     for item in node.items:
       if (not isinstance(item.context_expr.args[0], gast.Name) or
           not isinstance(item.optional_vars, gast.Name)):
         raise ValueError
       replacements[item.optional_vars.id] = create.create_grad(
           item.context_expr.args[0], self.namer)
     template.ReplaceTransformer(replacements).visit(node)
     return primal, adjoint
   else:
     return node, []
Ejemplo n.º 6
0
  def visit_FunctionDef(self, node):
    self.namer = naming.Namer.build(node)

    # Get the tangent of the body
    new_body = []
    for n in node.body:
      new = self.visit(n)
      if isinstance(new, (list, tuple)):
        new_body.extend(new)
      else:
        new_body.append(new)
    node.body = new_body

    # Add in the initial gradient argument
    grad_args = [
        create.create_grad(arg, self.namer, tangent=True)
        for i, arg in enumerate(node.args.args) if i in self.wrt
    ]

    node.args.args += grad_args

    # Add in gradient initialization statements for everything else
    grad_init_nodes = [
        template.replace(
            'd[z] = init_grad(z)',
            replace_grad=template.Replace.TANGENT,
            namer=self.namer,
            z=arg,
            init_grad=utils.INIT_GRAD) for i, arg in enumerate(node.args.args)
        if i not in self.wrt
    ]
    node.body = grad_init_nodes + node.body

    # Rename the function
    func = anno.getanno(node, 'func')
    node.name = naming.tangent_name(func, self.wrt)

    return node
Ejemplo n.º 7
0
  def visit_Call(self, node):
    if not self.target:
      return node
    func = anno.getanno(node, 'func')

    if func in tangents.UNIMPLEMENTED_TANGENTS:
      raise errors.ForwardNotImplementedError(func)

    if func == tracing.Traceable:
      raise NotImplementedError('Tracing of %s is not enabled in forward mode' %
                                quoting.unquote(node))

    if func not in tangents.tangents:
      try:
        quoting.parse_function(func)
      except:
        raise ValueError('No tangent found for %s, and could not get source.' %
                         func.__name__)

      # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy)
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if isinstance(arg, gast.Name))
      # TODO: Stack arguments are currently not considered
      # active, but for forward-mode applied to call trees,
      # they have to be. When we figure out how to update activity
      # analysis to do the right thing, we'll want to add the extra check:
      # `and arg.id in self.active_variables`

      # TODO: Duplicate of code in reverse_ad.
      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      fn_name = naming.tangent_name(func, active_args)
      orig_args = quoting.parse_function(func).body[0].args
      tangent_keywords = []
      for i in active_args:
        grad_node = create.create_grad(node.args[i], self.namer, tangent=True)
        arg_grad_node = create.create_grad(
            orig_args.args[i], self.namer, tangent=True)
        grad_node.ctx = gast.Load()
        tangent_keywords.append(
            gast.keyword(arg=arg_grad_node.id, value=grad_node))
      # Update the original call
      rhs = gast.Call(
          func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None),
          args=node.args,
          keywords=tangent_keywords + node.keywords)
      # Set self.value to False to trigger whole primal replacement
      self.value = False
      return [rhs]

    template_ = tangents.tangents[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)
    bound_args.apply_defaults()

    # If any keyword arguments weren't passed, we fill them using the
    # defaults of the original function
    if grads.DEFAULT in bound_args.arguments.values():
      # Build a mapping from names to defaults
      args = quoting.parse_function(func).body[0].args
      defaults = {}
      for arg, default in zip(*map(reversed, [args.args, args.defaults])):
        defaults[arg.id] = default
      for arg, default in zip(args.kwonlyargs, args.kw_defaults):
        if default is not None:
          defaults[arg.id] = default
      for name, value in bound_args.arguments.items():
        if value is grads.DEFAULT:
          bound_args.arguments[name] = defaults[name]

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: self.tmp_node}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(template_).co_varnames[
          -1]] = target
    tangent_node = template.replace(
        template_,
        replace_grad=template.Replace.TANGENT,
        namer=self.namer,
        **arg_replacements)

    # If the template uses the answer in the RHS of the tangent,
    # we need to make sure that the regular answer is replaced
    # with self.tmp_node, but that the gradient is not. We have
    # to be extra careful for statements like a = exp(a), because
    # both the target and RHS variables have the same name.
    tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True)
    tmp_grad_name = tmp_grad_node.id
    ans_grad_node = create.create_grad(self.target, self.namer, tangent=True)
    for _node in tangent_node:
      for succ in gast.walk(_node):
        if isinstance(succ, gast.Name) and succ.id == tmp_grad_name:
          succ.id = ans_grad_node.id

    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [
          create.create_temp_grad(arg, self.namer, True) for arg in to_pack
      ]
      value = create.create_grad(target, self.namer, tangent=True)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())

    # Stack pops have to be special-cased, we have
    # to set the 'push' attribute, so we know that if we
    # remove this pop, we have to remove the equivalent push.
    # NOTE: this only works if we're doing forward-over-reverse,
    # where reverse is applied in joint mode, with no call tree.
    # Otherwise, the pushes and pops won't be matched within a single
    # function call.
    if func == tangent.pop:
      if len(self.metastack):
        anno.setanno(tangent_node[0], 'push', self.metastack.pop())
      else:
        anno.setanno(tangent_node[0], 'push', None)
    return tangent_node
Ejemplo n.º 8
0
  def visit_Call(self, node):
    """Create adjoint for call.

    We don't allow unpacking of parameters, so we know that each argument
    gets passed in explicitly, allowing us to create partials for each.
    However, templates might perform parameter unpacking (for cases where
    the number of arguments is variable) and express their gradient as a
    tuple. In this case, we have to unpack this tuple of partials.
    """
    # Find the function we are differentiating
    func = anno.getanno(node, 'func')

    if func in non_differentiable.NON_DIFFERENTIABLE:
      return node, []

    if func == tracing.Traceable:
      return self.primal_and_adjoint_for_tracing(node)

    if func in grads.UNIMPLEMENTED_ADJOINTS:
      raise errors.ReverseNotImplementedError(func)


    # If we don't have an adjoint, we will have to step into the called
    # function and differentiate it
    if func not in grads.adjoints:
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if arg.id in self.active_variables)

      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      pri_name = naming.primal_name(func, active_args)
      pri_call = gast.Call(
          func=gast.Name(id=pri_name, ctx=gast.Load(), annotation=None),
          args=[self.substack] + node.args,
          keywords=node.keywords)
      anno.setanno(pri_call, 'pri_call', True)

      dy = create.create_grad(self.target, self.namer)
      dy.ctx = gast.Load()
      dx = create.create_grad(node.args[0], self.namer)
      dx.ctx = gast.Store()
      adj_name = naming.adjoint_name(func, active_args)
      adj_call = gast.Call(
          func=gast.Name(id=adj_name, ctx=gast.Load(), annotation=None),
          args=[self.substack, dy] + node.args,
          keywords=node.keywords)
      anno.setanno(adj_call, 'adj_call', True)
      adjoint = [template.replace('dxs = dfx', namer=self.namer, dfx=adj_call)]
      for j, i in enumerate(active_args):
        adjoint.append(template.replace('d[x] = dxs[i]', namer=self.namer,
                                        x=node.args[i].id, i=gast.Num(n=j)))
      return pri_call, adjoint

    # We have a template for the gradient that we need to fill in
    template_ = grads.adjoints[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)

    # Fill in any missing kwargs with the defaults from the template
    args = quoting.parse_function(template_).body[0].args
    kwargs = dict(zip(*map(reversed, [args.args, args.defaults])))
    kwargs.update(dict(zip(args.kwonlyargs, args.kw_defaults)))
    for arg, val in kwargs.items():
      if arg.id not in bound_args.arguments:
        bound_args.arguments[arg.id] = val

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: ast_.copy_node(self.target)}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    packing = []
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())
      packing = [gast.Assign(targets=[target], value=value)]

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(
          template_).co_varnames[-1]] = target
    adjoint = template.replace(template_, namer=self.namer, **arg_replacements)
    unpacking = []
    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [create.create_temp_grad(arg, self.namer)
                  for arg in to_pack]
      value = create.create_grad(target, self.namer)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())
      unpacking = [gast.Assign(targets=[target], value=value)]

    return node, packing + adjoint + unpacking
Ejemplo n.º 9
0
  def visit_FunctionDef(self, node):
    # Construct a namer to guarantee we create unique names that don't
    # override existing names
    self.namer = naming.Namer.build(node)

    # Check that this function has exactly one return statement at the end
    return_nodes = [n for n in gast.walk(node) if isinstance(n, gast.Return)]
    if ((len(return_nodes) > 1) or not isinstance(node.body[-1], gast.Return)):
      raise ValueError('function must have exactly one return statement')
    return_node = ast_.copy_node(return_nodes[0])

    # Perform AD on the function body
    body, adjoint_body = self.visit_statements(node.body[:-1])

    # Annotate the first statement of the primal and adjoint as such
    if body:
      body[0] = comments.add_comment(body[0], 'Beginning of forward pass')
    if adjoint_body:
      adjoint_body[0] = comments.add_comment(
          adjoint_body[0], 'Beginning of backward pass')

    # Before updating the primal arguments, extract the arguments we want
    # to differentiate with respect to
    dx = gast.Tuple([create.create_grad(node.args.args[i], self.namer)
                     for i in self.wrt], ctx=gast.Load())

    if self.preserve_result:
      # Append an extra Assign operation to the primal body
      # that saves the original output value
      stored_result_node = quoting.quote(self.namer.unique('result'))
      assign_stored_result = template.replace(
          'result=orig_result',
          result=stored_result_node,
          orig_result=return_node.value)
      body.append(assign_stored_result)
      dx.elts.append(stored_result_node)

    for _dx in dx.elts:
      _dx.ctx = gast.Load()
    return_dx = gast.Return(value=dx)

    # We add the stack as first argument of the primal
    node.args.args = [self.stack] + node.args.args

    # Rename the function to its primal name
    func = anno.getanno(node, 'func')
    node.name = naming.primal_name(func, self.wrt)

    # The new body is the primal body plus the return statement
    node.body = body + node.body[-1:]

    # Find the cost; the first variable of potentially multiple return values
    # The adjoint will receive a value for the initial gradient of the cost
    y = node.body[-1].value
    if isinstance(y, gast.Tuple):
      y = y.elts[0]
    dy = gast.Name(id=self.namer.grad(y.id), ctx=gast.Param(),
                   annotation=None)

    # Construct the adjoint
    adjoint_template = grads.adjoints[gast.FunctionDef]
    adjoint, = template.replace(adjoint_template, namer=self.namer,
                                adjoint_body=adjoint_body, return_dx=return_dx)
    adjoint.args.args.extend([self.stack, dy])
    adjoint.args.args.extend(node.args.args[1:])
    adjoint.name = naming.adjoint_name(func, self.wrt)

    return node, adjoint