Esempio n. 1
0
 def _create_cond_expr(self, results, test, body_name, orelse_name,
                       state_getter_name,
                       state_setter_name):
   if results is not None:
     template = """
       results = ag__.if_stmt(test, body_name, orelse_name,
                              state_getter_name, state_setter_name)
     """
     return templates.replace(
         template,
         test=test,
         results=results,
         body_name=body_name,
         orelse_name=orelse_name,
         state_getter_name=state_getter_name,
         state_setter_name=state_setter_name)
   else:
     template = """
       ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name)
     """
     return templates.replace(
         template,
         test=test,
         body_name=body_name,
         orelse_name=orelse_name,
         getter_name=state_getter_name,
         setter_name=state_setter_name)
Esempio n. 2
0
  def _create_cond_branch(self, body_name, aliased_orig_names,
                          aliased_new_names, body, returns):
    if len(returns) == 1:
      template = """
        return retval
      """
      return_stmt = templates.replace(template, retval=returns[0])
    else:
      template = """
        return (retvals,)
      """
      return_stmt = templates.replace(template, retvals=returns)

    if aliased_orig_names:
      template = """
        def body_name():
          aliased_new_names, = aliased_orig_names,
          body
          return_stmt
      """
      return templates.replace(
          template,
          body_name=body_name,
          body=body,
          aliased_orig_names=aliased_orig_names,
          aliased_new_names=aliased_new_names,
          return_stmt=return_stmt)
    else:
      template = """
        def body_name():
          body
          return_stmt
      """
      return templates.replace(
          template, body_name=body_name, body=body, return_stmt=return_stmt)
Esempio n. 3
0
  def _create_state_functions(self, composites,
                              state_getter_name, state_setter_name):
    if composites:
      composite_tuple = tuple(composites)
      template = """
        def state_getter_name():
          return composite_tuple,
        def state_setter_name(vals):
          composite_tuple, = vals
      """
      node = templates.replace(
          template,
          state_getter_name=state_getter_name,
          state_setter_name=state_setter_name,
          composite_tuple=composite_tuple)
    else:
      template = """
        def state_getter_name():
          return ()
        def state_setter_name(_):
          pass
        """
      node = templates.replace(
          template,
          state_getter_name=state_getter_name,
          state_setter_name=state_setter_name)

    return node
Esempio n. 4
0
  def visit_For(self, node):
    self.generic_visit(node)

    loop_state, reserved_symbols = self._get_loop_state(node)
    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
        loop_state, reserved_symbols)
    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    if loop_state:
      template = """
        def extra_test_name(state_ssf):
          return extra_test_expr
        def body_name(loop_vars, state_ssf):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return state_ssf,
        state_ast_tuple = ag__.for_stmt(
            iter_, extra_test_name, body_name, (state,))
      """
      node = templates.replace(
          template,
          state=loop_state,
          state_ssf=state_ssf,
          state_ast_tuple=state_ast_tuple,
          iter_=node.iter,
          iterate=node.target,
          extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                    reserved_symbols),
          extra_test_expr=extra_test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)
    else:
      template = """
        def extra_test_name():
          return extra_test_expr
        def body_name(loop_vars):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return ()
        ag__.for_stmt(iter_, extra_test_name, body_name, ())
      """
      node = templates.replace(
          template,
          iter_=node.iter,
          iterate=node.target,
          extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                    reserved_symbols),
          extra_test_expr=extra_test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)

    return node
Esempio n. 5
0
  def test_replace_name_mixed_attr_subscript(self, expression_source):
    template = 'foo = bar'
    replacement = _parse_with_unset_ctx(expression_source)

    target_node = templates.replace(template, foo=replacement)[0].targets[0]
    self.assertExpectedCtxSet(target_node, gast.Store)

    value_node = templates.replace(template, bar=replacement)[0].value
    self.assertExpectedCtxSet(value_node, gast.Load)
Esempio n. 6
0
  def visit_With(self, node):
    # Depth-first traversal of syntax
    node = self.generic_visit(node)

    # If the with statement returns, lift the return
    if isinstance(node.body[-1], gast.Return):
      node.body[-1] = templates.replace(
          'a = b', a=self.common_return_name, b=node.body[-1].value)[0]
      return_node = templates.replace('return a', a=self.common_return_name)[0]
      node = self.generic_visit(node)
      self.changes_made = True
      return [node, return_node]
    else:
      return node
Esempio n. 7
0
  def test_replace_attribute(self):
    template = """
      def test_fn(a):
        return a.foo
    """

    node = templates.replace(template, foo='b')[0]
    result, _ = compiler.ast_to_object(node)
    mod = imp.new_module('test')
    mod.b = 3
    self.assertEquals(3, result.test_fn(mod))

    with self.assertRaises(ValueError):
      templates.replace(template, foo=1)
