Пример #1
0
 def visit_Assign(self, node):
     if len(node.targets) > 1:
         raise NotImplementedError("cannot process multiple assignment")
     if not isinstance(node.targets[0], gast.Name):
         raise NotImplementedError("cannot process indexed assignment")
     # $lhs = $lhs.update_($rhs, matchbox.EXECUTION_MASK) if (lhs in vars()
     # or lhs in globals()) and isinstance($lhs, (matchbox.MaskedBatch,
     # matchbox.TENSOR_TYPE)) else $rhs
     node.value = gast.IfExp(
         gast.BoolOp(
             gast.And(),
             [
                 gast.BoolOp(gast.Or(), [
                     gast.Compare(gast.Str(
                         node.targets[0].id), [gast.In()], [
                             gast.Call(gast.Name('vars', gast.Load, None),
                                       [], [])
                         ]),
                     gast.Compare(gast.Str(
                         node.targets[0].id), [gast.In()], [
                             gast.Call(
                                 gast.Name('globals', gast.Load, None), [],
                                 [])
                         ])
                 ]),
                 # gast.Compare(
                 #    gast.Attribute(
                 #      gast.Name('matchbox', gast.Load(), None),
                 #      gast.Name('EXECUTION_MASK', gast.Load(), None),
                 #      gast.Load()),
                 #    [gast.IsNot()],
                 #    [gast.NameConstant(None)]),
                 gast.Call(gast.Name('isinstance', gast.Load(), None), [
                     node.targets[0],
                     gast.Tuple([
                         gast.Attribute(
                             gast.Name('matchbox', gast.Load(), None),
                             gast.Name('MaskedBatch', gast.Load(), None),
                             gast.Load()),
                         gast.Attribute(
                             gast.Name('matchbox', gast.Load(), None),
                             gast.Name('TENSOR_TYPE', gast.Load(), None),
                             gast.Load())
                     ], gast.Load())
                 ], [])
             ]),
         gast.Call(
             gast.Attribute(
                 gast.Name(node.targets[0].id, gast.Load(), None),
                 gast.Name('_update', gast.Load(), None), gast.Load()), [
                     node.value,
                     gast.Attribute(
                         gast.Name('matchbox', gast.Load(), None),
                         gast.Name('EXECUTION_MASK', gast.Load(), None),
                         gast.Load())
                 ], []),
         node.value)
     return node
Пример #2
0
 def visit_Attribute(self, node):
     node = self.generic_visit(node)
     # method name -> not a getattr
     if node.attr in methods:
         return node
     # imported module -> not a getattr
     elif (isinstance(node.value, ast.Name)
           and node.value.id in self.imports):
         module_id = self.imports[node.value.id]
         if node.attr not in MODULES[self.renamer(module_id, MODULES)[1]]:
             msg = ("`" + node.attr + "' is not a member of " + module_id +
                    " or Pythran does not support it")
             raise PythranSyntaxError(msg, node)
         node.value.id = module_id  # patch module aliasing
         self.update = True
         return node
     # not listed as attributed -> not a getattr
     elif node.attr not in attributes:
         return node
     # A getattr !
     else:
         self.update = True
         call = ast.Call(
             ast.Attribute(ast.Name('__builtin__', ast.Load(), None),
                           'getattr', ast.Load()),
             [node.value, ast.Str(node.attr)], [])
         if isinstance(node.ctx, ast.Store):
             # the only situation where this arises is for real/imag of
             # a ndarray. As a call is not valid for a store, add a slice
             # to ends up with a valid lhs
             assert node.attr in ('real', 'imag'), "only store to imag/real"
             return ast.Subscript(call, ast.Slice(None, None, None),
                                  node.ctx)
         else:
             return call
Пример #3
0
    def visit_Lambda(self, node):
        self.state[_Function].enter()
        node = self.generic_visit(node)

        # Only wrap the top-level function. Theoretically, we can and should wrap
        # everything, but that can lead to excessive boilerplate when lambdas are
        # nested.
        # TODO(mdan): Looks more closely for use cases that actually require this.
        if self.state[_Function].level > 2:
            self.state[_Function].exit()
            return node

        scope = anno.getanno(node, anno.Static.SCOPE)
        function_context_name = self.ctx.namer.new_symbol(
            'lscope', scope.referenced)
        self.state[_Function].context_name = function_context_name
        anno.setanno(node, 'function_context_name', function_context_name)

        template = """
      ag__.with_function_scope(
          lambda function_context: body, function_context_name, options)
    """
        node.body = templates.replace_as_expression(
            template,
            options=self.ctx.program.options.to_ast(),
            function_context=function_context_name,
            function_context_name=gast.Str(function_context_name),
            body=node.body)

        self.state[_Function].exit()
        return node
