Beispiel #1
0
    def visit_Lambda(self, node):
        with self.state[_Function] as fn_scope:
            node = self.generic_visit(node)

            # Only wrap the top-level function. Theoretically, we can and should wrap
            # everything, but that can lead to excessive boilerplate when lambdas are
            # nested.
            # TODO(mdan): Looks more closely for use cases that actually require this.
            if fn_scope.level > 2:
                return templates.replace_as_expression(
                    'ag__.autograph_artifact(l)', l=node)

            scope = anno.getanno(node, anno.Static.SCOPE)
            function_context_name = self.ctx.namer.new_symbol(
                'lscope', scope.referenced)
            fn_scope.context_name = function_context_name
            anno.setanno(node, 'function_context_name', function_context_name)

            template = """
        ag__.with_function_scope(
            lambda function_context: body, function_context_name, options)
      """
            node.body = templates.replace_as_expression(
                template,
                options=self._function_scope_options(fn_scope).to_ast(),
                function_context=function_context_name,
                function_context_name=gast.Constant(function_context_name,
                                                    kind=None),
                body=node.body)

            return node
  def visit_Lambda(self, node):
    with self.state[_Function] as fn_scope:
      node = self.generic_visit(node)

      # TODO(mdan): Fix the tests so that we can always add this decorator.
      if fn_scope.level > 2:
        return templates.replace_as_expression(
            'ag__.autograph_artifact(l)', l=node)

      scope = anno.getanno(node, anno.Static.SCOPE)
      function_context_name = self.ctx.namer.new_symbol('lscope',
                                                        scope.referenced)
      fn_scope.context_name = function_context_name
      anno.setanno(node, 'function_context_name', function_context_name)

      template = """
        ag__.with_function_scope(
            lambda function_context: body, function_context_name, options)
      """
      node.body = templates.replace_as_expression(
          template,
          options=self._function_scope_options(fn_scope).to_ast(),
          function_context=function_context_name,
          function_context_name=gast.Constant(function_context_name, kind=None),
          body=node.body)

      return node
Beispiel #3
0
    def visit_Subscript(self, node):
        node = self.generic_visit(node)
        s = node.slice
        if isinstance(s, (gast.Tuple, gast.Slice)):
            return node

        if not isinstance(node.ctx, gast.Load):
            # Index writes are handled at a higher level, one at which the rvalue is
            # also available.
            return node

        dtype = self.get_definition_directive(
            node.value,
            directives.set_element_type,
            'dtype',
            default=templates.replace_as_expression('None'))

        template = """
      ag__.get_item(
          target,
          key,
          opts=ag__.GetItemOpts(element_dtype=dtype))
    """
        return templates.replace_as_expression(template,
                                               target=node.value,
                                               key=s,
                                               dtype=dtype)
Beispiel #4
0
 def test_replace_as_expression_restrictions(self):
     template = """
   foo(a)
   bar(b)
 """
     with self.assertRaises(ValueError):
         templates.replace_as_expression(template)
Beispiel #5
0
 def test_replace_as_expression_restrictions(self):
   template = """
     foo(a)
     bar(b)
   """
   with self.assertRaises(ValueError):
     templates.replace_as_expression(template)
 def visit_Name(self, node):
     # Only the loads which existed in the original code are overloaded.
     if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS):
         return node
     if isinstance(node.ctx, gast.Load):
         node = templates.replace_as_expression('ag__.ld(var_)', var_=node)
     return node
  def visit_For(self, node):
    original_node = node
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

    node.target = self.visit(node.target)
    node.iter = self.visit(node.iter)
    node.body, break_used = self._process_body(node.body, break_var)
    # A break in the else clause applies to the containing scope.
    node.orelse = self.visit_block(node.orelse)

    if not break_used:
      template = """
        for target in iter_:
          body
        orelse
      """
      node = templates.replace(
          template,
          iter_=node.iter,
          target=node.target,
          body=node.body,
          orelse=node.orelse)

      new_for_node = node[0]
      anno.copyanno(original_node, new_for_node, anno.Basic.EXTRA_LOOP_TEST)
      anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)

      return node

    # Python's else clause only triggers if the loop exited cleanly (e.g.
    # break did not trigger).
    guarded_orelse = self._guard_if_present(node.orelse, break_var)
    extra_test = templates.replace_as_expression(
        'ag__.not_(var_name)', var_name=break_var)

    # The extra test is hidden in the AST, which will confuse the static
    # analysis. To mitigate that, we insert a no-op statement that ensures
    # the control variable is marked as used.
    # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
    template = """
      var_name = False
      for target in iter_:
        (var_name,)
        body
      orelse
    """
    node = templates.replace(
        template,
        var_name=break_var,
        iter_=node.iter,
        target=node.target,
        body=node.body,
        orelse=guarded_orelse)

    new_for_node = node[1]
    anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
    anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)

    return node