Esempio n. 8
0
 def _for_loop_with_extra_test(self, loop_state, state_ssf, state_ast_tuple,
                               original_node, extra_test_name, extra_test,
                               body_name, loop_body, ssf_map):
   target_nodes = ast_util.rename_symbols(original_node.target, ssf_map)
   template = """
     def extra_test_name(state_ssf):
       return extra_test_expr
     def body_name(loop_vars, state_ssf):
       # Workaround for PEP-3113
       target = loop_vars
       body
       return state_ssf,
     state_ast_tuple = ag__.for_stmt(
         iter_, extra_test_name, body_name, (state,))
   """
   return templates.replace(
       template,
       state=loop_state,
       state_ssf=state_ssf,
       state_ast_tuple=state_ast_tuple,
       iter_=original_node.iter,
       target=target_nodes,
       extra_test_name=extra_test_name,
       extra_test_expr=extra_test,
       body_name=body_name,
       body=loop_body)
 def visit_Continue(self, node):
   self.set_local(CONTINUE_USED, True)
   template = """
     var_name = tf.constant(True)
   """
   return templates.replace(
       template, var_name=self.get_local(CONTROL_VAR_NAME))
Esempio n. 10
0
  def visit_Assert(self, node):
    self.generic_visit(node)

    # Note: The lone tf.Assert call will be wrapped with control_dependencies
    # by side_effect_guards.
    template = """
      tf.Assert(test, (msg,))
    """

    if node.msg is None:
      return templates.replace(
          template, test=node.test, msg=gast.Str('Assertion error'))
    elif isinstance(node.msg, gast.Str):
      return templates.replace(template, test=node.test, msg=node.msg)
    else:
      raise NotImplementedError('can only convert string messages for now.')
    def visit_While(self, node):
        original_node = node
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

        node.test = self.visit(node.test)
        node.body, break_used = self._process_body(node.body, break_var)
        # A break in the else clause applies to the containing scope.
        node.orelse = self.visit_block(node.orelse)

        if break_used:
            # Python's else clause only triggers if the loop exited cleanly (e.g.
            # break did not trigger).
            guarded_orelse = self._guard_if_present(node.orelse, break_var)

            template = """
        var_name = False
        while ag__.and_(lambda: test, lambda: ag__.not_(var_name)):
          body
        else:
          orelse
      """
            node = templates.replace(template,
                                     var_name=break_var,
                                     test=node.test,
                                     body=node.body,
                                     orelse=guarded_orelse)

            new_while_node = node[1]
            anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)

        return node
Esempio n. 12
0
 def create_assignment(self, target, expression):
     template = """
   target = expression
 """
     return templates.replace(template,
                              target=target,
                              expression=expression)
Esempio n. 13
0
    def test_replace_code_block(self):
        template = """
      def test_fn(a):
        block
        return a
    """

        class ShouldBeReplaced(object):
            pass

        node = templates.replace(
            template,
            block=[
                gast.Assign(
                    [
                        gast.Name('a',
                                  ctx=ShouldBeReplaced,
                                  annotation=None,
                                  type_comment=None)
                    ],
                    gast.BinOp(
                        gast.Name('a',
                                  ctx=ShouldBeReplaced,
                                  annotation=None,
                                  type_comment=None), gast.Add(),
                        gast.Constant(1, kind=None)),
                ),
            ] * 2)[0]
        result, _, _ = loader.load_ast(node)
        self.assertEqual(3, result.test_fn(1))
Esempio n. 14
0
    def test_replace_call_keyword(self):
        template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

        source = parser.parse_expression('f(d=3, f=5)')
        node = templates.replace(template, kws=source.keywords)[0]
        result, _, _ = loader.load_ast(node)
        self.assertEqual(9, result.test_fn())

        with self.assertRaises(ValueError):
            templates.replace(template, kws=[])
            templates.replace(template, kws=1)
 def _insert_dynamic_conversion(self, node):
     """Inlines a dynamic conversion for a dynamic function."""
     # TODO(mdan): Pass information on the statically compiled functions.
     # Having access to the statically compiled functions can help avoid
     # unnecessary compilation.
     # For example, this would lead to function `a` being compiled twice:
     #
     #   def a():
     #     v = b
     #     b()
     #   def b():
     #     a()
     #
     # This is really a problem with recursive calls, which currently can
     # only be gated by a static condition, and should be rare.
     # TODO(mdan): It probably makes sense to use dynamic conversion every time.
     # Before we could convert all the time though, we'd need a reasonable
     # caching mechanism.
     template = """
   ag__.converted_call(
       func,
       ag__.ConversionOptions.new(recursive=recursive_val),
       args)
 """
     call_expr = templates.replace(template,
                                   func=node.func,
                                   recursive_val=parser.parse_expression(
                                       str(self.ctx.program.recursive)),
                                   args=node.args)
     new_call = call_expr[0].value
     # TODO(mdan): Improve the template mechanism to better support this.
     new_call.keywords = node.keywords
     return new_call