Пример #4
0
def to_ast(value):
    """
    Turn a value into ast expression.

    >>> a = 1
    >>> print ast.dump(to_ast(a))
    Num(n=1)
    >>> a = [1, 2, 3]
    >>> print ast.dump(to_ast(a))
    List(elts=[Num(n=1), Num(n=2), Num(n=3)], ctx=Load())
    """
    if isinstance(value, (type(None), bool)):
        return builtin_folding(value)
    if any(value is t for t in (bool, int, float)):
        return builtin_folding(value)
    elif isinstance(value, numpy.generic):
        return to_ast(numpy.asscalar(value))
    elif isinstance(value, numbers.Number):
        return ast.Num(value)
    elif isinstance(value, str):
        return ast.Str(value)
    elif isinstance(value, (list, tuple, set, dict, numpy.ndarray)):
        return size_container_folding(value)
    elif hasattr(value, "__module__") and value.__module__ == "__builtin__":
        # TODO Can be done the same way for others modules
        return builtin_folding(value)
    # only meaningful for python3
    elif sys.version_info.major == 3:
        if isinstance(value, (filter, map, zip)):
            return to_ast(list(value))
    raise ToNotEval()
Пример #5
0
 def visit_Attribute(self, node):
     node = self.generic_visit(node)
     # storing in an attribute -> not a getattr
     if not isinstance(node.ctx, ast.Load):
         return node
     # method name -> not a getattr
     elif node.attr in methods:
         return node
     # imported module -> not a getattr
     elif (isinstance(node.value, ast.Name)
           and node.value.id in self.imports):
         module_id = self.imports[node.value.id]
         if node.attr not in MODULES[self.renamer(module_id, MODULES)[1]]:
             msg = ("`" + node.attr + "' is not a member of " + module_id +
                    " or Pythran does not support it")
             raise PythranSyntaxError(msg, node)
         node.value.id = module_id  # patch module aliasing
         self.update = True
         return node
     # not listed as attributed -> not a getattr
     elif node.attr not in attributes:
         return node
     # A getattr !
     else:
         self.update = True
         return ast.Call(
             ast.Attribute(ast.Name('__builtin__', ast.Load(), None),
                           'getattr', ast.Load()),
             [node.value, ast.Str(node.attr)], [])
Пример #6
0
def keywords_to_dict(keywords):
    keys = []
    values = []
    for kw in keywords:
        keys.append(gast.Str(kw.arg))
        values.append(kw.value)
    return gast.Dict(keys=keys, values=values)
Пример #7
0
    def visit_Call(self, node):
        # TODO(mdan): Refactor converted_call as a 'Call' operator.

        # Calls to the internal 'ag__' module are never converted (though their
        # arguments might be).
        full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
        if full_name.startswith('ag__.'):
            return self.generic_visit(node)
        if (full_name == 'print' and not self.ctx.program.options.uses(
                converter.Feature.BUILTIN_FUNCTIONS)):
            return self.generic_visit(node)

        template = """
      ag__.converted_call(func, owner, options, args)
    """
        if isinstance(node.func, gast.Attribute):
            func = gast.Str(node.func.attr)
            owner = node.func.value
        else:
            func = node.func
            owner = parser.parse_expression('None')

        new_call = templates.replace_as_expression(
            template,
            func=func,
            owner=owner,
            options=self.ctx.program.options.to_ast(
                self.ctx,
                internal_convert_user_code=self.ctx.program.options.recursive),
            args=node.args)
        # TODO(mdan): Improve the template mechanism to better support this.
        new_call.keywords = node.keywords

        return new_call
Пример #8
0
def keywords_to_dict(keywords):
    """Converts a list of ast.keyword objects to a dict."""
    keys = []
    values = []
    for kw in keywords:
        keys.append(gast.Str(kw.arg))
        values.append(kw.value)
    return gast.Dict(keys=keys, values=values)
Пример #9
0
  def visit_FunctionDef(self, node):
    """Intercepts function definitions.

    Converts function definitions to the corresponding `ProgramBuilder.function`
    construction.

    Args:
      node: An `ast.AST` node representing the function to convert.

    Returns:
      node: An updated node, representing the result.

    Raises:
      ValueError: If the input node does not adhere to the restrictions,
        e.g., failing to have a `return` statement at the end.
    """
    # Check input form
    return_node = node.body[-1]
    if not isinstance(return_node, gast.Return):
      msg = 'Last node in function body should be Return, not {}.'
      raise ValueError(msg.format(return_node))

    # Convert all args to _tfp_autobatching_context_.param()
    local_declarations = []
    for arg in node.args.args:
      # print('Creating param declaration for', arg, arg.id, type(arg.id))
      local_declarations.append(templates.replace(
          'target = _tfp_autobatching_context_.param(name=target_name)',
          target=arg.id,
          target_name=gast.Str(arg.id))[0])

    # Visit the content of the function
    node = self.generic_visit(node)

    # Prepend the declarations
    node.body = local_declarations + node.body

    # Convert the function into a
    # `with _tfp_autobatching_context_.define_function()` block.
    # Wrap the `with` block into a function so additional information (namely,
    # the auto-batching `ProgramBuilder` and the `instruction.Function`s that
    # may be called in the body) can be passed in through regular Python
    # variable references.
    callable_function_names = [gast.Name(n, ctx=gast.Store(), annotation=None)
                               for n in self.known_functions]
    node = templates.replace(
        '''
        def func(_tfp_autobatching_context_,
                 _tfp_autobatching_available_functions_):
          names = _tfp_autobatching_available_functions_
          with _tfp_autobatching_context_.define_function(func):
            body
          return func''',
        func=node.name,
        names=gast.List(callable_function_names, ctx=gast.Store()),
        body=node.body)[0]

    return node