Beispiel #8
0
    def visit_Lambda(self, node):
        self.state[_Function].enter()
        node = self.generic_visit(node)

        # Only wrap the top-level function. Theoretically, we can and should wrap
        # everything, but that can lead to excessive boilerplate when lambdas are
        # nested.
        # TODO(mdan): Looks more closely for use cases that actually require this.
        if self.state[_Function].level > 2:
            self.state[_Function].exit()
            return node

        scope = anno.getanno(node, anno.Static.SCOPE)
        function_context_name = self.ctx.namer.new_symbol(
            'lscope', scope.referenced)
        self.state[_Function].context_name = function_context_name
        anno.setanno(node, 'function_context_name', function_context_name)

        template = """
      ag__.with_function_scope(
          lambda function_context: body, function_context_name, options)
    """
        node.body = templates.replace_as_expression(
            template,
            options=self.ctx.program.options.to_ast(),
            function_context=function_context_name,
            function_context_name=gast.Str(function_context_name),
            body=node.body)

        self.state[_Function].exit()
        return node
Beispiel #9
0
  def visit_Call(self, node):
    # TODO(mdan): Refactor converted_call as a 'Call' operator.

    # Calls to the internal 'ag__' module are never converted (though their
    # arguments might be).
    full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
    if full_name.startswith('ag__.'):
      return self.generic_visit(node)
    if (full_name == 'print' and
        not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
      return self.generic_visit(node)

    template = """
      ag__.converted_call(func, owner, options, args)
    """
    if isinstance(node.func, gast.Attribute):
      func = gast.Str(node.func.attr)
      owner = node.func.value
    else:
      func = node.func
      owner = parser.parse_expression('None')

    new_call = templates.replace_as_expression(
        template,
        func=func,
        owner=owner,
        options=self.ctx.program.options.to_ast(
            self.ctx,
            internal_convert_user_code=self.ctx.program.options.recursive),
        args=node.args)
    # TODO(mdan): Improve the template mechanism to better support this.
    new_call.keywords = node.keywords

    return new_call
Beispiel #10
0
 def visit_IfExp(self, node):
     return templates.replace_as_expression(
         '''ag__.if_stmt(test, lambda: true_expr,
                     lambda: false_expr, lambda: (), lambda _: None)''',
         test=node.test,
         true_expr=node.body,
         false_expr=node.orelse)
Beispiel #11
0
  def _rename_compilable_function(self, node):
    assert anno.hasanno(node.func, 'live_val')
    assert anno.hasanno(node.func, 'fqn')
    target_entity = anno.getanno(node.func, 'live_val')
    target_fqn = anno.getanno(node.func, 'fqn')

    if anno.hasanno(node, 'is_constructor'):
      new_name = self.ctx.namer.compiled_class_name(
          target_fqn, live_entity=target_entity)
      do_rename = True
    else:
      if anno.hasanno(node.func, 'parent_type'):
        owner_type = anno.getanno(node.func, 'parent_type')
      else:
        # Fallback - not reliable.
        owner_type = inspect_utils.getmethodclass(target_entity)
      new_name, do_rename = self.ctx.namer.compiled_function_name(
          target_fqn, live_entity=target_entity, owner_type=owner_type)

    if do_rename:
      if target_entity is not None:
        if tf_inspect.ismethod(target_entity):
          # The renaming process will transform it into a regular function.
          # TODO(mdan): Is this complete? How does it work with nested members?
          node.args = [node.func.value] + node.args
      node.func = templates.replace_as_expression(
          'func_name', func_name=new_name)
    return node
Beispiel #12
0
 def _as_binary_operation(self, op, arg1, arg2):
     template = templates.replace_as_expression(
         'arg1 is arg2',  # Note: `is` will be replaced with `op` below.
         arg1=arg1,
         arg2=arg2)
     template.ops[0] = op
     return template
Beispiel #13
0
    def visit_Call(self, node):
        # TODO(mdan): Refactor converted_call as a 'Call' operator.

        # Calls to the internal 'ag__' module are never converted (though their
        # arguments might be).
        full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
        if full_name.startswith('ag__.'):
            return self.generic_visit(node)
        if (full_name == 'print' and not self.ctx.program.options.uses(
                converter.Feature.BUILTIN_FUNCTIONS)):
            return self.generic_visit(node)

        template = """
      ag__.converted_call(func, owner, options, args)
    """
        if isinstance(node.func, gast.Attribute):
            func = gast.Str(node.func.attr)
            owner = node.func.value
        else:
            func = node.func
            owner = parser.parse_expression('None')

        new_call = templates.replace_as_expression(
            template,
            func=func,
            owner=owner,
            options=self.ctx.program.options.to_ast(
                self.ctx,
                internal_convert_user_code=self.ctx.program.options.recursive),
            args=node.args)
        # TODO(mdan): Improve the template mechanism to better support this.
        new_call.keywords = node.keywords

        return new_call
 def _as_function(self, func_name, args):
   template = """
     func_name(args)
   """
   replacement = templates.replace_as_expression(
       template, func_name=parser.parse_expression(func_name), args=args)
   anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
   return replacement
 def _as_function(self, func_name, args):
     template = """
   func_name(args)
 """
     replacement = templates.replace_as_expression(
         template, func_name=parser.parse_expression(func_name), args=args)
     anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
     return replacement
Beispiel #16
0
    def test_replace_as_expression(self):
        template = """
      foo(a)
    """

        node = templates.replace_as_expression(template, foo='bar', a='baz')
        self.assertIsInstance(node, gast.Call)
        self.assertEqual(node.func.id, 'bar')
        self.assertEqual(node.args[0].id, 'baz')
Beispiel #17
0
 def _replace_stack_call(self, node):
     assert len(node.args) == 1
     dtype = self.get_definition_directive(
         node.args[0],
         directives.set_element_type,
         'dtype',
         default=templates.replace_as_expression('None'))
     template = """
   ag__.list_stack(
       target,
       opts=ag__.ListStackOpts(
           element_dtype=dtype,
           original_call=orig_call))
 """
     return templates.replace_as_expression(template,
                                            dtype=dtype,
                                            target=node.args[0],
                                            orig_call=node.func)
Beispiel #18
0
 def _to_reference(self, node):
   if isinstance(node, (gast.Name, qual_names.QN)):
     return templates.replace_as_expression(
         '_tfp_autobatching_context_.var.name', name=node)
   elif gast_util.is_literal(node):
     raise ValueError('TODO(axch): Support literals, not just variables')
   else:
     msg = 'Expected trivial node, got {}.  Is the input in A-normal form?'
     raise ValueError(msg.format(node))
Beispiel #19
0
  def test_replace_as_expression(self):
    template = """
      foo(a)
    """

    node = templates.replace_as_expression(template, foo='bar', a='baz')
    self.assertIsInstance(node, gast.Call)
    self.assertEqual(node.func.id, 'bar')
    self.assertEqual(node.args[0].id, 'baz')
Beispiel #20
0
    def visit_If(self, node):
        """Intercepts if statements.

    Converts each `if` to up to two separate `with` statements,
    `ProgramBuilder.if_(condition_variable)` and `ProgramBuilder.else_()`.  If
    the incoming `if` had one arm, returns the transformed AST node; if it had
    two, returns two nodes in a list.

    Args:
      node: An `ast.AST` node representing the `if` statement to convert.

    Returns:
      then_node: A node representing the `with`-guarded consequent branch.
      else_node: A node representing the `with`-guarded alternate branch,
        if present.
    """
        # Transform a branch
        # NOTE: this is a little hackery to make sure that prepending works
        # properly. Wrapping a list of statements in a Module ensures
        # that the AST-visiting machinery won't choke on, e.g., a list.
        then = self.generic_visit(gast_util.Module(node.body)).body

        # Construct header (goes in the `with`s).
        then_header = templates.replace_as_expression(
            '_tfp_autobatching_context_.if_(cond)',
            cond=self._to_reference(node.test))

        # Construct `with` node.
        # TODO(axch): Test that this form actually works with multiline bodies.
        then_node = templates.replace('with header: body',
                                      header=then_header,
                                      body=then)[0]

        if node.orelse:
            orelse = self.generic_visit(gast_util.Module(node.orelse)).body
            orelse_header = templates.replace_as_expression(
                '_tfp_autobatching_context_.else_()')
            orelse_node = templates.replace('with header: body',
                                            header=orelse_header,
                                            body=orelse)[0]
            # Return both
            return [then_node, orelse_node]
        else:
            return then_node
 def _convert_builtin(self, f, args, as_expression):
   template = """
     ag__.func(args)
   """
   if as_expression:
     return templates.replace_as_expression(
         template, func=py_builtins.overload_of(f).__name__, args=args)
   else:
     return templates.replace(
         template, func=py_builtins.overload_of(f).__name__, args=args)
Beispiel #22
0
 def _replace_stack_call(self, node):
   assert len(node.args) == 1
   dtype = self.get_definition_directive(
       node.args[0],
       directives.set_element_type,
       'dtype',
       default=templates.replace_as_expression('None'))
   template = """
     ag__.list_stack(
         target,
         opts=ag__.ListStackOpts(
             element_dtype=dtype,
             original_call=orig_call))
   """
   return templates.replace_as_expression(
       template,
       dtype=dtype,
       target=node.args[0],
       orig_call=node.func)
Beispiel #23
0
 def _wrap_to_py_func_single_return(self, node, dtype):
   # TODO(mdan): Properly handle varargs, etc.
   template = """
     ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False)
   """
   return templates.replace_as_expression(
       template,
       func=node.func,
       dtype=parser.parse_expression(dtype),
       args=node.args,
       kwargs=ast_util.keywords_to_dict(node.keywords))
 def _convert_builtin(self, f, args, as_expression):
     template = """
   ag__.func(args)
 """
     if as_expression:
         return templates.replace_as_expression(
             template, func=py_builtins.overload_of(f).__name__, args=args)
     else:
         return templates.replace(template,
                                  func=py_builtins.overload_of(f).__name__,
                                  args=args)
Beispiel #25
0
  def visit_For(self, node):
    node.iter = self.visit(node.iter)
    node.target = self.visit(node.target)

    # Add the check for return to the loop condition.
    node.body = self._visit_statement_block(node, node.body)
    if self.state[_Return].used:
      extra_test = anno.getanno(node, 'extra_test', default=None)
      if extra_test is not None:
        extra_test = templates.replace_as_expression(
            'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)',
            extra_test=extra_test,
            control_var=self.state[_Function].do_return_var_name)
      else:
        extra_test = templates.replace_as_expression(
            'ag__.not_(control_var)',
            control_var=self.state[_Function].do_return_var_name)
      anno.setanno(node, 'extra_test', extra_test)

    node.orelse = self._visit_statement_block(node, node.orelse)
    return node
Beispiel #26
0
  def visit_For(self, node):
    node.iter = self.visit(node.iter)
    node.target = self.visit(node.target)

    # Add the check for return to the loop condition.
    node.body = self._visit_statement_block(node, node.body)
    if self.state[_Block].return_used:
      extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
      if extra_test is not None:
        extra_test = templates.replace_as_expression(
            'not control_var and extra_test',
            extra_test=extra_test,
            control_var=self.state[_Function].do_return_var_name)
      else:
        extra_test = templates.replace_as_expression(
            'not control_var',
            control_var=self.state[_Function].do_return_var_name)
      anno.setanno(node, anno.Basic.EXTRA_LOOP_TEST, extra_test)

    node.orelse = self._visit_statement_block(node, node.orelse)
    return node
Beispiel #27
0
    def visit_For(self, node):
        node.iter = self.visit(node.iter)
        node.target = self.visit(node.target)

        # Add the check for return to the loop condition.
        node.body = self._visit_statement_block(node, node.body)
        if self.state[_Block].return_used:
            extra_test = anno.getanno(node, 'extra_test', default=None)
            if extra_test is not None:
                extra_test = templates.replace_as_expression(
                    'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)',
                    extra_test=extra_test,
                    control_var=self.state[_Function].do_return_var_name)
            else:
                extra_test = templates.replace_as_expression(
                    'ag__.not_(control_var)',
                    control_var=self.state[_Function].do_return_var_name)
            anno.setanno(node, 'extra_test', extra_test)

        node.orelse = self._visit_statement_block(node, node.orelse)
        return node
Beispiel #28
0
    def _as_function(self, func_name, args, args_as_lambda=False):
        if args_as_lambda:
            args_as_lambda = []
            for arg in args:
                template = """
          lambda: arg
        """
                args_as_lambda.append(
                    templates.replace_as_expression(template, arg=arg))
            args = args_as_lambda

        if not args:
            template = """
        func_name()
      """
            replacement = templates.replace_as_expression(
                template, func_name=parser.parse_expression(func_name))
        elif len(args) == 1:
            template = """
        func_name(arg)
      """
            replacement = templates.replace_as_expression(
                template,
                func_name=parser.parse_expression(func_name),
                arg=args[0])
        elif len(args) == 2:
            template = """
        func_name(arg1, arg2)
      """
            replacement = templates.replace_as_expression(
                template,
                func_name=parser.parse_expression(func_name),
                arg1=args[0],
                arg2=args[1])
        else:
            raise NotImplementedError('{} arguments for {}'.format(
                len(args), func_name))

        anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
        return replacement
Beispiel #29
0
  def visit_While(self, node):
    node.test = self.visit(node.test)

    # Add the check for return to the loop condition.
    node.body = self._visit_statement_block(node, node.body)
    if self.state[_Return].used:
      node.test = templates.replace_as_expression(
          'ag__.and_(lambda: ag__.not_(control_var), lambda: test)',
          test=node.test,
          control_var=self.state[_Function].do_return_var_name)

    node.orelse = self._visit_statement_block(node, node.orelse)
    return node
Beispiel #30
0
    def visit_Return(self, node):
        """Intercepts return statements.

    Args:
      node: An `ast.AST` node representing the `return` statement to convert.

    Returns:
      node: A node representing the result.
    """
        node = templates.replace_as_expression(
            '_tfp_autobatching_context_.return_(value)',
            value=self._to_reference(node.value))
        return gast.Expr(node)
Beispiel #31
0
 def _assignment_construct_recur(self, target):
     if isinstance(target, (gast.Tuple, gast.List)):
         subs = [self._assignment_construct_recur(t) for t in target.elts]
         if isinstance(target, gast.Tuple):
             # Context is not Store anymore, because this section is constructing the
             # pattern object
             return gast.Tuple(subs, ctx=gast.Load())
         else:
             # Context is not Store anymore, because this section is constructing the
             # pattern object
             return gast.List(subs, ctx=gast.Load())
     return templates.replace_as_expression(
         '_tfp_autobatching_context_.var.name', name=target)
Beispiel #32
0
    def visit_While(self, node):
        node.test = self.visit(node.test)

        # Add the check for return to the loop condition.
        node.body = self._visit_statement_block(node, node.body)
        if self.state[_Block].return_used:
            node.test = templates.replace_as_expression(
                'ag__.and_(lambda: ag__.not_(control_var), lambda: test)',
                test=node.test,
                control_var=self.state[_Function].do_return_var_name)

        node.orelse = self._visit_statement_block(node, node.orelse)
        return node
  def _as_function(self, func_name, args, args_as_lambda=False):
    if args_as_lambda:
      args_as_lambda = []
      for arg in args:
        template = """
          lambda: arg
        """
        args_as_lambda.append(
            templates.replace_as_expression(template, arg=arg))
      args = args_as_lambda

    if not args:
      template = """
        func_name()
      """
      replacement = templates.replace_as_expression(
          template, func_name=parser.parse_expression(func_name))
    elif len(args) == 1:
      template = """
        func_name(arg)
      """
      replacement = templates.replace_as_expression(
          template, func_name=parser.parse_expression(func_name), arg=args[0])
    elif len(args) == 2:
      template = """
        func_name(arg1, arg2)
      """
      replacement = templates.replace_as_expression(
          template,
          func_name=parser.parse_expression(func_name),
          arg1=args[0],
          arg2=args[1])
    else:
      raise NotImplementedError('{} arguments for {}'.format(
          len(args), func_name))

    anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
    return replacement
  def visit_IfExp(self, node):
    if anno.hasanno(node.test, anno.Basic.QN):
      name_root = anno.getanno(node.test, anno.Basic.QN).ssf()
    else:
      name_root = 'ifexp'

    true_fn_name = self._create_branch(node.body, '%s_true' % name_root)
    false_fn_name = self._create_branch(node.orelse, '%s_false' % name_root)

    return templates.replace_as_expression(
        'ag__.utils.run_cond(test, true_fn_name, false_fn_name)',
        test=node.test,
        true_fn_name=true_fn_name,
        false_fn_name=false_fn_name)
Beispiel #35
0
  def _generate_pop_operation(self, original_call_node, pop_var_name):
    assert isinstance(original_call_node.func, gast.Attribute)

    if original_call_node.args:
      pop_element = original_call_node.args[0]
    else:
      pop_element = parser.parse_expression('None')

    # The call will be something like "target.pop()", and the dtype is hooked to
    # target, hence the func.value.
    # TODO(mdan): For lists of lists, this won't work.
    # The reason why it won't work is because it's unclear how to annotate
    # the list as a "list of lists with a certain element type" when using
    # operations like `l.pop().pop()`.
    dtype = self.get_definition_directive(
        original_call_node.func.value,
        directives.set_element_type,
        'dtype',
        default=templates.replace_as_expression('None'))
    shape = self.get_definition_directive(
        original_call_node.func.value,
        directives.set_element_type,
        'shape',
        default=templates.replace_as_expression('None'))

    template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
    return templates.replace(
        template,
        target=original_call_node.func.value,
        pop_var_name=pop_var_name,
        element=pop_element,
        dtype=dtype,
        shape=shape)
Beispiel #36
0
  def _generate_pop_operation(self, original_call_node, pop_var_name):
    assert isinstance(original_call_node.func, gast.Attribute)

    if original_call_node.args:
      pop_element = original_call_node.args[0]
    else:
      pop_element = parser.parse_expression('None')

    # The call will be something like "target.pop()", and the dtype is hooked to
    # target, hence the func.value.
    # TODO(mdan): For lists of lists, this won't work.
    # The reason why it won't work is because it's unclear how to annotate
    # the list as a "list of lists with a certain element type" when using
    # operations like `l.pop().pop()`.
    dtype = self.get_definition_directive(
        original_call_node.func.value,
        directives.set_element_type,
        'dtype',
        default=templates.replace_as_expression('None'))
    shape = self.get_definition_directive(
        original_call_node.func.value,
        directives.set_element_type,
        'shape',
        default=templates.replace_as_expression('None'))

    template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
    return templates.replace(
        template,
        target=original_call_node.func.value,
        pop_var_name=pop_var_name,
        element=pop_element,
        dtype=dtype,
        shape=shape)
Beispiel #37
0
    def visit_Call(self, node):
        full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
        function_context_name = self.state[_Function].context_name
        node = self.generic_visit(node)

        # TODO(mdan): Refactor converted_call as a 'Call' operator.

        # Calls to the internal 'ag__' module are never converted (though their
        # arguments might be).
        if full_name.startswith('ag__.'):
            return node

        # Calls to the function context manager (inserted by function_scopes) are
        # also safe.
        if full_name.startswith(function_context_name + '.'):
            return node

        # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
        # the normal mechanisms to bypass these literals because they are sensitive
        # to the frame they are being called from.
        # TODO(mdan): Generalize this to a "static whitelist" config.
        if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'):
            global set_trace_warned
            if not set_trace_warned:
                # TODO(mdan): Update and shorten once available on tensorflow.org.
                ag_logging.warn(
                    'Detected `pdb.set_trace()` in converted code. The code'
                    ' generated by AutoGraph is not optimized for step-by-step'
                    ' debugging. See https://github.com/tensorflow/tensorflow/'
                    'blob/master/tensorflow/python/autograph/g3doc/reference/'
                    'debugging.md.')
                set_trace_warned = True
            return node

        if (full_name == 'print' and not self.ctx.program.options.uses(
                converter.Feature.BUILTIN_FUNCTIONS)):
            return node

        template = """
      ag__.converted_call(func, args, kwargs, function_ctx)
    """
        new_call = templates.replace_as_expression(
            template,
            func=node.func,
            args=self._args_to_tuple(node),
            kwargs=self._kwargs_to_dict(node),
            function_ctx=function_context_name)

        return new_call
Beispiel #38
0
  def visit_Print(self, node):
    node = self.generic_visit(node)
    args = node.values
    # Following is the case when calling print(a, b)
    if len(args) == 1 and isinstance(args[0], gast.Tuple):
      args = args[0].elts

    template = """
      ag__.converted_call(func, None, options, args, {})
    """
    return templates.replace_as_expression(
        template,
        func='print',
        options=self.ctx.program.options.to_ast(),
        args=args)
    def visit_IfExp(self, node):
        if anno.hasanno(node.test, anno.Basic.QN):
            name_root = anno.getanno(node.test, anno.Basic.QN).ssf()
        else:
            name_root = 'ifexp'

        true_fn_name = self._create_branch(node.body, '%s_true' % name_root)
        false_fn_name = self._create_branch(node.orelse,
                                            '%s_false' % name_root)

        return templates.replace_as_expression(
            'ag__.utils.run_cond(test, true_fn_name, false_fn_name)',
            test=node.test,
            true_fn_name=true_fn_name,
            false_fn_name=false_fn_name)
Beispiel #40
0
  def visit_Subscript(self, node):
    node = self.generic_visit(node)
    if not isinstance(node.slice, gast.Index):
      return node

    if not isinstance(node.ctx, gast.Load):
      # Index writes are handled at a higher level, one at which the rvalue is
      # also available.
      return node

    dtype = self.get_definition_directive(
        node.value,
        directives.set_element_type,
        'dtype',
        default=templates.replace_as_expression('None'))

    template = """
      ag__.get_item(
          target,
          key,
          opts=ag__.GetItemOpts(element_dtype=dtype))
    """
    return templates.replace_as_expression(
        template, target=node.value, key=node.slice.value, dtype=dtype)
Beispiel #41
0
    def visit_Print(self, node):
        node = self.generic_visit(node)
        args = node.values
        # Following is the case when calling print(a, b)
        if len(args) == 1 and isinstance(args[0], gast.Tuple):
            args = args[0].elts

        template = """
      ag__.converted_call(func, None, options, args, {})
    """
        return templates.replace_as_expression(
            template,
            func='print',
            options=self.ctx.program.options.to_ast(),
            args=args)
 def visit_IfExp(self, node):
     template = '''
     ag__.if_exp(
         test,
         lambda: true_expr,
         lambda: false_expr,
         expr_repr)
 '''
     expr_repr = parser.unparse(node.test,
                                include_encoding_marker=False).strip()
     return templates.replace_as_expression(template,
                                            test=node.test,
                                            true_expr=node.body,
                                            false_expr=node.orelse,
                                            expr_repr=gast.Constant(
                                                expr_repr, kind=None))
  def visit_For(self, node):
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

    node.target = self.visit(node.target)
    node.iter = self.visit(node.iter)
    node.body, break_used = self._process_body(node.body, break_var)
    # A break in the else clause applies to the containing scope.
    node.orelse = self.visit_block(node.orelse)

    if break_used:
      # Python's else clause only triggers if the loop exited cleanly (e.g.
      # break did not trigger).
      guarded_orelse = self._guard_if_present(node.orelse, break_var)
      extra_test = templates.replace_as_expression(
          'not var_name', var_name=break_var)

      # The extra test is hidden in the AST, which will confuse the static
      # analysis. To mitigate that, we insert a no-op statement that ensures
      # the control variable is marked as used.
      # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
      template = """
        var_name = tf.constant(False)
        for target in iter_:
          (var_name,)
          body
        else:
          orelse
      """
      node = templates.replace(
          template,
          var_name=break_var,
          iter_=node.iter,
          target=node.target,
          body=node.body,
          orelse=guarded_orelse)

      anno.setanno(node[1], 'extra_test', extra_test)

    return node
Beispiel #44
0
    def _replace_pop_call(self, node):
        # Expressions that use pop() are converted to a statement + expression.
        #
        # For example:
        #
        #   print(target.pop())
        #
        # ... is converted to:
        #
        #   target, target_pop = ag__.list_pop(target)
        #   print(target_pop)
        #
        # Here, we just generate the variable name and swap it in,
        # and _generate_pop_operation will handle the rest.
        #
        # Multiple uses of pop() are allowed:
        #
        #   print(tartget.pop(), target.pop())
        #   print(tartget.pop().pop())
        #
        assert isinstance(node.func, gast.Attribute)
        scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
        target_node = node.func.value

        # Attempt to use a related name if one exists. Otherwise use something
        # generic.
        if anno.hasanno(target_node, anno.Basic.QN):
            target_name = anno.getanno(target_node, anno.Basic.QN).ssf()
        else:
            target_name = 'list_'
        pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced)

        stmt = self.state[_Statement]
        if stmt.pop_uses is None:
            stmt.pop_uses = []
        stmt.pop_uses.append((node, pop_var_name))

        return templates.replace_as_expression('var_name',
                                               var_name=pop_var_name)
