示例#1
0
 def _build_print_call_node(self, node):
     return gast.Call(
         func=gast.parse('fluid.layers.Print').body[0].value,
         args=[node],
         keywords=[
             gast.keyword(
                 arg='summarize',
                 value=gast.UnaryOp(
                     op=gast.USub(),
                     operand=gast.Constant(
                         value=1, kind=None))), gast.keyword(
                             arg='print_phase',
                             value=gast.Constant(
                                 value='forward', kind=None))
         ])
示例#2
0
def add_keywords_to(node, add_dict): 
    assert isinstance(node, gast.Call)
    for key in add_dict: 
        value = gast.Constant(value=add_dict[key], kind=None)      
        add_keyword = gast.keyword(arg=key, value=value)
        node.keywords.append(add_keyword)
    return
示例#3
0
文件: ast3.py 项目: yiakwy/gast
        def visit_Call(self, node):
            if sys.version_info.minor < 5:
                if node.starargs:
                    star = gast.Starred(self._visit(node.starargs),
                                        gast.Load())
                    gast.copy_location(star, node)
                    starred = [star]
                else:
                    starred = []

                if node.kwargs:
                    kw = gast.keyword(None, self._visit(node.kwargs))
                    gast.copy_location(kw, node.kwargs)
                    kwargs = [kw]
                else:
                    kwargs = []
            else:
                starred = kwargs = []

            new_node = gast.Call(
                self._visit(node.func),
                self._visit(node.args) + starred,
                self._visit(node.keywords) + kwargs,
            )
            gast.copy_location(new_node, node)
            return new_node
示例#4
0
    def visit_Call(self, node):
        if node.starargs:
            star = gast.Starred(self._visit(node.starargs), gast.Load())
            ast.copy_location(star, node)
            starred = [star]
        else:
            starred = []

        if node.kwargs:
            kwargs = [gast.keyword(None, self._visit(node.kwargs))]
        else:
            kwargs = []

        new_node = gast.Call(
            self._visit(node.func),
            self._visit(node.args) + starred,
            self._visit(node.keywords) + kwargs,
        )
        ast.copy_location(new_node, node)
        return new_node
示例#5
0
        def visit_Call(self, node):
            if node.starargs:
                star = gast.Starred(self._visit(node.starargs), gast.Load())
                ast.copy_location(star, node)
                starred = [star]
            else:
                starred = []

            if node.kwargs:
                kwargs = [gast.keyword(None, self._visit(node.kwargs))]
            else:
                kwargs = []

            new_node = gast.Call(
                self._visit(node.func),
                self._visit(node.args) + starred,
                self._visit(node.keywords) + kwargs,
            )
            ast.copy_location(new_node, node)
            return new_node
示例#6
0
def _add_keywords_to(node, dygraph_api_name):
    assert isinstance(node, gast.Call)
    if dygraph_api_name == "Linear":
        for ast_keyword in node.keywords:
            if ast_keyword.arg == "output_dim":
                ast_keyword.arg = "size"

        node.keywords.append(
            gast.keyword(arg="num_flatten_dims",
                         value=gast.Constant(value=-1, kind=None)))

    if dygraph_api_name == "BilinearTensorProduct":
        for ast_keyword in node.keywords:
            if ast_keyword.arg == "output_dim":
                ast_keyword.arg = "size"

    if dygraph_api_name == "PRelu":
        for ast_keyword in node.keywords:
            if ast_keyword.arg == "input":
                ast_keyword.arg = "x"
    return
示例#7
0
    def visit_Call(self, node):
        if node.starargs:
            star = gast.Starred(self._visit(node.starargs), gast.Load())
            gast.copy_location(star, node)
            star.end_lineno = star.end_col_offset = None
            starred = [star]
        else:
            starred = []

        if node.kwargs:
            kwargs = [gast.keyword(None, self._visit(node.kwargs))]
        else:
            kwargs = []

        new_node = gast.Call(
            self._visit(node.func),
            self._visit(node.args) + starred,
            self._visit(node.keywords) + kwargs,
        )
        gast.copy_location(new_node, node)
        new_node.end_lineno = new_node.end_col_offset = None
        return new_node
