Example #1
0
 def _create_state_functions(self, loop_vars, nonlocal_declarations,
                             getter_name, setter_name):
     if loop_vars:
         template = """
     def getter_name():
       return state_vars,
     def setter_name(loop_vars):
       nonlocal_declarations
       state_vars, = loop_vars
   """
         return templates.replace(
             template,
             nonlocal_declarations=nonlocal_declarations,
             getter_name=getter_name,
             setter_name=setter_name,
             state_vars=tuple(loop_vars))
     else:
         template = """
     def getter_name():
       return ()
     def setter_name(loop_vars):
       pass
   """
         return templates.replace(template,
                                  getter_name=getter_name,
                                  setter_name=setter_name)
Example #2
0
 def _create_cond_expr(self, results, test, body_name, orelse_name,
                       state_getter_name, state_setter_name,
                       basic_symbol_names, composite_symbol_names):
     if results is not None:
         template = """
     results = ag__.if_stmt(test, body_name, orelse_name,
                            state_getter_name, state_setter_name,
                            (basic_symbol_names,),
                            (composite_symbol_names,))
   """
         return templates.replace(
             template,
             test=test,
             results=results,
             body_name=body_name,
             orelse_name=orelse_name,
             state_getter_name=state_getter_name,
             state_setter_name=state_setter_name,
             basic_symbol_names=basic_symbol_names,
             composite_symbol_names=composite_symbol_names)
     else:
         template = """
     ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name,
                  (basic_symbol_names,), (composite_symbol_names,))
   """
         return templates.replace(
             template,
             test=test,
             body_name=body_name,
             orelse_name=orelse_name,
             getter_name=state_getter_name,
             setter_name=state_setter_name,
             basic_symbol_names=basic_symbol_names,
             composite_symbol_names=composite_symbol_names)
Example #3
0
 def _do_transform_node(self, node):
     temp_name = self._gensym.new_name()
     temp_assign = templates.replace('temp_name = expr',
                                     temp_name=temp_name,
                                     expr=node)[0]
     self._add_pending_statement(temp_assign)
     answer = templates.replace('temp_name', temp_name=temp_name)[0]
     return answer
Example #4
0
 def create_assignment(self, target, expression):
     template = """
   target = expression
 """
     return templates.replace(template,
                              target=target,
                              expression=expression)
Example #5
0
 def visit_Return(self, node):
     if node.value is None:
         return node
     return templates.replace(
         'return function_context_name.mark_return_value(value)',
         function_context_name=self.state[_Function].context_name,
         value=node.value)
Example #6
0
    def visit_While(self, node):
        original_node = node
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

        node.test = self.visit(node.test)
        node.body, break_used = self._process_body(node.body, break_var)
        # A break in the else clause applies to the containing scope.
        node.orelse = self.visit_block(node.orelse)

        if break_used:
            # Python's else clause only triggers if the loop exited cleanly (e.g.
            # break did not trigger).
            guarded_orelse = self._guard_if_present(node.orelse, break_var)

            template = """
        var_name = False
        while ag__.and_(lambda: test, lambda: ag__.not_(var_name)):
          body
        else:
          orelse
      """
            node = templates.replace(template,
                                     var_name=break_var,
                                     test=node.test,
                                     body=node.body,
                                     orelse=guarded_orelse)

            new_while_node = node[1]
            anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)

        return node
Example #7
0
    def visit_Return(self, node):
        for block in reversed(self.state[_Block].stack):
            block.return_used = True
            block.create_guard_next = True
            if block.is_function:
                break

        retval = node.value if node.value else parser.parse_expression('None')

        # Note: If `return <expr> raises, then the return is aborted.
        # The try-catch below ensures the variables remain consistent in that case.
        template = """
      try:
        do_return_var_name = True
        retval_var_name = retval
      except:
        do_return_var_name = False
        raise
    """
        node = templates.replace(
            template,
            do_return_var_name=self.state[_Function].do_return_var_name,
            retval_var_name=self.state[_Function].retval_var_name,
            retval=retval)

        return node