Esempio n. 16
0
 def _insert_dynamic_conversion(self, node):
   """Inlines a dynamic conversion for a dynamic function."""
   # TODO(mdan): Pass information on the statically compiled functions.
   # Having access to the statically compiled functions can help avoid
   # unnecessary compilation.
   # For example, this would lead to function `a` being compiled twice:
   #
   #   def a():
   #     v = b
   #     b()
   #   def b():
   #     a()
   #
   # This is really a problem with recursive calls, which currently can
   # only be gated by a static condition, and should be rare.
   # TODO(mdan): It probably makes sense to use dynamic conversion every time.
   # Before we could convert all the time though, we'd need a reasonable
   # caching mechanism.
   template = """
     ag__.converted_call(func, options, args)
   """
   call_expr = templates.replace(
       template,
       func=node.func,
       options=self.ctx.program.options.to_ast(self.ctx.info.namespace),
       args=node.args)
   new_call = call_expr[0].value
   # TODO(mdan): Improve the template mechanism to better support this.
   new_call.keywords = node.keywords
   return new_call
Esempio n. 17
0
    def to_ast(self):
        """Returns a representation of this object as an AST node.

    The AST node encodes a constructor that would create an object with the
    same contents.

    Returns:
      ast.Node
    """
        if self == STANDARD_OPTIONS:
            return parser.parse_expression('ag__.STD')

        template = """
      ag__.ConversionOptions(
          recursive=recursive_val,
          user_requested=user_requested_val,
          optional_features=optional_features_val,
          internal_convert_user_code=internal_convert_user_code_val)
    """

        def list_of_features(values):
            return parser.parse_expression('({})'.format(', '.join(
                'ag__.{}'.format(str(v)) for v in values)))

        expr_ast = templates.replace(
            template,
            recursive_val=parser.parse_expression(str(self.recursive)),
            user_requested_val=parser.parse_expression(str(
                self.user_requested)),
            internal_convert_user_code_val=parser.parse_expression(
                str(self.internal_convert_user_code)),
            optional_features_val=list_of_features(self.optional_features))
        return expr_ast[0].value
Esempio n. 18
0
  def test_replace_call_keyword(self):
    template = """
      def test_fn():
        def f(a, d, f):
          return a + d + f
        return f(1, kws=None)
    """

    source = parser.parse_expression('f(d=3, f=5)')
    node = templates.replace(template, kws=source.keywords)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(9, result.test_fn())

    with self.assertRaises(ValueError):
      templates.replace(template, kws=[])
      templates.replace(template, kws=1)
Esempio n. 19
0
    def visit_Delete(self, node):
        node = self.generic_visit(node)

        rewrite_targets = []
        for tgt in node.targets:
            # Don't rewrite composites like `del a[0]`.
            if isinstance(tgt, gast.Name):
                rewrite_targets.append(tgt)

        if not rewrite_targets:
            return node

        results = []
        for tgt in rewrite_targets:
            template = """
        var_ = ag__.Undefined(var_name)
      """
            results.extend(
                templates.replace(template,
                                  var_=tgt,
                                  var_name=gast.Constant(tgt.id, kind=None)))
        remaining_targets = [
            n for n in node.targets if n not in rewrite_targets
        ]
        if remaining_targets:
            results.append(gast.Delete(targets=remaining_targets))

        return results
    def _rename_compilable_function(self, node):
        assert anno.hasanno(node.func, 'live_val')
        assert anno.hasanno(node.func, 'fqn')
        target_entity = anno.getanno(node.func, 'live_val')
        target_fqn = anno.getanno(node.func, 'fqn')

        if not self._should_compile(node, target_fqn):
            return node

        if anno.hasanno(node, 'is_constructor'):
            new_name = self.ctx.namer.compiled_class_name(
                target_fqn, live_entity=target_entity)
            do_rename = True
        else:
            if anno.hasanno(node.func, 'parent_type'):
                owner_type = anno.getanno(node.func, 'parent_type')
            else:
                # Fallback - not reliable.
                owner_type = inspect_utils.getmethodclass(target_entity)
            new_name, do_rename = self.ctx.namer.compiled_function_name(
                target_fqn, live_entity=target_entity, owner_type=owner_type)

        if do_rename:
            if target_entity is not None:
                if tf_inspect.ismethod(target_entity):
                    # The renaming process will transform it into a regular function.
                    # TODO(mdan): Is this complete? How does it work with nested members?
                    node.args = [node.func.value] + node.args
            node.func = templates.replace('func_name', func_name=new_name)[0]
        return node
Esempio n. 21
0
 def visit_Continue(self, node):
   self.state[_Continue].used = True
   template = """
     var_name = True
   """
   return templates.replace(
       template, var_name=self.state[_Continue].control_var_name)