示例#8
0
def call_func_ast(
    fn_name: str,
    args: Union[List[AST], Dict[str, AST]] = [],
    attr_parent: Optional[str] = None,
) -> Call:
    """Call the Function with the given name with the given args.

    Applies the arguments in `args` by argument name in to `fn` and calls the
    function by creating a Call AST node. Any extra arguments in `args` not used
    to apply arguments in `fn` is ignored.

    Args:
        fn_name: Name of the function to call.
        args: a List of  argument values AST nodes to use as positional arguments,
            or a Dict of argument name as raw string to argument value as an AST node
            to use as keyword arguments.
        attr_parent: Optionally specifiy name of the parent attribute required
            to reference the given `fn`. Genrates a call with `attr_parent.fn(...)`.
    Returns:
        Call AST with the given args applied that represents the function call.
    """
    # create qualified reference to function
    if attr_parent is not None:
        func_ref = Attribute(
            value=name_ast(attr_parent),
            attr=fn_name,
            ctx=Load(),
        )
    else:
        func_ref = name_ast(fn_name)
    # create call AST with function name and apply args
    if isinstance(args, dict):
        return Call(
            args=[],
            func=func_ref,
            keywords=[keyword(name, value) for name, value in args.items()],
        )
    # apply arguments using positional arguments syntax
    return Call(args=args, func=func_ref, keywords=[])
示例#9
0
def update_args_of_func(node, dygraph_node, method_name):
    assert isinstance(node, gast.Call)
    if method_name not in ["__init__", "forward"]:
        raise ValueError(
            "The method name of class to update args should be '__init__' or 'forward'"
        )

    class_src = astor.to_source(gast.gast_to_ast(dygraph_node.func))
    import paddle.fluid as fluid
    if method_name == "__init__" or eval(
            "issubclass({}, fluid.dygraph.Layer)".format(class_src)):
        full_args = eval("inspect.getargspec({}.{})".format(
            class_src, method_name))
        full_args_name = [
            arg_name for arg_name in full_args[0] if arg_name != "self"
        ]
    else:
        full_args_name = []
    added_keywords = []
    for idx, arg in enumerate(node.args):
        added_keywords.append(gast.keyword(arg=full_args_name[idx], value=arg))

    node.args = []
    node.keywords = added_keywords + node.keywords