Beispiel #45
0
 def _insert_dynamic_conversion(self, node):
   """Inlines a dynamic conversion for a dynamic function."""
   # TODO(mdan): Pass information on the statically compiled functions.
   # Having access to the statically compiled functions can help avoid
   # unnecessary compilation.
   # For example, this would lead to function `a` being compiled twice:
   #
   #   def a():
   #     v = b
   #     b()
   #   def b():
   #     a()
   #
   # This is really a problem with recursive calls, which currently can
   # only be gated by a static condition, and should be rare.
   # TODO(mdan): It probably makes sense to use dynamic conversion every time.
   # Before we could convert all the time though, we'd need a reasonable
   # caching mechanism.
   template = """
     ag__.converted_call(func, owner, options, args)
   """
   if isinstance(node.func, gast.Attribute):
     func = gast.Str(node.func.attr)
     owner = node.func.value
   else:
     func = node.func
     owner = parser.parse_expression('None')
   new_call = templates.replace_as_expression(
       template,
       func=func,
       owner=owner,
       options=self.ctx.program.options.to_ast(
           self.ctx.info.namespace,
           internal_convert_user_code=self.ctx.program.options.recursive),
       args=node.args)
   # TODO(mdan): Improve the template mechanism to better support this.
   new_call.keywords = node.keywords
   return new_call