Example #8
0
  def to_ast(self):
    """Returns a representation of this object as an AST node.

    The AST node encodes a constructor that would create an object with the
    same contents.

    Returns:
      ast.Node
    """
    if self == STANDARD_OPTIONS:
      return parser.parse_expression('ag__.STD')

    template = """
      ag__.ConversionOptions(
          recursive=recursive_val,
          user_requested=user_requested_val,
          optional_features=optional_features_val,
          internal_convert_user_code=internal_convert_user_code_val)
    """

    def list_of_features(values):
      return parser.parse_expression('({})'.format(', '.join(
          'ag__.{}'.format(str(v)) for v in values)))

    expr_ast = templates.replace(
        template,
        recursive_val=parser.parse_expression(str(self.recursive)),
        user_requested_val=parser.parse_expression(str(self.user_requested)),
        internal_convert_user_code_val=parser.parse_expression(
            str(self.internal_convert_user_code)),
        optional_features_val=list_of_features(self.optional_features))
    return expr_ast[0].value
Example #9
0
 def _replace_append_call(self, node):
     assert len(node.args) == 1
     assert isinstance(node.func, gast.Attribute)
     template = """
   target = ag__.list_append(target, element)
 """
     return templates.replace(template,
                              target=node.func.value,
                              element=node.args[0])
Example #10
0
 def visit_Break(self, node):
     self.state[_Break].used = True
     var_name = self.state[_Break].control_var_name
     # TODO(mdan): This will fail when expanded inside a top-level else block.
     template = """
   var_name = True
   continue
 """
     return templates.replace(template, var_name=var_name)
Example #11
0
    def visit_Assert(self, node):
        self.generic_visit(node)

        # Note: The lone tf.Assert call will be wrapped with control_dependencies
        # by side_effect_guards.
        template = """
      ag__.assert_stmt(test, lambda: msg)
    """

        if node.msg is None:
            return templates.replace(template,
                                     test=node.test,
                                     msg=gast.Constant('Assertion error',
                                                       kind=None))
        elif isinstance(node.msg, gast.Constant):
            return templates.replace(template, test=node.test, msg=node.msg)
        else:
            raise NotImplementedError(
                'can only convert string messages for now.')
Example #12
0
  def _process_single_assignment(self, target, value):
    if not isinstance(target, gast.Subscript):
      return None
    if not isinstance(target.slice, gast.Index):
      return None

    template = """
      target = ag__.set_item(target, key, item)
    """
    return templates.replace(
        template, target=target.value, key=target.slice.value, item=value)
Example #13
0
    def visit_FunctionDef(self, node):
        self.state[_Function].enter()
        self.state[_Block].enter()
        self.state[_Block].is_function = True

        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        do_return_var_name = self.ctx.namer.new_symbol('do_return',
                                                       scope.referenced)
        retval_var_name = self.ctx.namer.new_symbol('retval_',
                                                    scope.referenced)
        self.state[_Function].do_return_var_name = do_return_var_name
        self.state[_Function].retval_var_name = retval_var_name

        converted_body = self._visit_statement_block(node, node.body)

        # Avoid placing statements before any eventual docstring.
        # TODO(mdan): Should a docstring even be included in the output?
        docstring = None
        if converted_body:
            if (isinstance(converted_body[0], gast.Expr)
                    and isinstance(converted_body[0].value, gast.Constant)):
                docstring = converted_body[0]
                converted_body = converted_body[1:]

        if self.state[_Block].return_used:

            if self.default_to_null_return:
                # TODO(mdan): Remove the (do_return_var_name,) below.
                # Currently, that line ensures the variable is both defined and alive
                # throughout the function.
                template = """
          do_return_var_name = False
          retval_var_name = ag__.UndefinedReturnValue()
          body
          (do_return_var_name,)
          return ag__.retval(retval_var_name)
        """
            else:
                template = """
          body
          return retval_var_name
        """
            node.body = templates.replace(
                template,
                body=converted_body,
                do_return_var_name=do_return_var_name,
                retval_var_name=retval_var_name)

            if docstring:
                node.body.insert(0, docstring)

        self.state[_Block].exit()
        self.state[_Function].exit()
        return node
Example #14
0
 def _create_undefined_assigns(self, undefined_symbols):
     assignments = []
     for s in undefined_symbols:
         template = '''
     var = ag__.Undefined(symbol_name)
   '''
         assignments += templates.replace(template,
                                          var=s,
                                          symbol_name=gast.Constant(
                                              s.ssf(), kind=None))
     return assignments
Example #15
0
    def _guard_if_present(self, block, var_name):
        """Prevents the block from executing if var_name is set."""
        if not block:
            return block

        template = """
        if ag__.not_(var_name):
          block
      """
        node = templates.replace(template, var_name=var_name, block=block)
        return node