示例#10
0
  def visit_Call(self, node):
    if not self.target:
      return node
    func = anno.getanno(node, 'func')

    if func in tangents.UNIMPLEMENTED_TANGENTS:
      raise errors.ForwardNotImplementedError(func)

    if func == tracing.Traceable:
      raise NotImplementedError('Tracing of %s is not enabled in forward mode' %
                                quoting.unquote(node))

    if func not in tangents.tangents:
      try:
        quoting.parse_function(func)
      except:
        raise ValueError('No tangent found for %s, and could not get source.' %
                         func.__name__)

      # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy)
      active_args = tuple(i for i, arg in enumerate(node.args)
                          if isinstance(arg, gast.Name))
      # TODO: Stack arguments are currently not considered
      # active, but for forward-mode applied to call trees,
      # they have to be. When we figure out how to update activity
      # analysis to do the right thing, we'll want to add the extra check:
      # `and arg.id in self.active_variables`

      # TODO: Duplicate of code in reverse_ad.
      already_counted = False
      for f, a in self.required:
        if f.__name__ == func.__name__ and set(a) == set(active_args):
          already_counted = True
          break
      if not already_counted:
        self.required.append((func, active_args))

      fn_name = naming.tangent_name(func, active_args)
      orig_args = quoting.parse_function(func).body[0].args
      tangent_keywords = []
      for i in active_args:
        grad_node = create.create_grad(node.args[i], self.namer, tangent=True)
        arg_grad_node = create.create_grad(
            orig_args.args[i], self.namer, tangent=True)
        grad_node.ctx = gast.Load()
        tangent_keywords.append(
            gast.keyword(arg=arg_grad_node.id, value=grad_node))
      # Update the original call
      rhs = gast.Call(
          func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None),
          args=node.args,
          keywords=tangent_keywords + node.keywords)
      # Set self.value to False to trigger whole primal replacement
      self.value = False
      return [rhs]

    template_ = tangents.tangents[func]

    # Match the function call to the template
    sig = funcsigs.signature(template_)
    sig = sig.replace(parameters=list(sig.parameters.values())[1:])
    kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords)
    bound_args = sig.bind(*node.args, **kwargs)
    bound_args.apply_defaults()

    # If any keyword arguments weren't passed, we fill them using the
    # defaults of the original function
    if grads.DEFAULT in bound_args.arguments.values():
      # Build a mapping from names to defaults
      args = quoting.parse_function(func).body[0].args
      defaults = {}
      for arg, default in zip(*map(reversed, [args.args, args.defaults])):
        defaults[arg.id] = default
      for arg, default in zip(args.kwonlyargs, args.kw_defaults):
        if default is not None:
          defaults[arg.id] = default
      for name, value in bound_args.arguments.items():
        if value is grads.DEFAULT:
          bound_args.arguments[name] = defaults[name]

    # Let's fill in the template. The first argument is the output, which
    # was stored in a temporary variable
    output_name = six.get_function_code(template_).co_varnames[0]
    arg_replacements = {output_name: self.tmp_node}
    arg_replacements.update(bound_args.arguments)

    # If the template uses *args, then we pack the corresponding inputs
    flags = six.get_function_code(template_).co_flags

    if flags & inspect.CO_VARARGS:
      to_pack = node.args[six.get_function_code(template_).co_argcount - 1:]
      vararg_name = six.get_function_code(template_).co_varnames[-1]
      target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store())
      value = gast.Tuple(elts=to_pack, ctx=gast.Load())

      # And we fill in the packed tuple into the template
      arg_replacements[six.get_function_code(template_).co_varnames[
          -1]] = target
    tangent_node = template.replace(
        template_,
        replace_grad=template.Replace.TANGENT,
        namer=self.namer,
        **arg_replacements)

    # If the template uses the answer in the RHS of the tangent,
    # we need to make sure that the regular answer is replaced
    # with self.tmp_node, but that the gradient is not. We have
    # to be extra careful for statements like a = exp(a), because
    # both the target and RHS variables have the same name.
    tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True)
    tmp_grad_name = tmp_grad_node.id
    ans_grad_node = create.create_grad(self.target, self.namer, tangent=True)
    for _node in tangent_node:
      for succ in gast.walk(_node):
        if isinstance(succ, gast.Name) and succ.id == tmp_grad_name:
          succ.id = ans_grad_node.id

    if flags & inspect.CO_VARARGS:
      # If the template packs arguments, then we have to unpack the
      # derivatives afterwards
      # We also have to update the replacements tuple then
      dto_pack = [
          create.create_temp_grad(arg, self.namer, True) for arg in to_pack
      ]
      value = create.create_grad(target, self.namer, tangent=True)
      target = gast.Tuple(elts=dto_pack, ctx=gast.Store())

    # Stack pops have to be special-cased, we have
    # to set the 'push' attribute, so we know that if we
    # remove this pop, we have to remove the equivalent push.
    # NOTE: this only works if we're doing forward-over-reverse,
    # where reverse is applied in joint mode, with no call tree.
    # Otherwise, the pushes and pops won't be matched within a single
    # function call.
    if func == tangent.pop:
      if len(self.metastack):
        anno.setanno(tangent_node[0], 'push', self.metastack.pop())
      else:
        anno.setanno(tangent_node[0], 'push', None)
    return tangent_node