Пример #10
0
 def _create_undefined_assigns(self, undefined_symbols):
     assignments = []
     for s in undefined_symbols:
         template = '''
     var = ag__.Undefined(symbol_name)
   '''
         assignments += templates.replace(template,
                                          var=s,
                                          symbol_name=gast.Str(s.ssf()))
     return assignments
Пример #11
0
def to_ast(value):
    """
    Turn a value into ast expression.

    >>> a = 1
    >>> print ast.dump(to_ast(a))
    Num(n=1)
    >>> a = [1, 2, 3]
    >>> print ast.dump(to_ast(a))
    List(elts=[Num(n=1), Num(n=2), Num(n=3)], ctx=Load())
    """
    numpy_type = (numpy.float64, numpy.float32, numpy.float16, numpy.complex_,
                  numpy.complex64, numpy.complex128, numpy.float_, numpy.uint8,
                  numpy.uint16, numpy.uint32, numpy.uint64, numpy.int8,
                  numpy.int16, numpy.int32, numpy.int64, numpy.intp,
                  numpy.intc, numpy.int_, numpy.bool_)
    itertools_t = [
        getattr(itertools, fun) for fun in dir(itertools)
        if isinstance(getattr(itertools, fun), type)
    ]
    unfolded_type = (types.BuiltinFunctionType, types.BuiltinMethodType,
                     numpy.ufunc, type(list.append), BaseException,
                     types.GeneratorType) + tuple(itertools_t)
    if sys.version_info.major == 2:
        unfolded_type += (types.FunctionType, types.FileType, types.TypeType,
                          types.XRangeType)
    else:
        unfolded_type += type, range, type(numpy.array2string)

    if isinstance(value, (type(None), bool)):
        return ast.Attribute(ast.Name('__builtin__', ast.Load(), None),
                             str(value), ast.Load())
    elif isinstance(value, numpy_type):
        return to_ast(numpy.asscalar(value))
    elif isinstance(value, (int, long, float, complex)):
        return ast.Num(value)
    elif isinstance(value, str):
        return ast.Str(value)
    elif isinstance(value, (list, tuple, set, dict, numpy.ndarray)):
        return size_container_folding(value)
    elif hasattr(value, "__module__") and value.__module__ == "__builtin__":
        # TODO Can be done the same way for others modules
        return builtin_folding(value)
    elif isinstance(value, unfolded_type):
        raise ToNotEval()
    elif value in numpy_type:
        raise ToNotEval()
    # only meaningful for python3
    elif isinstance(value, (filter, map, zip)):
        return to_ast(list(value))
    else:
        raise ConversionError()
    def _create_loop_options(self, node):
        if not anno.hasanno(node, anno.Basic.DIRECTIVES):
            return gast.Dict([], [])

        loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES)
        if directives.set_loop_options not in loop_directives:
            return gast.Dict([], [])

        opts_dict = loop_directives[directives.set_loop_options]
        str_keys, values = zip(*opts_dict.items())
        keys = [gast.Str(s) for s in str_keys]
        values = list(values)  # ast and gast don't play well with tuples.
        return gast.Dict(keys, values)
Пример #13
0
    def visit_FunctionDef(self, node):
        self.state[_Function].enter()
        scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)

        function_context_name = self.ctx.namer.new_symbol(
            'fscope', scope.referenced)
        self.state[_Function].context_name = function_context_name
        anno.setanno(node, 'function_context_name', function_context_name)

        node = self.generic_visit(node)

        docstring_node = None
        if node.body:
            first_statement = node.body[0]
            if (isinstance(first_statement, gast.Expr)
                    and isinstance(first_statement.value, gast.Str)):
                docstring_node = first_statement
                node.body = node.body[1:]

        template = """
      with ag__.FunctionScope(
      function_name, context_name, options) as function_context:
        body
    """
        wrapped_body = templates.replace(
            template,
            function_name=gast.Str(node.name),
            context_name=gast.Str(function_context_name),
            options=self.ctx.program.options.to_ast(),
            function_context=function_context_name,
            body=node.body)

        if docstring_node is not None:
            wrapped_body = [docstring_node] + wrapped_body

        node.body = wrapped_body

        self.state[_Function].exit()
        return node
Пример #14
0
 def visit_FunctionDef(self, node):
     self._function_level += 1
     try:
         self.generic_visit(node)
     finally:
         self._function_level -= 1
     scope_name = node.name
     if self._function_level == 0 and self.context.owner_type is not None:
         scope_name = '{}/{}'.format(self.context.owner_type.__name__,
                                     scope_name)
     node.body = templates.replace('with tf.name_scope(scope_name): body',
                                   scope_name=gast.Str(scope_name),
                                   body=node.body)
     return node