Esempio n. 22
0
  def _rename_compilable_function(self, node):
    assert anno.hasanno(node.func, 'live_val')
    assert anno.hasanno(node.func, 'fqn')
    target_entity = anno.getanno(node.func, 'live_val')
    target_fqn = anno.getanno(node.func, 'fqn')

    if not self._should_compile(node, target_fqn):
      return node

    if anno.hasanno(node, 'is_constructor'):
      new_name = self.ctx.namer.compiled_class_name(
          target_fqn, live_entity=target_entity)
      do_rename = True
    else:
      if anno.hasanno(node.func, 'parent_type'):
        owner_type = anno.getanno(node.func, 'parent_type')
      else:
        # Fallback - not reliable.
        owner_type = inspect_utils.getmethodclass(target_entity)
      new_name, do_rename = self.ctx.namer.compiled_function_name(
          target_fqn, live_entity=target_entity, owner_type=owner_type)

    if do_rename:
      if target_entity is not None:
        if tf_inspect.ismethod(target_entity):
          # The renaming process will transform it into a regular function.
          # TODO(mdan): Is this complete? How does it work with nested members?
          node.args = [node.func.value] + node.args
      node.func = templates.replace('func_name', func_name=new_name)[0]
    return node
Esempio n. 23
0
  def visit_While(self, node):
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

    node.test = self.visit(node.test)
    node.body, break_used = self._process_body(node.body, break_var)
    # A break in the else clause applies to the containing scope.
    node.orelse = self.visit_block(node.orelse)

    if break_used:
      # Python's else clause only triggers if the loop exited cleanly (e.g.
      # break did not trigger).
      guarded_orelse = self._guard_if_present(node.orelse, break_var)

      template = """
        var_name = tf.constant(False)
        while test and not var_name:
          body
        else:
          orelse
      """
      node = templates.replace(
          template,
          var_name=break_var,
          test=node.test,
          body=node.body,
          orelse=guarded_orelse)

    return node
Esempio n. 24
0
    def to_ast(self, ctx, internal_convert_user_code=None):
        """Returns a representation of this object as an AST node.

    The AST node encodes a constructor that would create an object with the
    same contents.

    Args:
      ctx: EntityContext, the entity with which this AST needs to be consistent.
      internal_convert_user_code: Optional[bool], allows ovrriding the
        corresponding value.

    Returns:
      ast.Node
    """
        template = """
      constructor_name(
          recursive=recursive_val,
          verbose=verbose_val,
          strip_decorators=strip_decorators_val,
          force_conversion=force_conversion_val,
          optional_features=optional_features_val,
          internal_convert_user_code=internal_convert_user_code_val)
    """

        def as_qualified_name(o):
            name = inspect_utils.getqualifiedname(ctx.info.namespace,
                                                  o,
                                                  max_depth=1)
            if not name:
                # TODO(mdan): This needs to account for the symbols defined locally.
                name = ctx.namer.new_symbol(o.__name__, ())
                ctx.program.add_symbol(name, weakref.ref(o))
            return name

        def list_of_names(values):
            return parser.parse_expression('({})'.format(', '.join(
                tuple(as_qualified_name(v) for v in values))))

        def list_of_features(values):
            return parser.parse_expression('({})'.format(', '.join(
                'ag__.Feature.{}'.format(v) for v in Feature.__members__
                if v in values)))

        if internal_convert_user_code is not None:
            internal_convert_user_code = self.internal_convert_user_code

        expr_ast = templates.replace(
            template,
            constructor_name=parser.parse_expression(
                as_qualified_name(ConversionOptions)),
            recursive_val=parser.parse_expression(str(self.recursive)),
            verbose_val=parser.parse_expression(str(int(self.verbose))),
            strip_decorators_val=list_of_names(self._strip_decorators),
            force_conversion_val=parser.parse_expression(
                str(self.force_conversion)),
            internal_convert_user_code_val=parser.parse_expression(
                str(internal_convert_user_code)),
            optional_features_val=list_of_features(self.optional_features))
        return expr_ast[0].value
Esempio n. 25
0
    def visit_If(self, node):
        node = self.generic_visit(node)
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)

        cond_vars, undefined, nouts = self._get_block_vars(
            node, body_scope.bound | orelse_scope.bound)

        undefined_assigns = self._create_undefined_assigns(undefined)

        nonlocal_declarations = self._create_nonlocal_declarations(cond_vars)

        reserved = body_scope.referenced | orelse_scope.referenced
        state_getter_name = self.ctx.namer.new_symbol('get_state', reserved)
        state_setter_name = self.ctx.namer.new_symbol('set_state', reserved)
        state_functions = self._create_state_functions(cond_vars,
                                                       nonlocal_declarations,
                                                       state_getter_name,
                                                       state_setter_name)

        orelse_body = node.orelse
        if not orelse_body:
            orelse_body = [gast.Pass()]

        template = """
      state_functions
      def body_name():
        nonlocal_declarations
        body
      def orelse_name():
        nonlocal_declarations
        orelse
      undefined_assigns
      ag__.if_stmt(
        test,
        body_name,
        orelse_name,
        state_getter_name,
        state_setter_name,
        (symbol_names,),
        nouts)
    """
        new_nodes = templates.replace(
            template,
            body=node.body,
            body_name=self.ctx.namer.new_symbol('if_body', reserved),
            orelse=orelse_body,
            orelse_name=self.ctx.namer.new_symbol('else_body', reserved),
            nonlocal_declarations=nonlocal_declarations,
            nouts=gast.Constant(nouts, kind=None),
            state_functions=state_functions,
            state_getter_name=state_getter_name,
            state_setter_name=state_setter_name,
            symbol_names=tuple(
                gast.Constant(str(s), kind=None) for s in cond_vars),
            test=node.test,
            undefined_assigns=undefined_assigns)
        origin_info.copy_origin(node, new_nodes[-1])
        return new_nodes