Beispiel #46
0
  def _replace_pop_call(self, node):
    # Expressions that use pop() are converted to a statement + expression.
    #
    # For example:
    #
    #   print(target.pop())
    #
    # ... is converted to:
    #
    #   target, target_pop = ag__.list_pop(target)
    #   print(target_pop)
    #
    # Here, we just generate the variable name and swap it in,
    # and _generate_pop_operation will handle the rest.
    #
    # Multiple uses of pop() are allowed:
    #
    #   print(tartget.pop(), target.pop())
    #   print(tartget.pop().pop())
    #
    assert isinstance(node.func, gast.Attribute)
    scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
    target_node = node.func.value

    # Attempt to use a related name if one exists. Otherwise use something
    # generic.
    if anno.hasanno(target_node, anno.Basic.QN):
      target_name = anno.getanno(target_node, anno.Basic.QN).ssf()
    else:
      target_name = 'list_'
    pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced)

    pop_uses = self.get_local(POP_USES, [])
    pop_uses.append((node, pop_var_name))
    self.set_local(POP_USES, pop_uses)

    return templates.replace_as_expression('var_name', var_name=pop_var_name)
Beispiel #47
0
  def visit_Call(self, node):
    # TODO(mdan): Refactor converted_call as a 'Call' operator.

    # Calls to the internal 'ag__' module are never converted (though their
    # arguments might be).
    full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
    if full_name.startswith('ag__.'):
      return self.generic_visit(node)
    if (full_name == 'print' and
        not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
      return self.generic_visit(node)

    if isinstance(node.func, gast.Attribute):
      func = gast.Str(node.func.attr)
      owner = node.func.value
    else:
      func = node.func
      owner = parser.parse_expression('None')

    starred_arg = None
    normal_args = []
    for a in node.args:
      if isinstance(a, gast.Starred):
        assert starred_arg is None, 'Multiple *args should be impossible.'
        starred_arg = a
      else:
        a = self.visit(a)
        normal_args.append(a)
    if starred_arg is None:
      args = templates.replace_as_expression('(args,)', args=normal_args)
    else:
      args = templates.replace_as_expression(
          '(args,) + tuple(stararg)',
          stararg=starred_arg.value,
          args=normal_args)

    kwargs_arg = None
    normal_keywords = []
    for k in node.keywords:
      if k.arg is None:
        assert kwargs_arg is None, 'Multiple **kwargs should be impossible.'
        kwargs_arg = k
      else:
        k = self.visit(k)
        normal_keywords.append(k)
    if kwargs_arg is None:
      kwargs = ast_util.keywords_to_dict(normal_keywords)
    else:
      kwargs = templates.replace_as_expression(
          'dict(kwargs, **keywords)',
          kwargs=kwargs_arg.value,
          keywords=ast_util.keywords_to_dict(normal_keywords))

    template = """
      ag__.converted_call(func, owner, options, args, kwargs)
    """
    new_call = templates.replace_as_expression(
        template,
        func=func,
        owner=owner,
        options=self.ctx.program.options.to_ast(
            internal_convert_user_code=self.ctx.program.options.recursive),
        args=args,
        kwargs=kwargs)

    return new_call
  def visit_If(self, node):
    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)

    # Note: this information needs to be extracted before the body conversion
    # that happens in the call to generic_visit below, because the conversion
    # generates nodes that lack static analysis annotations.
    need_alias_in_body = self._determine_aliased_symbols(
        body_scope, defined_in, node.body)
    need_alias_in_orelse = self._determine_aliased_symbols(
        orelse_scope, defined_in, node.orelse)

    node = self.generic_visit(node)

    modified_in_cond = body_scope.modified | orelse_scope.modified
    returned_from_cond = set()
    for s in modified_in_cond:
      if s in live_out:
        returned_from_cond.add(s)
      elif s.is_composite():
        # Special treatment for compound objects: if any of their owner entities
        # are live, then they are outputs as well.
        if live_out & s.owner_set:
          returned_from_cond.add(s)

    created_in_body = body_scope.modified & returned_from_cond - defined_in
    created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in

    basic_created_in_body = tuple(
        s for s in created_in_body if not s.is_composite())
    basic_created_in_orelse = tuple(
        s for s in created_in_orelse if not s.is_composite())

    # These variables are defined only in a single branch. This is fine in
    # Python so we pass them through. Another backend, e.g. Tensorflow, may need
    # to handle these cases specially or throw an Error.
    possibly_undefined = (set(basic_created_in_body) ^
                          set(basic_created_in_orelse))

    # Alias the closure variables inside the conditional functions, to allow
    # the functions access to the respective variables.
    # We will alias variables independently for body and orelse scope,
    # because different branches might write different variables.
    aliased_body_orig_names = tuple(need_alias_in_body)
    aliased_orelse_orig_names = tuple(need_alias_in_orelse)
    aliased_body_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
        for s in aliased_body_orig_names)
    aliased_orelse_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
        for s in aliased_orelse_orig_names)

    alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
    alias_orelse_map = dict(
        zip(aliased_orelse_orig_names, aliased_orelse_new_names))

    node_body = ast_util.rename_symbols(node.body, alias_body_map)
    node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

    cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced)
    body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
    orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)

    returned_from_cond = tuple(returned_from_cond)
    if returned_from_cond:
      if len(returned_from_cond) == 1:
        cond_results = returned_from_cond[0]
      else:
        cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)

      returned_from_body = tuple(
          alias_body_map[s] if s in need_alias_in_body else s
          for s in returned_from_cond)
      returned_from_orelse = tuple(
          alias_orelse_map[s] if s in need_alias_in_orelse else s
          for s in returned_from_cond)

    else:
      # When the cond would return no value, we leave the cond called without
      # results. That in turn should trigger the side effect guards. The
      # branch functions will return a dummy value that ensures cond
      # actually has some return value as well.
      cond_results = None
      # TODO(mdan): Replace with None once side_effect_guards is retired.
      returned_from_body = (templates.replace_as_expression(
          'ag__.match_staging_level(1, cond_var_name)',
          cond_var_name=cond_var_name),)
      returned_from_orelse = (templates.replace_as_expression(
          'ag__.match_staging_level(1, cond_var_name)',
          cond_var_name=cond_var_name),)

    cond_assign = self.create_assignment(cond_var_name, node.test)
    body_def = self._create_cond_branch(
        body_name,
        aliased_orig_names=aliased_body_orig_names,
        aliased_new_names=aliased_body_new_names,
        body=node_body,
        returns=returned_from_body)
    orelse_def = self._create_cond_branch(
        orelse_name,
        aliased_orig_names=aliased_orelse_orig_names,
        aliased_new_names=aliased_orelse_new_names,
        body=node_orelse,
        returns=returned_from_orelse)
    undefined_assigns = self._create_undefined_assigns(possibly_undefined)

    cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name,
                                       orelse_name)

    return (undefined_assigns
            + cond_assign
            + body_def
            + orelse_def
            + cond_expr)