Пример #15
0
    def visit_FunctionDef(self, node):
        self.state[_Function].enter()
        scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)

        function_context_name = self.ctx.namer.new_symbol(
            'fn_context', scope.referenced)
        self.state[_Function].context_name = function_context_name

        node = self.generic_visit(node)

        docstring_node = None
        if node.body:
            first_statement = node.body[0]
            if (isinstance(first_statement, gast.Expr)
                    and isinstance(first_statement.value, gast.Str)):
                docstring_node = first_statement
                node.body = node.body[1:]

        if self.ctx.program.options.uses(converter.Feature.NAME_SCOPES):
            use_name_scopes = parser.parse_expression('True')
            scope_name = gast.Str(self._sanitize(node.name))
        else:
            use_name_scopes = parser.parse_expression('False')
            scope_name = parser.parse_expression('None')

        use_auto_deps = parser.parse_expression(
            str(
                self.ctx.program.options.uses(
                    converter.Feature.AUTO_CONTROL_DEPS)))

        template = """
      with ag__.FunctionScope(
          use_name_scopes, scope_name, use_auto_deps) as function_context_name:
        body
    """
        wrapped_body = templates.replace(
            template,
            use_name_scopes=use_name_scopes,
            scope_name=scope_name,
            use_auto_deps=use_auto_deps,
            function_context_name=function_context_name,
            body=node.body)

        if docstring_node is not None:
            wrapped_body = [docstring_node] + wrapped_body

        node.body = wrapped_body

        self.state[_Function].exit()
        return node
Пример #16
0
  def visit_Assert(self, node):
    self.generic_visit(node)

    # Note: The lone tf.Assert call will be wrapped with control_dependencies
    # by side_effect_guards.
    template = """
      ag__.assert_stmt(test, lambda: msg)
    """

    if node.msg is None:
      return templates.replace(
          template, test=node.test, msg=gast.Str('Assertion error'))
    elif isinstance(node.msg, gast.Str):
      return templates.replace(template, test=node.test, msg=node.msg)
    else:
      raise NotImplementedError('can only convert string messages for now.')
Пример #17
0
def create_grad(node, namer, tangent=False):
    """Given a variable, create a variable for the gradient.

  Args:
    node: A node to create a gradient for, can be a normal variable (`x`) or a
        subscript (`x[i]`).
    namer: The namer object which will determine the name to use for the
        gradient.
    tangent: Whether a tangent (instead of adjoint) is created.

  Returns:
    node: A node representing the gradient with the correct name e.g. the
        gradient of `x[i]` is `dx[i]`.

        Note that this returns an invalid node, with the `ctx` attribute
        missing. It is assumed that this attribute is filled in later.

        Node has an `adjoint_var` annotation referring to the node it is an
        adjoint of.
  """
    if not isinstance(node, (gast.Subscript, gast.Name, gast.Str)):
        raise TypeError

    if anno.hasanno(node, 'temp_var'):
        return create_grad(anno.getanno(node, 'temp_var'), namer, tangent)

    def _name_grad(node):
        if not isinstance(node, gast.Name):
            raise TypeError
        varname = node.id
        name = namer.grad(varname, tangent)
        grad_node = gast.Name(id=name, ctx=None, annotation=None)
        anno.setanno(grad_node, 'adjoint_var', node)
        return grad_node

    if isinstance(node, gast.Subscript):
        grad_node = create_grad(node.value, namer, tangent=tangent)
        grad_node.ctx = gast.Load()
        return gast.Subscript(value=grad_node, slice=node.slice, ctx=None)
    elif isinstance(node, gast.Str):
        grad_node = create_grad(gast.Name(id=node.s, ctx=None,
                                          annotation=None),
                                namer,
                                tangent=tangent)
        return gast.Str(grad_node.id)
    else:
        return _name_grad(node)
Пример #18
0
    def ast(self):
        # The caller must adjust the context appropriately.
        if self.has_subscript():
            return gast.Subscript(self.parent.ast(),
                                  gast.Index(self.qn[-1].ast()), None)
        if self.has_attr():
            return gast.Attribute(self.parent.ast(), self.qn[-1], None)

        base = self.qn[0]
        if isinstance(base, str):
            return gast.Name(base, None, None)
        elif isinstance(base, StringLiteral):
            return gast.Str(base.value)
        elif isinstance(base, NumberLiteral):
            return gast.Num(base.value)
        else:
            assert False, ('the constructor should prevent types other than '
                           'str, StringLiteral and NumberLiteral')
Пример #19
0
    def test_ast_to_source(self):
        node = gast.If(
            test=gast.Num(1),
            body=[
                gast.Assign(targets=[gast.Name('a', gast.Store(), None)],
                            value=gast.Name('b', gast.Load(), None))
            ],
            orelse=[
                gast.Assign(targets=[gast.Name('a', gast.Store(), None)],
                            value=gast.Str('c'))
            ])

        source = compiler.ast_to_source(node, indentation='  ')
        self.assertEqual(
            textwrap.dedent("""
            if 1:
              a = b
            else:
              a = 'c'
        """).strip(), source.strip())
Пример #20
0
    def dispatch(self, tree):
        """Dispatcher function, dispatching tree type T to method _T."""
        # display omp directive in python dump
        for omp in metadata.get(tree, openmp.OMPDirective):
            deps = list()
            for dep in omp.deps:
                old_file = self.f
                self.f = io.StringIO()
                self.dispatch(dep)
                deps.append(self.f.getvalue())
                self.f = old_file
            directive = omp.s.format(*deps)
            self._Expr(ast.Expr(ast.Str(s=directive)))

        if isinstance(tree, list):
            for t in tree:
                self.dispatch(t)
            return
        meth = getattr(self, "_" + tree.__class__.__name__)
        meth(tree)
