예제 #1
0
 def visit_Subscript(self, node):
     if isinstance(node.value,
                   (gast.Name, gast.Num)) and node.value.id == 'd':
         if (not isinstance(node.slice, gast.Index)
                 or not isinstance(node.slice.value,
                                   (gast.Subscript, gast.Name, gast.Str))):
             # This happens when the gradient of a constant is taken
             if self.replace_grad == Replace.TANGENT:
                 new_node = gast.Num(0)
             else:
                 new_node = gast.Name(id='_', ctx=None, annotation=None)
                 self.remove(new_node)
         elif (self.replace_grad in (Replace.FULL, Replace.TANGENT)
               or isinstance(node.ctx, gast.Load)):
             new_node = create.create_grad(node.slice.value, self.namer,
                                           self.tangent)
         elif isinstance(node.ctx, gast.Store):
             new_node = create.create_temp_grad(node.slice.value,
                                                self.namer, self.tangent)
         else:
             raise ValueError
         new_node.ctx = node.ctx
         if isinstance(new_node, gast.Tuple):
             for elt in new_node.elts:
                 elt.ctx = node.ctx
         node = new_node
     return node
예제 #2
0
  def primal_and_adjoint_for_tracing(self, node):
    """Build the primal and adjoint of a traceable function.

    Args:
      node: ast.Call node of a function we wish to trace, instead of transform

    Returns:
      primal: new ast.Assign node to replace the original primal call
      adjoint: new ast.Assign node using the VJP generated in primal to
        calculate the adjoint.
    """
    primal_template = grads.primals[tracing.Traceable]
    adjoint_template = grads.adjoints[tracing.Traceable]

    # Prep
    to_pack = node.args
    target = ast_.copy_node(self.orig_target)
    vjp = quoting.quote(self.namer.unique('%s_grad' % node.func.id))
    tmp = create.create_temp(quoting.quote('tmp'), self.namer)
    assert len(node.keywords) == 0

    # Full replacement of primal
    # TODO: do we need to set 'pri_call' on this?
    primal = template.replace(
        primal_template,
        namer=self.namer,
        result=target,
        fn=node.func,
        tmp=tmp,
        vjp=vjp,
        args=gast.Tuple(elts=to_pack, ctx=gast.Load()))

    # Building adjoint using the vjp generated with the primal
    dto_pack = gast.Tuple(
        elts=[create.create_temp_grad(arg, self.namer) for arg in to_pack],
        ctx=gast.Store())

    adjoint = template.replace(
        adjoint_template,
        namer=self.namer,
        result=target,
        vjp=vjp,
        dargs=dto_pack)

    return primal, adjoint
예제 #3
0
  def visit_Call(self, node):
    if not self.target:
      return node
    func = anno.getanno(node, 'func')

    if func in tangents.UNIMPLEMENTED_TANGENTS:
      raise errors.ForwardNotImplementedError(func)

    if func == tracing.Traceable:
      raise NotImplementedError('Tracing of %s is not enabled in forward mode' %
                                quoting.unquote(node))

    if func not in tangents.tangents:
      try:
        quoting.parse_function(func)
      except:
        raise ValueError('No tangent found for %s, and could not get source.' %
                         func.__name__)

      # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy)
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if isinstance(arg, gast.Name))
      # TODO: Stack arguments are currently not considered
      # active, but for forward-mode applied to call trees,
      # they have to be. When we figure out how to update activity
      # analysis to do the right thing, we'll want to add the extra check:
      # `and arg.id in self.active_variables`

      # TODO: Duplicate of code in reverse_ad.
      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      fn_name = naming.tangent_name(func, active_args)
      orig_args = quoting.parse_function(func).body[0].args
      tangent_keywords = []
      for i in active_args:
        grad_node = create.create_grad(node.args[i], self.namer, tangent=True)
        arg_grad_node = create.create_grad(
            orig_args.args[i], self.namer, tangent=True)
        grad_node.ctx = gast.Load()
        tangent_keywords.append(
            gast.keyword(arg=arg_grad_node.id, value=grad_node))
      # Update the original call
      rhs = gast.Call(
          func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None),
          args=node.args,
          keywords=tangent_keywords + node.keywords)
      # Set self.value to False to trigger whole primal replacement
      self.value = False
      return [rhs]

    template_ = tangents.tangents[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)
    bound_args.apply_defaults()

    # If any keyword arguments weren't passed, we fill them using the
    # defaults of the original function
    if grads.DEFAULT in bound_args.arguments.values():
      # Build a mapping from names to defaults
      args = quoting.parse_function(func).body[0].args
      defaults = {}
      for arg, default in zip(*map(reversed, [args.args, args.defaults])):
        defaults[arg.id] = default
      for arg, default in zip(args.kwonlyargs, args.kw_defaults):
        if default is not None:
          defaults[arg.id] = default
      for name, value in bound_args.arguments.items():
        if value is grads.DEFAULT:
          bound_args.arguments[name] = defaults[name]

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: self.tmp_node}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(template_).co_varnames[
          -1]] = target
    tangent_node = template.replace(
        template_,
        replace_grad=template.Replace.TANGENT,
        namer=self.namer,
        **arg_replacements)

    # If the template uses the answer in the RHS of the tangent,
    # we need to make sure that the regular answer is replaced
    # with self.tmp_node, but that the gradient is not. We have
    # to be extra careful for statements like a = exp(a), because
    # both the target and RHS variables have the same name.
    tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True)
    tmp_grad_name = tmp_grad_node.id
    ans_grad_node = create.create_grad(self.target, self.namer, tangent=True)
    for _node in tangent_node:
      for succ in gast.walk(_node):
        if isinstance(succ, gast.Name) and succ.id == tmp_grad_name:
          succ.id = ans_grad_node.id

    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [
          create.create_temp_grad(arg, self.namer, True) for arg in to_pack
      ]
      value = create.create_grad(target, self.namer, tangent=True)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())

    # Stack pops have to be special-cased, we have
    # to set the 'push' attribute, so we know that if we
    # remove this pop, we have to remove the equivalent push.
    # NOTE: this only works if we're doing forward-over-reverse,
    # where reverse is applied in joint mode, with no call tree.
    # Otherwise, the pushes and pops won't be matched within a single
    # function call.
    if func == tangent.pop:
      if len(self.metastack):
        anno.setanno(tangent_node[0], 'push', self.metastack.pop())
      else:
        anno.setanno(tangent_node[0], 'push', None)
    return tangent_node