Beispiel #49
0
  def visit_If(self, node):
    node = self.generic_visit(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)

    modified_in_cond = body_scope.modified | orelse_scope.modified
    returned_from_cond = set()
    for s in modified_in_cond:
      if s in live_out:
        returned_from_cond.add(s)
      elif s.is_composite():
        # Special treatment for compound objects: if any of their owner entities
        # are live, then they are outputs as well.
        if live_out & s.owner_set:
          returned_from_cond.add(s)

    need_alias_in_body = body_scope.modified & defined_in
    need_alias_in_orelse = orelse_scope.modified & defined_in

    created_in_body = body_scope.modified & returned_from_cond - defined_in
    created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in

    if created_in_body != created_in_orelse:
      raise ValueError(
          'if statement may not initialize all variables: the true branch'
          ' creates %s, while the false branch creates %s. Make sure all'
          ' these variables are initialized either in both'
          ' branches or before the if statement.' %
          (self._fmt_symbols(created_in_body),
           self._fmt_symbols(created_in_orelse)))

    # Alias the closure variables inside the conditional functions, to allow
    # the functions access to the respective variables.
    # We will alias variables independently for body and orelse scope,
    # because different branches might write different variables.
    aliased_body_orig_names = tuple(need_alias_in_body)
    aliased_orelse_orig_names = tuple(need_alias_in_orelse)
    aliased_body_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
        for s in aliased_body_orig_names)
    aliased_orelse_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
        for s in aliased_orelse_orig_names)

    alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
    alias_orelse_map = dict(
        zip(aliased_orelse_orig_names, aliased_orelse_new_names))

    node_body = ast_util.rename_symbols(node.body, alias_body_map)
    node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

    returned_from_cond = tuple(returned_from_cond)
    if returned_from_cond:
      if len(returned_from_cond) == 1:
        cond_results = returned_from_cond[0]
      else:
        cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)

      returned_from_body = tuple(
          alias_body_map[s] if s in need_alias_in_body else s
          for s in returned_from_cond)
      returned_from_orelse = tuple(
          alias_orelse_map[s] if s in need_alias_in_orelse else s
          for s in returned_from_cond)

    else:
      # When the cond would return no value, we leave the cond called without
      # results. That in turn should trigger the side effect guards. The
      # branch functions will return a dummy value that ensures cond
      # actually has some return value as well.
      cond_results = None
      # TODO(mdan): This doesn't belong here; it's specific to the operator.
      returned_from_body = (templates.replace_as_expression('tf.constant(1)'),)
      returned_from_orelse = (
          templates.replace_as_expression('tf.constant(1)'),)

    body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
    orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)

    body_def = self._create_cond_branch(
        body_name,
        aliased_orig_names=aliased_body_orig_names,
        aliased_new_names=aliased_body_new_names,
        body=node_body,
        returns=returned_from_body)
    orelse_def = self._create_cond_branch(
        orelse_name,
        aliased_orig_names=aliased_orelse_orig_names,
        aliased_new_names=aliased_orelse_new_names,
        body=node_orelse,
        returns=returned_from_orelse)
    cond_expr = self._create_cond_expr(cond_results, node.test, body_name,
                                       orelse_name)

    return body_def + orelse_def + cond_expr
 def visit_IfExp(self, node):
   return templates.replace_as_expression(
       'ag__.if_stmt(test, lambda: true_expr, lambda: false_expr)',
       test=node.test,
       true_expr=node.body,
       false_expr=node.orelse)
Beispiel #51
0
 def test_function_call_in_list(self):
   template = """
       foo(bar)
   """
   source = parser.parse_expression('[a(b(1))]')
   templates.replace_as_expression(template, bar=source)
Beispiel #52
0
 def visit_List(self, node):
   node = self.generic_visit(node)
   template = """
     ag__.new_list(elements)
   """
   return templates.replace_as_expression(template, elements=node)