Пример #21
0
def to_ast(value):
    """
    Turn a value into ast expression.

    >>> a = 1
    >>> print(ast.dump(to_ast(a)))
    Num(n=1)
    >>> a = [1, 2, 3]
    >>> print(ast.dump(to_ast(a)))
    List(elts=[Num(n=1), Num(n=2), Num(n=3)], ctx=Load())
    """

    if isinstance(value, (type(None), bool)):
        return builtin_folding(value)
    if sys.version_info[0] == 2 and isinstance(value, long):
        from pythran.syntax import PythranSyntaxError
        raise PythranSyntaxError("constant folding results in big int")
    if any(value is t for t in (bool, int, float)):
        iinfo = np.iinfo(int)
        if isinstance(value, int) and not (iinfo.min <= value <= iinfo.max):
            from pythran.syntax import PythranSyntaxError
            raise PythranSyntaxError("constant folding results in big int")
        return builtin_folding(value)
    elif isinstance(value, np.generic):
        return to_ast(np.asscalar(value))
    elif isinstance(value, numbers.Number):
        return ast.Num(value)
    elif isinstance(value, str):
        return ast.Str(value)
    elif isinstance(value, (list, tuple, set, dict, np.ndarray)):
        return size_container_folding(value)
    elif hasattr(value, "__module__") and value.__module__ == "__builtin__":
        # TODO Can be done the same way for others modules
        return builtin_folding(value)
    # only meaningful for python3
    elif sys.version_info.major == 3:
        if isinstance(value, (filter, map, zip)):
            return to_ast(list(value))
    raise ToNotEval()
Пример #22
0
  def visit_FunctionDef(self, node):
    node = self.generic_visit(node)

    unscoped_body = []
    scoped_body = node.body
    if scoped_body:
      first = scoped_body[0]
      if isinstance(first, gast.Expr) and isinstance(first.value, gast.Str):
        # Skip any docstring.
        unscoped_body = scoped_body[:1]
        scoped_body = scoped_body[1:]

    template = """
      with tf.name_scope(scope_name):
        body
    """
    scoped_body = templates.replace(
        template,
        scope_name=gast.Str(self._name_for_current_scope()),
        body=scoped_body)
    node.body = unscoped_body + scoped_body
    return node
Пример #23
0
 def _insert_dynamic_conversion(self, node):
   """Inlines a dynamic conversion for a dynamic function."""
   # TODO(mdan): Pass information on the statically compiled functions.
   # Having access to the statically compiled functions can help avoid
   # unnecessary compilation.
   # For example, this would lead to function `a` being compiled twice:
   #
   #   def a():
   #     v = b
   #     b()
   #   def b():
   #     a()
   #
   # This is really a problem with recursive calls, which currently can
   # only be gated by a static condition, and should be rare.
   # TODO(mdan): It probably makes sense to use dynamic conversion every time.
   # Before we could convert all the time though, we'd need a reasonable
   # caching mechanism.
   template = """
     ag__.converted_call(func, owner, options, args)
   """
   if isinstance(node.func, gast.Attribute):
     func = gast.Str(node.func.attr)
     owner = node.func.value
   else:
     func = node.func
     owner = parser.parse_expression('None')
   new_call = templates.replace_as_expression(
       template,
       func=func,
       owner=owner,
       options=self.ctx.program.options.to_ast(
           self.ctx.info.namespace,
           internal_convert_user_code=self.ctx.program.options.recursive),
       args=node.args)
   # TODO(mdan): Improve the template mechanism to better support this.
   new_call.keywords = node.keywords
   return new_call