Esempio n. 26
0
  def to_ast(self, ctx, internal_convert_user_code=None):
    """Returns a representation of this object as an AST node.

    The AST node encodes a constructor that would create an object with the
    same contents.

    Args:
      ctx: EntityContext, the entity with which this AST needs to be consistent.
      internal_convert_user_code: Optional[bool], allows ovrriding the
        corresponding value.

    Returns:
      ast.Node
    """
    template = """
      ag__.ConversionOptions(
          recursive=recursive_val,
          verbose=verbose_val,
          strip_decorators=strip_decorators_val,
          force_conversion=force_conversion_val,
          optional_features=optional_features_val,
          internal_convert_user_code=internal_convert_user_code_val)
    """

    def as_qualified_name(o):
      name = inspect_utils.getqualifiedname(ctx.info.namespace, o, max_depth=1)
      if not name:
        if isinstance(o, weakref.ref):
          # `o` might already be a weak reference, if this object was
          # constructed from code generated by `to_ast` itself.
          # If so, unpack it.
          o = o()
        # TODO(mdan): This needs to account for the symbols defined locally.
        name = ctx.namer.new_symbol(o.__name__, ())
        ctx.program.add_symbol(name, weakref.ref(o))
      return name

    def list_of_names(values):
      return parser.parse_expression('({})'.format(', '.join(
          tuple(as_qualified_name(v) for v in values))))

    def list_of_features(values):
      return parser.parse_expression('({})'.format(', '.join(
          'ag__.{}'.format(str(v)) for v in values)))

    if internal_convert_user_code is None:
      internal_convert_user_code = self.internal_convert_user_code

    expr_ast = templates.replace(
        template,
        recursive_val=parser.parse_expression(str(self.recursive)),
        verbose_val=parser.parse_expression(str(int(self.verbose))),
        strip_decorators_val=list_of_names(self._strip_decorators),
        force_conversion_val=parser.parse_expression(
            str(self.force_conversion)),
        internal_convert_user_code_val=parser.parse_expression(
            str(internal_convert_user_code)),
        optional_features_val=list_of_features(self.optional_features))
    return expr_ast[0].value
Esempio n. 27
0
 def _create_undefined_assigns(self, undefined_symbols):
   assignments = []
   for s in undefined_symbols:
     template = '''
       var = ag__.UNDEFINED
     '''
     assignments += templates.replace(template, var=s)
   return assignments
Esempio n. 28
0
 def visit_Continue(self, node):
     self.state[_Continue].used = True
     self.state[_Block].reset_guard_state()
     template = """
   var_name = tf.constant(True)
 """
     return templates.replace(template,
                              var_name=self.get_local(CONTROL_VAR_NAME))
Esempio n. 29
0
 def _create_cond_expr(self, results, test, body_name, orelse_name):
   if results is not None:
     template = """
       results = ag__.utils.run_cond(test, body_name, orelse_name)
     """
     return templates.replace(
         template,
         test=test,
         results=results,
         body_name=body_name,
         orelse_name=orelse_name)
   else:
     template = """
       ag__.utils.run_cond(test, body_name, orelse_name)
     """
     return templates.replace(
         template, test=test, body_name=body_name, orelse_name=orelse_name)
Esempio n. 30
0
  def visit_While(self, node):
    self.generic_visit(node)

    loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(
        node, anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified)
    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
        loop_state, reserved_symbols)
    node_body = ast_util.rename_symbols(node.body, ssf_map)
    test = ast_util.rename_symbols(node.test, ssf_map)

    if loop_state:
      template = """
        def test_name(state_ssf):
          return test
        def body_name(state_ssf):
          body
          return state_ssf,
        state_ast_tuple = ag__.while_stmt(test_name, body_name, (state,))
      """
      node = templates.replace(
          template,
          state=loop_state,
          state_ssf=state_ssf,
          state_ast_tuple=state_ast_tuple,
          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
          test=test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)
    else:
      template = """
        def test_name():
          return test
        def body_name():
          body
          return ()
        ag__.while_stmt(test_name, body_name, ())
      """
      node = templates.replace(
          template,
          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
          test=test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)

    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
    return undefined_assigns + node
Esempio n. 31
0
 def visit_Continue(self, node):
   self.state[_Continue].used = True
   self.state[_Block].reset_guard_state()
   template = """
     var_name = True
   """
   return templates.replace(
       template, var_name=self.state[_Continue].control_var_name)
Esempio n. 32
0
 def visit_Return(self, node):
     if node.value is None:
         return node
     node = self.generic_visit(node)
     return templates.replace(
         'return function_context_name.mark_return_value(value)',
         function_context_name=self.state[_Function].context_name,
         value=node.value)