Example #16
0
    def visit_While(self, node):
        node = self.generic_visit(node)
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)

        loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars(
            node, body_scope.modified)

        undefined_assigns = self._create_undefined_assigns(possibly_undefs)

        nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)

        state_getter_name = self.ctx.namer.new_symbol('get_state',
                                                      reserved_symbols)
        state_setter_name = self.ctx.namer.new_symbol('set_state',
                                                      reserved_symbols)
        state_functions = self._create_state_functions(loop_vars,
                                                       nonlocal_declarations,
                                                       state_getter_name,
                                                       state_setter_name)

        opts = self._create_loop_options(node)

        template = """
      state_functions
      def body_name():
        nonlocal_declarations
        body
      def test_name():
        return test
      undefined_assigns
      ag__.pt_while_stmt(
          test_name,
          body_name,
          state_getter_name,
          state_setter_name,
          (symbol_names,),
          opts)
    """
        return templates.replace(
            template,
            body=node.body,
            body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
            nonlocal_declarations=nonlocal_declarations,
            opts=opts,
            state_functions=state_functions,
            state_getter_name=state_getter_name,
            state_setter_name=state_setter_name,
            symbol_names=tuple(
                gast.Constant(str(s), kind=None) for s in loop_vars),
            test=node.test,
            test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
            undefined_assigns=undefined_assigns)
Example #17
0
    def visit_Assign(self, node):
        if not isinstance(node.value, gast.ListComp):
            return self.generic_visit(node)
        if len(node.targets) > 1:
            raise NotImplementedError('multiple assignments')

        target, = node.targets
        list_comp_node = node.value

        template = """
      target = []
    """
        initialization = templates.replace(template, target=target)

        template = """
      target.append(elt)
    """
        body = templates.replace(template,
                                 target=target,
                                 elt=list_comp_node.elt)

        for gen in reversed(list_comp_node.generators):
            for gen_if in reversed(gen.ifs):
                template = """
          if test:
            body
        """
                body = templates.replace(template, test=gen_if, body=body)
            template = """
        for target in iter_:
          body
      """
            body = templates.replace(template,
                                     iter_=gen.iter,
                                     target=gen.target,
                                     body=body)

        return initialization + body
Example #18
0
    def _create_cond_branch(self, body_name, aliased_orig_names,
                            aliased_new_names, body, returns):
        if len(returns) == 1:
            template = """
        return retval
      """
            return_stmt = templates.replace(template, retval=returns[0])
        else:
            template = """
        return (retvals,)
      """
            return_stmt = templates.replace(template, retvals=returns)

        if aliased_orig_names:
            template = """
        def body_name():
          aliased_new_names, = aliased_orig_names,
          body
          return_stmt
      """
            return templates.replace(template,
                                     body_name=body_name,
                                     body=body,
                                     aliased_orig_names=aliased_orig_names,
                                     aliased_new_names=aliased_new_names,
                                     return_stmt=return_stmt)
        else:
            template = """
        def body_name():
          body
          return_stmt
      """
            return templates.replace(template,
                                     body_name=body_name,
                                     body=body,
                                     return_stmt=return_stmt)
Example #19
0
 def visit_Continue(self, node):
     self.state[_Continue].used = True
     for block in reversed(self.state[_Block].stack):
         # See ContinueCanonicalizationTest.test_multiple_continues for an example
         # it's necessary to create guards for all enclosing affected blocks, not
         # just that of the current block.
         block.create_guard_next = True
         if block.is_loop_type:
             # continue only affects the innermost loop
             break
     template = """
   var_name = True
 """
     return templates.replace(
         template, var_name=self.state[_Continue].control_var_name)
Example #20
0
 def _postprocess_statement(self, node):
     if self.state[_Continue].used:
         block = self.state[_Block]
         should_wrap_current = block.create_guard_current
         # After processing propagate whether to guard the next statement
         block.create_guard_current = block.create_guard_next
         block.create_guard_next = False
         if should_wrap_current:
             template = """
       if ag__.not_(var_name):
         original_node
     """
             cond, = templates.replace(
                 template,
                 var_name=self.state[_Continue].control_var_name,
                 original_node=node)
             return cond, cond.body
     return node, None