Пример #24
0
    def visit_FunctionDef(self, node):
        node = self.generic_visit(node)

        final_body = []
        indented_body = node.body
        if node.body:
            first_statement = node.body[0]
            # Skip the docstring, if any.
            if (isinstance(first_statement, gast.Expr)
                    and isinstance(first_statement.value, gast.Str)):
                indented_body = indented_body[1:]
                final_body.append(first_statement)

        template = """
      with ag__.function_scope(scope_name):
        body
    """
        scoped_body = templates.replace(template,
                                        scope_name=gast.Str(
                                            self._name_for_current_scope()),
                                        body=indented_body)
        final_body.extend(scoped_body)
        node.body = final_body
        return node
    def visit_While(self, node):
        node = self.generic_visit(node)

        (basic_loop_vars, composite_loop_vars, reserved_symbols,
         possibly_undefs) = self._get_loop_vars(
             node,
             anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified)
        loop_vars, loop_vars_ast_tuple = self._loop_var_constructs(
            basic_loop_vars)

        state_getter_name = self.ctx.namer.new_symbol('get_state',
                                                      reserved_symbols)
        state_setter_name = self.ctx.namer.new_symbol('set_state',
                                                      reserved_symbols)
        state_functions = self._create_state_functions(composite_loop_vars,
                                                       state_getter_name,
                                                       state_setter_name)

        basic_symbol_names = tuple(
            gast.Str(str(symbol)) for symbol in basic_loop_vars)
        composite_symbol_names = tuple(
            gast.Str(str(symbol)) for symbol in composite_loop_vars)

        opts = self._create_loop_options(node)

        # TODO(mdan): Use a single template.
        # If the body and test functions took a single tuple for loop_vars, instead
        # of *loop_vars, then a single template could be used.
        if loop_vars:
            template = """
        state_functions
        def body_name(loop_vars):
          body
          return loop_vars,
        def test_name(loop_vars):
          return test
        loop_vars_ast_tuple = ag__.while_stmt(
            test_name,
            body_name,
            state_getter_name,
            state_setter_name,
            (loop_vars,),
            (basic_symbol_names,),
            (composite_symbol_names,),
            opts)
      """
            node = templates.replace(
                template,
                loop_vars=loop_vars,
                loop_vars_ast_tuple=loop_vars_ast_tuple,
                test_name=self.ctx.namer.new_symbol('loop_test',
                                                    reserved_symbols),
                test=node.test,
                body_name=self.ctx.namer.new_symbol('loop_body',
                                                    reserved_symbols),
                body=node.body,
                state_functions=state_functions,
                state_getter_name=state_getter_name,
                state_setter_name=state_setter_name,
                basic_symbol_names=basic_symbol_names,
                composite_symbol_names=composite_symbol_names,
                opts=opts)
        else:
            template = """
        state_functions
        def body_name():
          body
          return ()
        def test_name():
          return test
        ag__.while_stmt(
            test_name,
            body_name,
            state_getter_name,
            state_setter_name,
            (),
            (),
            (composite_symbol_names,),
            opts)
      """
            node = templates.replace(
                template,
                test_name=self.ctx.namer.new_symbol('loop_test',
                                                    reserved_symbols),
                test=node.test,
                body_name=self.ctx.namer.new_symbol('loop_body',
                                                    reserved_symbols),
                body=node.body,
                state_functions=state_functions,
                state_getter_name=state_getter_name,
                state_setter_name=state_setter_name,
                composite_symbol_names=composite_symbol_names,
                opts=opts)

        undefined_assigns = self._create_undefined_assigns(possibly_undefs)
        return undefined_assigns + node
Пример #26
0
  def visit_Call(self, node):
    # TODO(mdan): Refactor converted_call as a 'Call' operator.

    # Calls to the internal 'ag__' module are never converted (though their
    # arguments might be).
    full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
    if full_name.startswith('ag__.'):
      return self.generic_visit(node)
    if (full_name == 'print' and
        not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
      return self.generic_visit(node)

    if isinstance(node.func, gast.Attribute):
      func = gast.Str(node.func.attr)
      owner = node.func.value
    else:
      func = node.func
      owner = parser.parse_expression('None')

    starred_arg = None
    normal_args = []
    for a in node.args:
      if isinstance(a, gast.Starred):
        assert starred_arg is None, 'Multiple *args should be impossible.'
        starred_arg = a
      else:
        normal_args.append(a)
    if starred_arg is None:
      args = templates.replace_as_expression('(args,)', args=normal_args)
    else:
      args = templates.replace_as_expression(
          '(args,) + tuple(stararg)',
          stararg=starred_arg.value,
          args=normal_args)

    kwargs_arg = None
    normal_keywords = []
    for k in node.keywords:
      if k.arg is None:
        assert kwargs_arg is None, 'Multiple **kwargs should be impossible.'
        kwargs_arg = k
      else:
        normal_keywords.append(k)
    if kwargs_arg is None:
      kwargs = ast_util.keywords_to_dict(normal_keywords)
    else:
      kwargs = templates.replace_as_expression(
          'dict(kwargs, **keywords)',
          kwargs=kwargs_arg.value,
          keywords=ast_util.keywords_to_dict(normal_keywords))

    template = """
      ag__.converted_call(func, owner, options, args, kwargs)
    """
    new_call = templates.replace_as_expression(
        template,
        func=func,
        owner=owner,
        options=self.ctx.program.options.to_ast(
            self.ctx,
            internal_convert_user_code=self.ctx.program.options.recursive),
        args=args,
        kwargs=kwargs)

    return new_call
Пример #27
0
                          ast.BinOp(left=Placeholder(0),
                                    op=ast.Sub(),
                                    right=ast.Num(n=1)),
                          ast.Num(n=-1),
                          ast.Num(n=-1)
                      ],
                      keywords=[])),

    # X * X => X ** 2
    (ast.BinOp(left=Placeholder(0), op=ast.Mult(), right=Placeholder(0)),
     lambda: ast.BinOp(left=Placeholder(0), op=ast.Pow(), right=ast.Num(n=2))),

    # a + "..." + b => "...".join((a, b))
    (ast.BinOp(left=ast.BinOp(left=Placeholder(0),
                              op=ast.Add(),
                              right=ast.Str(Placeholder(1))),
               op=ast.Add(),
               right=Placeholder(2)),
     lambda: ast.Call(func=ast.Attribute(
         ast.Attribute(ast.Name('__builtin__', ast.Load(), None), 'str',
                       ast.Load()), 'join', ast.Load()),
                      args=[
                          ast.Str(Placeholder(1)),
                          ast.Tuple([Placeholder(0),
                                     Placeholder(2)], ast.Load())
                      ],
                      keywords=[])),
]