예제 #4
0
  def visit_Call(self, node):
    """Create adjoint for call.

    We don't allow unpacking of parameters, so we know that each argument
    gets passed in explicitly, allowing us to create partials for each.
    However, templates might perform parameter unpacking (for cases where
    the number of arguments is variable) and express their gradient as a
    tuple. In this case, we have to unpack this tuple of partials.
    """
    # Find the function we are differentiating
    func = anno.getanno(node, 'func')

    if func in non_differentiable.NON_DIFFERENTIABLE:
      return node, []

    if func == tracing.Traceable:
      return self.primal_and_adjoint_for_tracing(node)

    if func in grads.UNIMPLEMENTED_ADJOINTS:
      raise errors.ReverseNotImplementedError(func)


    # If we don't have an adjoint, we will have to step into the called
    # function and differentiate it
    if func not in grads.adjoints:
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if arg.id in self.active_variables)

      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      pri_name = naming.primal_name(func, active_args)
      pri_call = gast.Call(
          func=gast.Name(id=pri_name, ctx=gast.Load(), annotation=None),
          args=[self.substack] + node.args,
          keywords=node.keywords)
      anno.setanno(pri_call, 'pri_call', True)

      dy = create.create_grad(self.target, self.namer)
      dy.ctx = gast.Load()
      dx = create.create_grad(node.args[0], self.namer)
      dx.ctx = gast.Store()
      adj_name = naming.adjoint_name(func, active_args)
      adj_call = gast.Call(
          func=gast.Name(id=adj_name, ctx=gast.Load(), annotation=None),
          args=[self.substack, dy] + node.args,
          keywords=node.keywords)
      anno.setanno(adj_call, 'adj_call', True)
      adjoint = [template.replace('dxs = dfx', namer=self.namer, dfx=adj_call)]
      for j, i in enumerate(active_args):
        adjoint.append(template.replace('d[x] = dxs[i]', namer=self.namer,
                                        x=node.args[i].id, i=gast.Num(n=j)))
      return pri_call, adjoint

    # We have a template for the gradient that we need to fill in
    template_ = grads.adjoints[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)

    # Fill in any missing kwargs with the defaults from the template
    args = quoting.parse_function(template_).body[0].args
    kwargs = dict(zip(*map(reversed, [args.args, args.defaults])))
    kwargs.update(dict(zip(args.kwonlyargs, args.kw_defaults)))
    for arg, val in kwargs.items():
      if arg.id not in bound_args.arguments:
        bound_args.arguments[arg.id] = val

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: ast_.copy_node(self.target)}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    packing = []
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())
      packing = [gast.Assign(targets=[target], value=value)]

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(
          template_).co_varnames[-1]] = target
    adjoint = template.replace(template_, namer=self.namer, **arg_replacements)
    unpacking = []
    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [create.create_temp_grad(arg, self.namer)
                  for arg in to_pack]
      value = create.create_grad(target, self.namer)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())
      unpacking = [gast.Assign(targets=[target], value=value)]

    return node, packing + adjoint + unpacking