Example #21
0
    def visit_For(self, node):
        original_node = node
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

        node.target = self.visit(node.target)
        node.iter = self.visit(node.iter)
        node.body, break_used = self._process_body(node.body, break_var)
        # A break in the else clause applies to the containing scope.
        node.orelse = self.visit_block(node.orelse)

        if break_used:
            # Python's else clause only triggers if the loop exited cleanly (e.g.
            # break did not trigger).
            guarded_orelse = self._guard_if_present(node.orelse, break_var)
            extra_test = templates.replace_as_expression('ag__.not_(var_name)',
                                                         var_name=break_var)

            # The extra test is hidden in the AST, which will confuse the static
            # analysis. To mitigate that, we insert a no-op statement that ensures
            # the control variable is marked as used.
            # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
            template = """
        var_name = False
        for target in iter_:
          (var_name,)
          body
        else:
          orelse
      """
            node = templates.replace(template,
                                     var_name=break_var,
                                     iter_=node.iter,
                                     target=node.target,
                                     body=node.body,
                                     orelse=guarded_orelse)

            new_for_node = node[1]
            anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
            anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)

        return node
Example #22
0
    def _postprocess_statement(self, node):
        if not self.state[_Block].return_used:
            return node, None

        state = self.state[_Block]
        if state.create_guard_now:
            template = """
        if ag__.not_(do_return_var_name):
          original_node
      """
            cond, = templates.replace(
                template,
                do_return_var_name=self.state[_Function].do_return_var_name,
                original_node=node)
            node, block = cond, cond.body
        else:
            node, block = node, None

        state.create_guard_now = state.create_guard_next
        state.create_guard_next = False

        return node, block
Example #23
0
    def visit_FunctionDef(self, node):
        self.state[_Function].enter()
        scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)

        function_context_name = self.ctx.namer.new_symbol(
            'fscope', scope.referenced)
        self.state[_Function].context_name = function_context_name
        anno.setanno(node, 'function_context_name', function_context_name)

        node = self.generic_visit(node)

        docstring_node = None
        if node.body:
            first_statement = node.body[0]
            if (isinstance(first_statement, gast.Expr)
                    and isinstance(first_statement.value, gast.Constant)):
                docstring_node = first_statement
                node.body = node.body[1:]

        template = """
      with ag__.FunctionScope(
          function_name, context_name, options) as function_context:
        body
    """
        wrapped_body = templates.replace(
            template,
            function_name=gast.Constant(node.name, kind=None),
            context_name=gast.Constant(function_context_name, kind=None),
            options=self._function_scope_options().to_ast(),
            function_context=function_context_name,
            body=node.body)

        if docstring_node is not None:
            wrapped_body = [docstring_node] + wrapped_body

        node.body = wrapped_body

        self.state[_Function].exit()
        return node
Example #24
0
    def _visit_loop_body(self, node, nodes):
        self.state[_Continue].enter()
        self.state[_Block].enter()
        self.state[_Block].is_loop_type = True
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        continue_var = self.ctx.namer.new_symbol('continue_', scope.referenced)
        self.state[_Continue].control_var_name = continue_var

        nodes = self.visit_block(nodes,
                                 after_visit=self._postprocess_statement)

        if self.state[_Continue].used:
            template = """
        var_name = False
      """
            control_var_init = templates.replace(template,
                                                 var_name=continue_var)
            nodes = control_var_init + nodes

        self.state[_Block].exit()
        self.state[_Continue].exit()
        return nodes
Example #25
0
    def _generate_pop_operation(self, original_call_node, pop_var_name):
        assert isinstance(original_call_node.func, gast.Attribute)

        if original_call_node.args:
            pop_element = original_call_node.args[0]
        else:
            pop_element = parser.parse_expression('None')

        # The call will be something like "target.pop()", and the dtype is hooked to
        # target, hence the func.value.
        # TODO(mdan): For lists of lists, this won't work.
        # The reason why it won't work is because it's unclear how to annotate
        # the list as a "list of lists with a certain element type" when using
        # operations like `l.pop().pop()`.
        dtype = self.get_definition_directive(
            original_call_node.func.value,
            directives.set_element_type,
            'dtype',
            default=templates.replace_as_expression('None'))
        shape = self.get_definition_directive(
            original_call_node.func.value,
            directives.set_element_type,
            'shape',
            default=templates.replace_as_expression('None'))

        template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
        return templates.replace(template,
                                 target=original_call_node.func.value,
                                 pop_var_name=pop_var_name,
                                 element=pop_element,
                                 dtype=dtype,
                                 shape=shape)