class PlaceholderReplace(Transformation):
    def visit_If(self, node):
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
        defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
        live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)

        # Note: this information needs to be extracted before the body conversion
        # that happens in the call to generic_visit below, because the conversion
        # generates nodes that lack static analysis annotations.
        need_alias_in_body = self._determine_aliased_symbols(
            body_scope, defined_in, node.body)
        need_alias_in_orelse = self._determine_aliased_symbols(
            orelse_scope, defined_in, node.orelse)

        node = self.generic_visit(node)

        modified_in_cond = body_scope.modified | orelse_scope.modified
        returned_from_cond = set()
        composites = set()
        for s in modified_in_cond:
            if s in live_out and not s.is_composite():
                returned_from_cond.add(s)
            if s.is_composite():
                # Special treatment for compound objects, always return them.
                # This allows special handling within the if_stmt itself.
                # For example, in TensorFlow we need to restore the state of composite
                # symbols to ensure that only effects from the executed branch are seen.
                composites.add(s)

        created_in_body = body_scope.modified & returned_from_cond - defined_in
        created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in

        basic_created_in_body = tuple(s for s in created_in_body
                                      if not s.is_composite())
        basic_created_in_orelse = tuple(s for s in created_in_orelse
                                        if not s.is_composite())

        # These variables are defined only in a single branch. This is fine in
        # Python so we pass them through. Another backend, e.g. Tensorflow, may need
        # to handle these cases specially or throw an Error.
        possibly_undefined = (set(basic_created_in_body)
                              ^ set(basic_created_in_orelse))

        # Alias the closure variables inside the conditional functions, to allow
        # the functions access to the respective variables.
        # We will alias variables independently for body and orelse scope,
        # because different branches might write different variables.
        aliased_body_orig_names = tuple(need_alias_in_body)
        aliased_orelse_orig_names = tuple(need_alias_in_orelse)
        aliased_body_new_names = tuple(
            self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
            for s in aliased_body_orig_names)
        aliased_orelse_new_names = tuple(
            self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
            for s in aliased_orelse_orig_names)

        alias_body_map = dict(
            zip(aliased_body_orig_names, aliased_body_new_names))
        alias_orelse_map = dict(
            zip(aliased_orelse_orig_names, aliased_orelse_new_names))

        node_body = ast_util.rename_symbols(node.body, alias_body_map)
        node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

        cond_var_name = self.ctx.namer.new_symbol('cond',
                                                  body_scope.referenced)
        body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
        orelse_name = self.ctx.namer.new_symbol('if_false',
                                                orelse_scope.referenced)
        all_referenced = body_scope.referenced | orelse_scope.referenced
        state_getter_name = self.ctx.namer.new_symbol('get_state',
                                                      all_referenced)
        state_setter_name = self.ctx.namer.new_symbol('set_state',
                                                      all_referenced)

        returned_from_cond = tuple(returned_from_cond)
        composites = tuple(composites)

        if returned_from_cond:
            if len(returned_from_cond) == 1:
                cond_results = returned_from_cond[0]
            else:
                cond_results = gast.Tuple(
                    [s.ast() for s in returned_from_cond], None)

            returned_from_body = tuple(
                alias_body_map[s] if s in need_alias_in_body else s
                for s in returned_from_cond)
            returned_from_orelse = tuple(
                alias_orelse_map[s] if s in need_alias_in_orelse else s
                for s in returned_from_cond)

        else:
            # When the cond would return no value, we leave the cond called without
            # results. That in turn should trigger the side effect guards. The
            # branch functions will return a dummy value that ensures cond
            # actually has some return value as well.
            cond_results = None
            # TODO(mdan): Replace with None once side_effect_guards is retired.
            returned_from_body = (templates.replace_as_expression(
                'ag__.match_staging_level(1, cond_var_name)',
                cond_var_name=cond_var_name), )
            returned_from_orelse = (templates.replace_as_expression(
                'ag__.match_staging_level(1, cond_var_name)',
                cond_var_name=cond_var_name), )

        cond_assign = self.create_assignment(cond_var_name, node.test)
        body_def = self._create_cond_branch(
            body_name,
            aliased_orig_names=aliased_body_orig_names,
            aliased_new_names=aliased_body_new_names,
            body=node_body,
            returns=returned_from_body)
        orelse_def = self._create_cond_branch(
            orelse_name,
            aliased_orig_names=aliased_orelse_orig_names,
            aliased_new_names=aliased_orelse_new_names,
            body=node_orelse,
            returns=returned_from_orelse)
        undefined_assigns = self._create_undefined_assigns(possibly_undefined)
        composite_defs = self._create_state_functions(composites,
                                                      state_getter_name,
                                                      state_setter_name)

        basic_symbol_names = tuple(
            gast.Str(str(symbol)) for symbol in returned_from_cond)
        composite_symbol_names = tuple(
            gast.Str(str(symbol)) for symbol in composites)

        cond_expr = self._create_cond_expr(cond_results, cond_var_name,
                                           body_name, orelse_name,
                                           state_getter_name,
                                           state_setter_name,
                                           basic_symbol_names,
                                           composite_symbol_names)

        if_ast = (undefined_assigns + composite_defs + body_def + orelse_def +
                  cond_assign + cond_expr)
        return if_ast
    def visit_For(self, node):
        node = self.generic_visit(node)

        (basic_loop_vars, composite_loop_vars,
         reserved_symbols, possibly_undefs) = self._get_loop_vars(
             node,
             (anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified
              | anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE).modified))
        loop_vars, loop_vars_ast_tuple = self._loop_var_constructs(
            basic_loop_vars)
        body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols)

        state_getter_name = self.ctx.namer.new_symbol('get_state',
                                                      reserved_symbols)
        state_setter_name = self.ctx.namer.new_symbol('set_state',
                                                      reserved_symbols)
        state_functions = self._create_state_functions(composite_loop_vars,
                                                       state_getter_name,
                                                       state_setter_name)

        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test_name = self.ctx.namer.new_symbol(
                'extra_test', reserved_symbols)
            template = """
        def extra_test_name(loop_vars):
          return extra_test_expr
      """
            extra_test_function = templates.replace(
                template,
                extra_test_name=extra_test_name,
                loop_vars=loop_vars,
                extra_test_expr=extra_test)
        else:
            extra_test_name = parser.parse_expression('None')
            extra_test_function = []

        # Workaround for PEP-3113
        # iterates_var holds a single variable with the iterates, which may be a
        # tuple.
        iterates_var_name = self.ctx.namer.new_symbol('iterates',
                                                      reserved_symbols)
        template = """
      iterates = iterates_var_name
    """
        iterate_expansion = templates.replace(
            template,
            iterates=node.target,
            iterates_var_name=iterates_var_name)

        undefined_assigns = self._create_undefined_assigns(possibly_undefs)

        basic_symbol_names = tuple(
            gast.Str(str(symbol)) for symbol in basic_loop_vars)
        composite_symbol_names = tuple(
            gast.Str(str(symbol)) for symbol in composite_loop_vars)

        opts = self._create_loop_options(node)

        # TODO(mdan): Use a single template.
        # If the body and test functions took a single tuple for loop_vars, instead
        # of *loop_vars, then a single template could be used.
        if loop_vars:
            template = """
        undefined_assigns
        state_functions
        def body_name(iterates_var_name, loop_vars):
          iterate_expansion
          body
          return loop_vars,
        extra_test_function
        loop_vars_ast_tuple = ag__.for_stmt(
            iter_,
            extra_test_name,
            body_name,
            state_getter_name,
            state_setter_name,
            (loop_vars,),
            (basic_symbol_names,),
            (composite_symbol_names,),
            opts)
      """
            return templates.replace(
                template,
                undefined_assigns=undefined_assigns,
                loop_vars=loop_vars,
                loop_vars_ast_tuple=loop_vars_ast_tuple,
                iter_=node.iter,
                iterate_expansion=iterate_expansion,
                iterates_var_name=iterates_var_name,
                extra_test_name=extra_test_name,
                extra_test_function=extra_test_function,
                body_name=body_name,
                body=node.body,
                state_functions=state_functions,
                state_getter_name=state_getter_name,
                state_setter_name=state_setter_name,
                basic_symbol_names=basic_symbol_names,
                composite_symbol_names=composite_symbol_names,
                opts=opts)
        else:
            template = """
        undefined_assigns
        state_functions
        def body_name(iterates_var_name):
          iterate_expansion
          body
          return ()
        extra_test_function
        ag__.for_stmt(
            iter_,
            extra_test_name,
            body_name,
            state_getter_name,
            state_setter_name,
            (),
            (),
            (composite_symbol_names,),
            opts)
      """
            return templates.replace(
                template,
                undefined_assigns=undefined_assigns,
                iter_=node.iter,
                iterate_expansion=iterate_expansion,
                iterates_var_name=iterates_var_name,
                extra_test_name=extra_test_name,
                extra_test_function=extra_test_function,
                body_name=body_name,
                body=node.body,
                state_functions=state_functions,
                state_getter_name=state_getter_name,
                state_setter_name=state_setter_name,
                composite_symbol_names=composite_symbol_names,
                opts=opts)