Esempio n. 33
0
 def _create_undefined_assigns(self, undefined_symbols):
     assignments = []
     for s in undefined_symbols:
         template = '''
     var = ag__.UNDEFINED
   '''
         assignments += templates.replace(template, var=s)
     return assignments
Esempio n. 34
0
 def _create_cond_expr(self, results, test, body_name, orelse_name):
   if results is not None:
     template = """
       results = ag__.utils.run_cond(test, body_name, orelse_name)
     """
     return templates.replace(
         template,
         test=test,
         results=results,
         body_name=body_name,
         orelse_name=orelse_name)
   else:
     template = """
       ag__.utils.run_cond(test, body_name, orelse_name)
     """
     return templates.replace(
         template, test=test, body_name=body_name, orelse_name=orelse_name)
Esempio n. 35
0
  def visit_While(self, node):
    self.generic_visit(node)

    loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(
        node, anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified)
    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
        loop_state, reserved_symbols)
    node_body = ast_util.rename_symbols(node.body, ssf_map)
    test = ast_util.rename_symbols(node.test, ssf_map)

    if loop_state:
      template = """
        def test_name(state_ssf):
          return test
        def body_name(state_ssf):
          body
          return state_ssf,
        state_ast_tuple = ag__.while_stmt(test_name, body_name, (state,))
      """
      node = templates.replace(
          template,
          state=loop_state,
          state_ssf=state_ssf,
          state_ast_tuple=state_ast_tuple,
          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
          test=test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)
    else:
      template = """
        def test_name():
          return test
        def body_name():
          body
          return ()
        ag__.while_stmt(test_name, body_name, ())
      """
      node = templates.replace(
          template,
          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
          test=test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body)

    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
    return undefined_assigns + node
Esempio n. 36
0
  def to_ast(self, namespace, internal_convert_user_code=None):
    """Returns a representation of this object as an AST node.

    The AST node encodes a constructor that would create an object with the
    same contents.

    Args:
      namespace: Dict[str, Any], the namespace to use when serializing values to
        names.
      internal_convert_user_code: Optional[bool], allows ovrriding the
        corresponding value.

    Returns:
      ast.Node
    """
    template = """
      constructor_name(
          recursive=recursive_val,
          verbose=verbose_val,
          strip_decorators=strip_decorators_val,
          force_conversion=force_conversion_val,
          optional_features=optional_features_val,
          internal_convert_user_code=internal_convert_user_code_val)
    """

    def as_qualified_name(o):
      name = inspect_utils.getqualifiedname(namespace, o)
      if not name:
        raise ValueError('Could not locate entity {} in {}'.format(
            o, namespace))
      return name

    def list_of_names(values):
      return parser.parse_expression('({})'.format(', '.join(
          tuple(as_qualified_name(v) for v in values))))

    def list_of_features(values):
      return parser.parse_expression('({})'.format(', '.join(
          'ag__.Feature.{}'.format(v)
          for v in Feature.__members__
          if v in values)))

    if internal_convert_user_code is not None:
      internal_convert_user_code = self.internal_convert_user_code

    expr_ast = templates.replace(
        template,
        constructor_name=parser.parse_expression(
            as_qualified_name(ConversionOptions)),
        recursive_val=parser.parse_expression(str(self.recursive)),
        verbose_val=parser.parse_expression(str(int(self.verbose))),
        strip_decorators_val=list_of_names(self.strip_decorators),
        force_conversion_val=parser.parse_expression(
            str(self.force_conversion)),
        internal_convert_user_code_val=parser.parse_expression(
            str(internal_convert_user_code)),
        optional_features_val=list_of_features(self.optional_features))
    return expr_ast[0].value
Esempio n. 37
0
  def visit_For(self, node):
    self.generic_visit(node)

    self._validate_no_live_vars_created(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iter_=node.iter,
        iterate=node.target,
        extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced),
        extra_test_expr=extra_test,
        body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
Esempio n. 38
0
  def to_ast(self, namespace, internal_convert_user_code=None):
    """Returns a representation of this object as an AST node.

    The AST node encodes a constructor that would create an object with the
    same contents.

    Args:
      namespace: Dict[str, Any], the namespace to use when serializing values to
        names.
      internal_convert_user_code: Optional[bool], allows ovrriding the
        corresponding value.

    Returns:
      ast.Node
    """
    template = """
      constructor_name(
          recursive=recursive_val,
          verbose=verbose_val,
          strip_decorators=strip_decorators_val,
          force_conversion=force_conversion_val,
          optional_features=optional_features_val,
          internal_convert_user_code=internal_convert_user_code_val)
    """

    def as_qualified_name(o):
      name = inspect_utils.getqualifiedname(namespace, o)
      if not name:
        raise ValueError('Could not locate entity {} in {}'.format(
            o, namespace))
      return name

    def list_of_names(values):
      return parser.parse_expression('({})'.format(', '.join(
          tuple(as_qualified_name(v) for v in values))))

    def list_of_features(values):
      return parser.parse_expression('({})'.format(', '.join(
          'ag__.Feature.{}'.format(v)
          for v in Feature.__members__
          if v in values)))

    if internal_convert_user_code is not None:
      internal_convert_user_code = self.internal_convert_user_code

    expr_ast = templates.replace(
        template,
        constructor_name=parser.parse_expression(
            as_qualified_name(ConversionOptions)),
        recursive_val=parser.parse_expression(str(self.recursive)),
        verbose_val=parser.parse_expression(str(int(self.verbose))),
        strip_decorators_val=list_of_names(self._strip_decorators),
        force_conversion_val=parser.parse_expression(
            str(self.force_conversion)),
        internal_convert_user_code_val=parser.parse_expression(
            str(internal_convert_user_code)),
        optional_features_val=list_of_features(self.optional_features))
    return expr_ast[0].value