Example #26
0
    def visit_For(self, node):
        node = self.generic_visit(node)
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE)

        loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars(
            node, body_scope.modified | iter_scope.modified)

        undefined_assigns = self._create_undefined_assigns(possibly_undefs)

        nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)

        state_getter_name = self.ctx.namer.new_symbol('get_state',
                                                      reserved_symbols)
        state_setter_name = self.ctx.namer.new_symbol('set_state',
                                                      reserved_symbols)
        state_functions = self._create_state_functions(loop_vars,
                                                       nonlocal_declarations,
                                                       state_getter_name,
                                                       state_setter_name)

        opts = self._create_loop_options(node)

        if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
            extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
            extra_test_name = self.ctx.namer.new_symbol(
                'extra_test', reserved_symbols)
            template = """
        def extra_test_name():
          nonlocal_declarations
          return extra_test_expr
      """
            extra_test_function = templates.replace(
                template,
                extra_test_expr=extra_test,
                extra_test_name=extra_test_name,
                loop_vars=loop_vars,
                nonlocal_declarations=nonlocal_declarations)
        else:
            extra_test_name = parser.parse_expression('None')
            extra_test_function = []

        # iterate_arg_name holds a single arg with the iterates, which may be a
        # tuple.
        iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved_symbols)
        template = """
      iterates = iterate_arg_name
    """
        iterate_expansion = templates.replace(
            template, iterate_arg_name=iterate_arg_name, iterates=node.target)

        template = """
      state_functions
      def body_name(iterate_arg_name):
        nonlocal_declarations
        iterate_expansion
        body
      extra_test_function
      undefined_assigns
      ag__.for_stmt(
          iterated,
          extra_test_name,
          body_name,
          state_getter_name,
          state_setter_name,
          (symbol_names,),
          opts)
    """
        return templates.replace(
            template,
            body=node.body,
            body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
            extra_test_function=extra_test_function,
            extra_test_name=extra_test_name,
            iterate_arg_name=iterate_arg_name,
            iterate_expansion=iterate_expansion,
            iterated=node.iter,
            nonlocal_declarations=nonlocal_declarations,
            opts=opts,
            symbol_names=tuple(
                gast.Constant(str(s), kind=None) for s in loop_vars),
            state_functions=state_functions,
            state_getter_name=state_getter_name,
            state_setter_name=state_setter_name,
            undefined_assigns=undefined_assigns)
Example #27
0
def _wrap_into_dynamic_factory(nodes, entity_name, factory_factory_name,
                               factory_name, closure_vars, future_features):
    """Wraps an AST into the body of a dynamic factory.

  This uses the dynamic factory (factory of factory) pattern to achieve the
  following:

   1. The inner factory, dynamically creates the entity represented by nodes.
   2. The entity is parametrized by `ag__`, the internal AutoGraph module.
   3. The outer factory creates the inner factory with a lexical scope
      in which `closure_vars` are bound local variables. This in turn allows the
      caller to control the exact closure (i.e. non-global free variables) for
      the inner factory.

  The AST is expected to define some symbol named by `entity_name`.

  Args:
    nodes: ast.AST
    entity_name: Union[Text, ast.AST]
    factory_factory_name: Text
    factory_name: Text
    closure_vars: Iterable[Text]
    future_features: Iterable[Text], see EntityInfo.future_features.

  Returns:
    ast.AST
  """
    if not isinstance(nodes, (list, tuple)):
        nodes = (nodes, )

    dummy_closure_defs = []
    for var_name in closure_vars:
        template = """
      var_name = None
    """
        dummy_closure_defs.extend(
            templates.replace(template, var_name=var_name))

    if future_features:
        future_imports = gast.ImportFrom(module='__future__',
                                         names=[
                                             gast.alias(name=name, asname=None)
                                             for name in future_features
                                         ],
                                         level=0)
    else:
        future_imports = []

    # These dummy symbol declarations create local fariables in a function scope,
    # so that the Python parser correctly marks them as free non-global variables
    # upon load (that is, it creates cell slots for each symbol). Their values are
    # not used, as the cells are swapped with the original entity's cells after
    # the code has been loaded.
    template = """
    future_imports
    def factory_factory_name():
      dummy_closure_defs
      def factory_name(ag__, ag_source_map__, ag_module__):
        entity_defs
        entity_name.ag_source_map = ag_source_map__
        entity_name.ag_module = ag_module__
        entity_name = ag__.autograph_artifact(entity_name)
        return entity_name
      return factory_name
  """
    return templates.replace(template,
                             future_imports=future_imports,
                             factory_factory_name=factory_factory_name,
                             factory_name=factory_name,
                             dummy_closure_defs=dummy_closure_defs,
                             entity_defs=nodes,
                             entity_name=entity_name)