Пример #30
0
     lambda: ast.Call(
         func=ast.Attribute(value=ast.Name(id='__builtin__',
                                           ctx=ast.Load(), annotation=None),
                            attr="xrange", ctx=ast.Load()),
         args=[ast.BinOp(left=Placeholder(0), op=ast.Sub(),
                         right=ast.Num(n=1)),
               ast.Num(n=-1),
               ast.Num(n=-1)],
         keywords=[])),
    # X * X => X ** 2
    (ast.BinOp(left=Placeholder(0), op=ast.Mult(), right=Placeholder(0)),
     lambda: ast.BinOp(left=Placeholder(0), op=ast.Pow(), right=ast.Num(n=2))),
    # a + "..." + b => "...".join((a, b))
    (ast.BinOp(left=ast.BinOp(left=Placeholder(0),
                              op=ast.Add(),
                              right=ast.Str(Placeholder(1))),
               op=ast.Add(),
               right=Placeholder(2)),
     lambda: ast.Call(func=ast.Attribute(
         ast.Attribute(
             ast.Name('__builtin__', ast.Load(), None),
             'str',
             ast.Load()),
         'join', ast.Load()),
         args=[ast.Str(Placeholder(1)),
               ast.Tuple([Placeholder(0), Placeholder(2)], ast.Load())],
         keywords=[])),
]


class PlaceholderReplace(Transformation):