Esempio n. 39
0
    def visit_For(self, node):
        self.generic_visit(node)

        self._validate_no_live_vars_created(node)

        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        body_closure = body_scope.modified - body_scope.created
        all_referenced = body_scope.referenced

        state = list(body_closure)

        state_ssf = [
            self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
        ]
        ssf_map = {
            name: ssf
            for name, ssf in zip(state, state_ssf) if str(name) != ssf
        }

        if len(state) == 1:
            state = state[0]
            state_ssf = state_ssf[0]
            state_ast_tuple = state
        else:
            state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

        node_body = ast_util.rename_symbols(node.body, ssf_map)
        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        else:
            extra_test = parser.parse_expression('True')

        template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
        node = templates.replace(
            template,
            state=state,
            state_ssf=state_ssf,
            state_ast_tuple=state_ast_tuple,
            iter_=node.iter,
            iterate=node.target,
            extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                      all_referenced),
            extra_test_expr=extra_test,
            body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
            body=node_body)

        return node
Esempio n. 40
0
  def test_replace_name_with_subscript(self):
    template = """
        foo = bar
    """
    replacement = qn.QN(qn.QN('dictionary'), subscript=qn.QN('key'))

    node = templates.replace(template, foo=replacement)[0].targets[0]
    self.assertIsInstance(node.ctx, gast.Store)
    self.assertIsInstance(node.value.ctx, gast.Load)
Esempio n. 41
0
    def test_replace_name_with_subscript(self):
        template = """
        foo = bar
    """
        replacement = qn.QN(qn.QN('dictionary'), subscript=qn.QN('key'))

        node = templates.replace(template, foo=replacement)[0].targets[0]
        self.assertIsInstance(node.ctx, gast.Store)
        self.assertIsInstance(node.value.ctx, gast.Load)
Esempio n. 42
0
 def test_lambda_in_function_call(self):
     template = """
   a = foo(arg)
 """
     source = parser.parse_expression('[lambda i: i]')
     node = templates.replace(template, arg=source)
     lambda_arg = node[0].value.args[0].elts[0]
     self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param)
     self.assertIsInstance(lambda_arg.body.ctx, gast.Load)
Esempio n. 43
0
 def test_star_comprehension_in_function_call(self):
     template = """
   a = foo(func, args)
 """
     source = parser.parse_expression('bar(*[i for i in range(j)])')
     node = templates.replace(template, func=source.func, args=source.args)
     arg_node = node[0].value.args[1].value
     self.assertIsInstance(arg_node.generators[0].target.ctx, gast.Store)
     self.assertIsInstance(arg_node.elt.ctx, gast.Load)
Esempio n. 44
0
  def replace_as_expression(self):
    template = """
      foo(a)
    """

    node = templates.replace(template, foo='bar', a='baz')
    self.assertTrue(node is gast.Call)
    self.assertEqual(node.func.id, 'bar')
    self.assertEqual(node.func.args[0].id, 'baz')
Esempio n. 45
0
 def test_lambda_in_function_call(self):
   template = """
     a = foo(arg)
   """
   source = parser.parse_expression('[lambda i: i]')
   node = templates.replace(template, arg=source)
   lambda_arg = node[0].value.args[0].elts[0]
   self.assertIsInstance(lambda_arg.args.args[0].ctx, gast.Param)
   self.assertIsInstance(lambda_arg.body.ctx, gast.Load)
Esempio n. 46
0
    def visit_Assert(self, node):
        self.generic_visit(node)

        # Note: The lone tf.Assert call will be wrapped with control_dependencies
        # by side_effect_guards.
        template = """
      tf.Assert(test, (msg,))
    """

        if node.msg is None:
            return templates.replace(template,
                                     test=node.test,
                                     msg=gast.Str('Assertion error'))
        elif isinstance(node.msg, gast.Str):
            return templates.replace(template, test=node.test, msg=node.msg)
        else:
            raise NotImplementedError(
                'can only convert string messages for now.')
Esempio n. 47
0
 def _replace_append_call(self, node):
     assert len(node.args) == 1
     assert isinstance(node.func, gast.Attribute)
     template = """
   target = ag__.list_append(target, element)
 """
     return templates.replace(template,
                              target=node.func.value,
                              element=node.args[0])
Esempio n. 48
0
    def replace_as_expression(self):
        template = """
      foo(a)
    """

        node = templates.replace(template, foo='bar', a='baz')
        self.assertTrue(node is gast.Call)
        self.assertEqual(node.func.id, 'bar')
        self.assertEqual(node.func.args[0].id, 'baz')
Esempio n. 49
0
 def visit_Break(self, node):
   self.state[_Break].used = True
   var_name = self.state[_Break].control_var_name
   # TODO(mdan): This will fail when expanded inside a top-level else block.
   template = """
     var_name = tf.constant(True)
     continue
   """
   return templates.replace(template, var_name=var_name)
Esempio n. 50
0
 def test_star_comprehension_in_function_call(self):
   template = """
     a = foo(func, args)
   """
   source = parser.parse_expression('bar(*[i for i in range(j)])')
   node = templates.replace(template, func=source.func, args=source.args)
   arg_node = node[0].value.args[1].value
   self.assertIsInstance(arg_node.generators[0].target.ctx, gast.Store)
   self.assertIsInstance(arg_node.elt.ctx, gast.Load)
Esempio n. 51
0
    def visit_If(self, node):
        """Intercepts if statements.

    Converts each `if` to up to two separate `with` statements,
    `ProgramBuilder.if_(condition_variable)` and `ProgramBuilder.else_()`.  If
    the incoming `if` had one arm, returns the transformed AST node; if it had
    two, returns two nodes in a list.

    Args:
      node: An `ast.AST` node representing the `if` statement to convert.

    Returns:
      then_node: A node representing the `with`-guarded consequent branch.
      else_node: A node representing the `with`-guarded alternate branch,
        if present.
    """
        # Transform a branch
        # NOTE: this is a little hackery to make sure that prepending works
        # properly. Wrapping a list of statements in a Module ensures
        # that the AST-visiting machinery won't choke on, e.g., a list.
        then = self.generic_visit(gast_util.Module(node.body)).body

        # Construct header (goes in the `with`s).
        then_header = templates.replace_as_expression(
            '_tfp_autobatching_context_.if_(cond)',
            cond=self._to_reference(node.test))

        # Construct `with` node.
        # TODO(axch): Test that this form actually works with multiline bodies.
        then_node = templates.replace('with header: body',
                                      header=then_header,
                                      body=then)[0]

        if node.orelse:
            orelse = self.generic_visit(gast_util.Module(node.orelse)).body
            orelse_header = templates.replace_as_expression(
                '_tfp_autobatching_context_.else_()')
            orelse_node = templates.replace('with header: body',
                                            header=orelse_header,
                                            body=orelse)[0]
            # Return both
            return [then_node, orelse_node]
        else:
            return then_node
Esempio n. 52
0
    def test_replace_expression_context(self):
        template = """
      def test_fn():
        foo
    """

        node = templates.replace(
            template, foo=parser.parse_expression('a + 2 * b / -c'))[0]
        self.assertIsInstance(node.body[0].left.ctx, gast.Load)
        self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load)
Esempio n. 53
0
 def _create_undefined_assigns(self, undefined_symbols):
     assignments = []
     for s in undefined_symbols:
         template = '''
     var = ag__.Undefined(symbol_name)
   '''
         assignments += templates.replace(template,
                                          var=s,
                                          symbol_name=gast.Str(s.ssf()))
     return assignments
Esempio n. 54
0
    def test_replace_name_with_dict(self):
        template = """
      def test_fn():
        return foo['bar']
    """

        source = parser.parse_expression('{\'bar\': 3}')
        node = templates.replace(template, foo=source)[0]
        result, _, _ = loader.load_ast(node)
        self.assertEqual(3, result.test_fn())
Esempio n. 55
0
  def test_replace_tuple(self):
    template = """
      def test_fn(a, c):
        return b,
    """

    node = templates.replace(template, b=('a', 'c'))[0]
    result, _ = compiler.ast_to_object(node)

    self.assertEquals((2, 3), result.test_fn(2, 3))
Esempio n. 56
0
  def test_replace_name_with_dict(self):
    template = """
      def test_fn():
        return foo['bar']
    """

    source = parser.parse_expression('{\'bar\': 3}')
    node = templates.replace(template, foo=source)[0]
    result, _ = compiler.ast_to_object(node)
    self.assertEquals(3, result.test_fn())
Esempio n. 57
0
    def test_replace_tuple(self):
        template = """
      def test_fn(a, c):
        return b,
    """

        node = templates.replace(template, b=('a', 'c'))[0]
        result, _, _ = loader.load_ast(node)

        self.assertEqual((2, 3), result.test_fn(2, 3))
Esempio n. 58
0
  def test_replace_expression_context(self):
    template = """
      def test_fn():
        foo
    """

    node = templates.replace(
        template, foo=parser.parse_expression('a + 2 * b / -c'))[0]
    self.assertIsInstance(node.body[0].left.ctx, gast.Load)
    self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load)