コード例 #1
0
ファイル: transformer.py プロジェクト: Cerebras/ptwse
    def visit(self, node):
        if not isinstance(node, gast.AST):
            # This is not that uncommon a mistake: various node bodies are lists, for
            # example, posing a land mine for transformers that need to recursively
            # call `visit`.  The error needs to be raised before the exception handler
            # below is installed, because said handler will mess up if `node` is not,
            # in fact, a node.
            msg = ('invalid value for "node": expected "ast.AST", got "{}"; to'
                   ' visit lists of nodes, use "visit_block" instead').format(
                       type(node))
            raise ValueError(msg)

        if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
            return node

        parent_origin = self.ctx.current_origin
        if anno.hasanno(node, anno.Basic.ORIGIN):
            self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)

        try:
            processing_expr_node = isinstance(node, gast.Expr)
            if processing_expr_node:
                entry_expr_value = node.value

            result = super(Base, self).visit(node)

            # Adjust for consistency: replacing the value of an Expr with
            # an Assign node removes the need for the Expr node.
            if (processing_expr_node and isinstance(result, gast.Expr)
                    and (result.value is not entry_expr_value)):
                # When the replacement is a list, it is assumed that the list came
                # from a template that contained a number of statements, which
                # themselves are standalone and don't require an enclosing Expr.
                if isinstance(result.value,
                              (list, tuple, gast.Assign, gast.AugAssign)):
                    result = result.value

            # By default, all replacements receive the origin info of the replaced
            # node.
            if result is not node and result is not None:
                inherited_origin = anno.getanno(node,
                                                anno.Basic.ORIGIN,
                                                default=parent_origin)
                if inherited_origin is not None:
                    nodes_to_adjust = result
                    if isinstance(result, (list, tuple)):
                        nodes_to_adjust = result
                    else:
                        nodes_to_adjust = (result, )
                    for n in nodes_to_adjust:
                        if not anno.hasanno(n, anno.Basic.ORIGIN):
                            anno.setanno(n, anno.Basic.ORIGIN,
                                         inherited_origin)
        finally:
            self.ctx.current_origin = parent_origin

        return result
コード例 #2
0
  def get_definition_directive(self, node, directive, arg, default):
    """Returns the unique directive argument for a symbol.

    See lang/directives.py for details on directives.

    Example:
       # Given a directive in the code:
       ag.foo_directive(bar, baz=1)

       # One can write for an AST node Name(id='bar'):
       get_definition_directive(node, ag.foo_directive, 'baz')

    Args:
      node: ast.AST, the node representing the symbol for which the directive
        argument is needed.
      directive: Callable[..., Any], the directive to search.
      arg: str, the directive argument to return.
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
    defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
    if not defs:
      return default

    arg_values_found = []
    for def_ in defs:
      if (directive in def_.directives and arg in def_.directives[directive]):
        arg_values_found.append(def_.directives[directive][arg])

    if not arg_values_found:
      return default

    if len(arg_values_found) == 1:
      return arg_values_found[0]

    # If multiple annotations reach the symbol, they must all match. If they do,
    # return any of them.
    first_value = arg_values_found[0]
    for other_value in arg_values_found[1:]:
      if not ast_util.matches(first_value, other_value):
        qn = anno.getanno(node, anno.Basic.QN)
        raise ValueError(
            '%s has ambiguous annotations for %s(%s): %s, %s' %
            (qn, directive.__name__, arg, parser.unparse(other_value).strip(),
             parser.unparse(first_value).strip()))
    return first_value
コード例 #3
0
    def visit_While(self, node):
        original_node = node
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

        node.test = self.visit(node.test)
        node.body, break_used = self._process_body(node.body, break_var)
        # A break in the else clause applies to the containing scope.
        node.orelse = self.visit_block(node.orelse)

        if break_used:
            # Python's else clause only triggers if the loop exited cleanly (e.g.
            # break did not trigger).
            guarded_orelse = self._guard_if_present(node.orelse, break_var)

            template = """
        var_name = False
        while ag__.and_(lambda: test, lambda: ag__.not_(var_name)):
          body
        else:
          orelse
      """
            node = templates.replace(template,
                                     var_name=break_var,
                                     test=node.test,
                                     body=node.body,
                                     orelse=guarded_orelse)

            new_while_node = node[1]
            anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)

        return node
コード例 #4
0
    def visit_Lambda(self, node):
        self.state[_Function].enter()
        node = self.generic_visit(node)

        # Only wrap the top-level function. Theoretically, we can and should wrap
        # everything, but that can lead to excessive boilerplate when lambdas are
        # nested.
        # TODO(mdan): Looks more closely for use cases that actually require this.
        if self.state[_Function].level > 2:
            self.state[_Function].exit()
            return node

        scope = anno.getanno(node, anno.Static.SCOPE)
        function_context_name = self.ctx.namer.new_symbol(
            'lscope', scope.referenced)
        self.state[_Function].context_name = function_context_name
        anno.setanno(node, 'function_context_name', function_context_name)

        template = """
      ag__.with_function_scope(
          lambda function_context: body, function_context_name, options)
    """
        node.body = templates.replace_as_expression(
            template,
            options=self._function_scope_options().to_ast(),
            function_context=function_context_name,
            function_context_name=gast.Constant(function_context_name,
                                                kind=None),
            body=node.body)

        self.state[_Function].exit()
        return node
コード例 #5
0
ファイル: call_trees.py プロジェクト: Cerebras/ptwse
    def visit_FunctionDef(self, node):
        self.state[_Function].enter()
        # Note: if the conversion process ever creates helper functions, this
        # assumption will no longer hold.
        assert anno.hasanno(node, 'function_context_name'), (
            'The function_scopes converter always creates a scope for functions.'
        )
        self.state[_Function].context_name = anno.getanno(
            node, 'function_context_name')
        node.args = self.visit(node.args)
        node.body = self.visit_block(node.body)

        if self.state[_Function].level < 2:
            # Top-level functions lose their decorator because the conversion is
            # always just-in-time and by the time it happens the decorators are
            # already set to be applied.
            node.decorator_list = []
        else:
            # TODO(mdan): Fix the tests so that we can always add this decorator.
            # Inner functions are converted already, so we insert a decorator to
            # prevent double conversion. Double conversion would work too, but this
            # saves the overhead.
            node.decorator_list.append(
                parser.parse_expression('ag__.autograph_artifact'))

        if node.returns:
            node.returns = self.visit(node.returns)

        self.state[_Function].exit()
        return node
コード例 #6
0
ファイル: qual_names.py プロジェクト: Cerebras/ptwse
 def visit_Attribute(self, node):
     node = self.generic_visit(node)
     if anno.hasanno(node.value, anno.Basic.QN):
         anno.setanno(
             node, anno.Basic.QN,
             QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr))
     return node
コード例 #7
0
    def visit_FunctionDef(self, node):
        parent_analyzer = self.current_analyzer
        subgraph = self.graphs[node]

        # Preorder tree processing:
        #  1. if this is a child function, the parent was already analyzed and it
        #     has the proper state value for the subgraph's entry
        #  2. analyze the current function body
        #  2. recursively walk the subtree; child functions will be processed
        analyzer = Analyzer(subgraph, self.definition_factory)
        if parent_analyzer is not None:
            # Wire the state between the two subgraphs' analyzers.
            parent_out_state = parent_analyzer.out[
                parent_analyzer.graph.index[node]]
            # Exception: symbols modified in the child function are local to it
            body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
            parent_out_state -= body_scope.modified
            analyzer.extra_in[node.args] = parent_out_state

        # Complete the analysis for the local function and annotate its body.
        analyzer.visit_forward()

        # Recursively process any remaining subfunctions.
        self.current_analyzer = analyzer
        # Note: not visiting name, decorator_list and returns because they don't
        # apply to this analysis.
        # TODO(mdan): Should we still process the function name?
        node.args = self.visit(node.args)
        node.body = self.visit_block(node.body)
        self.current_analyzer = parent_analyzer

        return node
コード例 #8
0
ファイル: cfg.py プロジェクト: Cerebras/ptwse
    def visit_For(self, node):
        self.builder.begin_statement(node)
        self._enter_lexical_scope(node)

        self.builder.enter_section(node)

        # Note: Strictly speaking, this should be node.target + node.iter.
        # However, the activity analysis accounts for this inconsistency,
        # so dataflow analysis produces the correct values.
        self.builder.enter_loop_section(node, node.iter)
        # Also include the "extra loop test" annotation, to capture things like the
        # control variable for return and break in for loops.
        if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
            self._process_basic_statement(
                anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST))
        for stmt in node.body:
            self.visit(stmt)
        self.builder.exit_loop_section(node)

        # Note: although the orelse is technically part of the loop node,
        # they don't count as loop bodies.  For example, a break in the loop's
        # orelse will affect the parent loop, not the current one.
        self._exit_lexical_scope(node)

        for stmt in node.orelse:
            self.visit(stmt)

        self.builder.exit_section(node)
        self.builder.end_statement(node)
コード例 #9
0
 def _node_sets_self_attribute(self, node):
     if anno.hasanno(node, anno.Basic.QN):
         qn = anno.getanno(node, anno.Basic.QN)
         # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'.
         if qn.has_attr and qn.parent.qn == ('self', ):
             return True
     return False
コード例 #10
0
 def visit_Attribute(self, node):
     node = self.generic_visit(node)
     parent_val = anno.getanno(node.value, STATIC_VALUE, default=None)
     if parent_val is not None and inspect.ismodule(parent_val):
         if hasattr(parent_val, node.attr):
             anno.setanno(node, STATIC_VALUE,
                          getattr(parent_val, node.attr))
     return node
コード例 #11
0
    def _postprocess_statement(self, node):
        # If the node definitely returns (e.g. it's a with statement with a
        # return statement in it), then the current block also definitely returns.
        if anno.getanno(node, STMT_DEFINITELY_RETURNS, default=False):
            self.state[_RewriteBlock].definitely_returns = True

        # The special case: collapse a typical conditional return pattern into
        # a single conditional with possibly returns on both branches. This
        # reduces the use of None return values, which don't work with TF
        # conditionals.
        if (isinstance(node, gast.If) and anno.getanno(
                node, BODY_DEFINITELY_RETURNS, default=False)):
            return node, node.orelse
        elif (isinstance(node, gast.If) and anno.getanno(
                node, ORELSE_DEFINITELY_RETURNS, default=False)):
            return node, node.body

        return node, None
コード例 #12
0
    def _get_loop_vars(self, node, modified):
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
        live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
        live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
        reserved_symbols = body_scope.referenced

        basic_loop_vars = self._get_basic_loop_vars(modified, live_in,
                                                    live_out)
        composite_loop_vars = self._get_composite_loop_vars(modified, live_in)
        loop_vars = tuple(basic_loop_vars | composite_loop_vars)

        # Variable that are used or defined inside the loop, but not defined
        # before entering the loop. Only simple variables must be defined. The
        # composite ones will be implicitly checked at runtime.
        undefined_lives = basic_loop_vars - defined_in

        return loop_vars, reserved_symbols, undefined_lives
コード例 #13
0
 def visit_Name(self, node):
     node = self.generic_visit(node)
     if isinstance(node.ctx, gast.Load):
         defs = anno.getanno(node, anno.Static.DEFINITIONS, ())
         is_defined = bool(defs)
         if not is_defined and node.id in self.ctx.info.namespace:
             anno.setanno(node, STATIC_VALUE,
                          self.ctx.info.namespace[node.id])
     return node
コード例 #14
0
 def _process_symbol_directive(self, call_node, directive):
     if len(call_node.args) < 1:
         raise ValueError('"%s" requires a positional first argument'
                          ' as the target' % directive.__name__)
     target = call_node.args[0]
     defs = anno.getanno(target, anno.Static.ORIG_DEFINITIONS)
     for def_ in defs:
         def_.directives[directive] = _map_args(call_node, directive)
     return call_node
コード例 #15
0
 def visit_ExceptHandler(self, node):
     self._enter_scope(False)
     # try/except oddity: as expected, it leaks any names you defined inside the
     # except block, but not the name of the exception variable.
     if node.name is not None:
         self.scope.isolated_names.add(
             anno.getanno(node.name, anno.Basic.QN))
     node = self.generic_visit(node)
     self._exit_scope()
     return node
コード例 #16
0
    def visit_FunctionDef(self, node):
        self.state[_Function].enter()
        self.state[_Block].enter()
        self.state[_Block].is_function = True

        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        do_return_var_name = self.ctx.namer.new_symbol('do_return',
                                                       scope.referenced)
        retval_var_name = self.ctx.namer.new_symbol('retval_',
                                                    scope.referenced)
        self.state[_Function].do_return_var_name = do_return_var_name
        self.state[_Function].retval_var_name = retval_var_name

        converted_body = self._visit_statement_block(node, node.body)

        # Avoid placing statements before any eventual docstring.
        # TODO(mdan): Should a docstring even be included in the output?
        docstring = None
        if converted_body:
            if (isinstance(converted_body[0], gast.Expr)
                    and isinstance(converted_body[0].value, gast.Constant)):
                docstring = converted_body[0]
                converted_body = converted_body[1:]

        if self.state[_Block].return_used:

            if self.default_to_null_return:
                # TODO(mdan): Remove the (do_return_var_name,) below.
                # Currently, that line ensures the variable is both defined and alive
                # throughout the function.
                template = """
          do_return_var_name = False
          retval_var_name = ag__.UndefinedReturnValue()
          body
          (do_return_var_name,)
          return ag__.retval(retval_var_name)
        """
            else:
                template = """
          body
          return retval_var_name
        """
            node.body = templates.replace(
                template,
                body=converted_body,
                do_return_var_name=do_return_var_name,
                retval_var_name=retval_var_name)

            if docstring:
                node.body.insert(0, docstring)

        self.state[_Block].exit()
        self.state[_Function].exit()
        return node
コード例 #17
0
ファイル: liveness.py プロジェクト: Cerebras/ptwse
 def _block_statement_live_in(self, node, entry_node):
     if entry_node in self.current_analyzer.graph.index:
         cfg_node = self.current_analyzer.graph.index[entry_node]
         stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node])
     else:
         assert anno.hasanno(entry_node, anno.Static.LIVE_VARS_IN), (
             'If not matching a CFG node, must be a block statement:'
             ' {}'.format(entry_node))
         stmt_live_in = anno.getanno(entry_node, anno.Static.LIVE_VARS_IN)
     anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in)
     return node
コード例 #18
0
ファイル: call_trees.py プロジェクト: Cerebras/ptwse
 def visit_Lambda(self, node):
     if anno.hasanno(node, 'function_context_name'):
         # Lambda functions created during the conversion process have no
         # context manager.
         self.state[_Function].enter()
         self.state[_Function].context_name = anno.getanno(
             node, 'function_context_name')
         node = self.generic_visit(node)
         self.state[_Function].exit()
     else:
         node = self.generic_visit(node)
     return node
コード例 #19
0
ファイル: ast_util.py プロジェクト: Cerebras/ptwse
 def _process_name_node(self, node):
     qn = anno.getanno(node, anno.Basic.QN)
     if qn in self.name_map:
         new_node = gast.Name(str(self.name_map[qn]),
                              ctx=node.ctx,
                              annotation=None,
                              type_comment=None)
         # All annotations get carried over.
         for k in anno.keys(node):
             anno.copyanno(node, new_node, k)
         return new_node
     return self.generic_visit(node)
コード例 #20
0
    def visit_While(self, node):
        node = self.generic_visit(node)
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)

        loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars(
            node, body_scope.modified)

        undefined_assigns = self._create_undefined_assigns(possibly_undefs)

        nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)

        state_getter_name = self.ctx.namer.new_symbol('get_state',
                                                      reserved_symbols)
        state_setter_name = self.ctx.namer.new_symbol('set_state',
                                                      reserved_symbols)
        state_functions = self._create_state_functions(loop_vars,
                                                       nonlocal_declarations,
                                                       state_getter_name,
                                                       state_setter_name)

        opts = self._create_loop_options(node)

        template = """
      state_functions
      def body_name():
        nonlocal_declarations
        body
      def test_name():
        return test
      undefined_assigns
      ag__.pt_while_stmt(
          test_name,
          body_name,
          state_getter_name,
          state_setter_name,
          (symbol_names,),
          opts)
    """
        return templates.replace(
            template,
            body=node.body,
            body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
            nonlocal_declarations=nonlocal_declarations,
            opts=opts,
            state_functions=state_functions,
            state_getter_name=state_getter_name,
            state_setter_name=state_setter_name,
            symbol_names=tuple(
                gast.Constant(str(s), kind=None) for s in loop_vars),
            test=node.test,
            test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
            undefined_assigns=undefined_assigns)
コード例 #21
0
 def _process_statement_directive(self, call_node, directive):
     if self.state[_LoopScope].statements_visited > 1:
         raise ValueError(
             '"%s" must be the first statement in the loop block' %
             (directive.__name__))
     if self.state[_LoopScope].level < 2:
         raise ValueError('"%s" must be used inside a statement' %
                          directive.__name__)
     target = self.state[_LoopScope].ast_node
     node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {})
     node_anno[directive] = _map_args(call_node, directive)
     anno.setanno(target, anno.Basic.DIRECTIVES, node_anno)
     return call_node
コード例 #22
0
    def _create_loop_options(self, node):
        if not anno.hasanno(node, anno.Basic.DIRECTIVES):
            return gast.Dict([], [])

        loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES)
        if directives.set_loop_options not in loop_directives:
            return gast.Dict([], [])

        opts_dict = loop_directives[directives.set_loop_options]
        str_keys, values = zip(*opts_dict.items())
        keys = [gast.Constant(s, kind=None) for s in str_keys]
        values = list(values)  # ast and gast don't play well with tuples.
        return gast.Dict(keys, values)
コード例 #23
0
ファイル: qual_names.py プロジェクト: Cerebras/ptwse
 def visit_Subscript(self, node):
     # TODO(mdan): This may no longer apply if we overload getitem.
     node = self.generic_visit(node)
     s = node.slice
     if not isinstance(s, gast.Index):
         # TODO(mdan): Support range and multi-dimensional indices.
         # Continuing silently because some demos use these.
         return node
     if isinstance(s.value, gast.Constant):
         subscript = QN(NumberLiteral(s.value.value))
     else:
         # The index may be an expression, case in which a name doesn't make sense.
         if anno.hasanno(node.slice.value, anno.Basic.QN):
             subscript = anno.getanno(node.slice.value, anno.Basic.QN)
         else:
             return node
     if anno.hasanno(node.value, anno.Basic.QN):
         anno.setanno(
             node, anno.Basic.QN,
             QN(anno.getanno(node.value, anno.Basic.QN),
                subscript=subscript))
     return node
コード例 #24
0
ファイル: lists.py プロジェクト: Cerebras/ptwse
    def _replace_pop_call(self, node):
        # Expressions that use pop() are converted to a statement + expression.
        #
        # For example:
        #
        #   print(target.pop())
        #
        # ... is converted to:
        #
        #   target, target_pop = ag__.list_pop(target)
        #   print(target_pop)
        #
        # Here, we just generate the variable name and swap it in,
        # and _generate_pop_operation will handle the rest.
        #
        # Multiple uses of pop() are allowed:
        #
        #   print(tartget.pop(), target.pop())
        #   print(tartget.pop().pop())
        #
        assert isinstance(node.func, gast.Attribute)
        scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
        target_node = node.func.value

        # Attempt to use a related name if one exists. Otherwise use something
        # generic.
        if anno.hasanno(target_node, anno.Basic.QN):
            target_name = anno.getanno(target_node, anno.Basic.QN).ssf()
        else:
            target_name = 'list_'
        pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced)

        stmt = self.state[_Statement]
        if stmt.pop_uses is None:
            stmt.pop_uses = []
        stmt.pop_uses.append((node, pop_var_name))

        return templates.replace_as_expression('var_name',
                                               var_name=pop_var_name)
コード例 #25
0
    def _track_symbol(self, node, composite_writes_alter_parent=False):
        # A QN may be missing when we have an attribute (or subscript) on a function
        # call. Example: a().b
        if not anno.hasanno(node, anno.Basic.QN):
            return
        qn = anno.getanno(node, anno.Basic.QN)

        # When inside a comprehension, ignore reads to any of the comprehensions's
        # targets. This includes attributes or slices of those arguments.
        for l in self.state[_Comprehension]:
            if qn in l.targets:
                return
            if qn.owner_set & set(l.targets):
                return

        if isinstance(node.ctx, gast.Store):
            # In comprehensions, modified symbols are the comprehension targets.
            if self.state[_Comprehension].level > 0:
                self.state[_Comprehension].targets.add(qn)
                # List comprehension targets leak in Python 2.
                # For details, see:
                # https://stackoverflow.com/questions/4198906/list-comprehension-rebinds-names-even-after-scope-of-comprehension-is-this-righ
                if not (six.PY2 and self.state[_Comprehension].is_list_comp):
                    return

            self.scope.modified.add(qn)
            self.scope.bound.add(qn)
            if qn.is_composite and composite_writes_alter_parent:
                self.scope.modified.add(qn.parent)
            if self._in_aug_assign:
                self.scope.read.add(qn)

        elif isinstance(node.ctx, gast.Load):
            self.scope.read.add(qn)

        elif isinstance(node.ctx, gast.Param):
            self.scope.bound.add(qn)
            self.scope.mark_param(qn, self.state[_FunctionOrClass].node)

        elif isinstance(node.ctx, gast.Del):
            # The read matches the Python semantics - attempting to delete an
            # undefined symbol is illegal.
            self.scope.read.add(qn)
            # Targets of del are considered bound:
            # https://docs.python.org/3/reference/executionmodel.html#binding-of-names
            self.scope.bound.add(qn)
            self.scope.deleted.add(qn)

        else:
            raise ValueError('Unknown context {} for node "{}".'.format(
                type(node.ctx), qn))
コード例 #26
0
    def _determine_aliased_symbols(self, scope, node_defined_in, block):
        if block:
            block_live_in = set(
                anno.getanno(block[0], anno.Static.LIVE_VARS_IN))
        else:
            block_live_in = set()

        modified_live = scope.modified & node_defined_in & block_live_in
        # Composite symbols are handled elsewhere, see _create_state_functions
        return {
            s
            for s in modified_live if not s.is_composite()
            and s not in self.state[_Function].scope.globals
        }
コード例 #27
0
ファイル: call_trees.py プロジェクト: Cerebras/ptwse
    def visit_Call(self, node):
        full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
        function_context_name = self.state[_Function].context_name
        node = self.generic_visit(node)

        # TODO(mdan): Refactor converted_call as a 'Call' operator.

        # Calls to the internal 'ag__' module are never converted (though their
        # arguments might be).
        if full_name.startswith('ag__.'):
            return node

        # Calls to the function context manager (inserted by function_scopes) are
        # also safe.
        if full_name.startswith(function_context_name + '.'):
            return node

        # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
        # the normal mechanisms to bypass these literals because they are sensitive
        # to the frame they are being called from.
        # TODO(mdan): Generalize this to a "static whitelist" config.
        if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'):
            global set_trace_warned
            if not set_trace_warned:
                # TODO(mdan): Update and shorten once available on tensorflow.org.
                ag_logging.warn(
                    'Detected `pdb.set_trace()` in user code. The code'
                    ' generated by AutoGraph is not optimized for step-by-step'
                    ' debugging. See https://github.com/tensorflow/tensorflow/'
                    'blob/master/tensorflow/python/autograph/g3doc/reference/'
                    'debugging.md.')
                set_trace_warned = True
            return node

        if (full_name == 'print' and not self.ctx.program.options.uses(
                converter.Feature.BUILTIN_FUNCTIONS)):
            return node

        template = """
      ag__.converted_call(func, args, kwargs, function_ctx)
    """
        new_call = templates.replace_as_expression(
            template,
            func=node.func,
            args=self._args_to_tuple(node),
            kwargs=self._kwargs_to_dict(node),
            function_ctx=function_context_name)

        return new_call
コード例 #28
0
    def visit_For(self, node):
        self._enter_scope(False)
        node.target = self.visit(node.target)
        node.iter = self.visit(node.iter)
        self._exit_and_record_scope(node.iter)

        self._enter_scope(False)
        self.visit(node.target)
        if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
            self._process_statement(
                anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST))
        self._exit_and_record_scope(node, tag=NodeAnno.ITERATE_SCOPE)

        node = self._process_parallel_blocks(
            node, ((node.body, NodeAnno.BODY_SCOPE),
                   (node.orelse, NodeAnno.ORELSE_SCOPE)))
        return node
コード例 #29
0
    def visit_Expr(self, node):
        self.state[_LoopScope].statements_visited += 1
        node = self.generic_visit(node)
        if isinstance(node.value, gast.Call):
            call_node = node.value
            static_val = anno.getanno(call_node.func,
                                      STATIC_VALUE,
                                      default=None)
            if static_val is not None:
                # Note: directive calls are not output in the generated code, hence
                # the removal from the code by returning None.

                if static_val is directives.set_element_type:
                    self._process_symbol_directive(call_node, static_val)
                    return None
                elif static_val is directives.set_loop_options:
                    self._process_statement_directive(call_node, static_val)
                    return None
        return node
コード例 #30
0
    def visit_For(self, node):
        original_node = node
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

        node.target = self.visit(node.target)
        node.iter = self.visit(node.iter)
        node.body, break_used = self._process_body(node.body, break_var)
        # A break in the else clause applies to the containing scope.
        node.orelse = self.visit_block(node.orelse)

        if break_used:
            # Python's else clause only triggers if the loop exited cleanly (e.g.
            # break did not trigger).
            guarded_orelse = self._guard_if_present(node.orelse, break_var)
            extra_test = templates.replace_as_expression('ag__.not_(var_name)',
                                                         var_name=break_var)

            # The extra test is hidden in the AST, which will confuse the static
            # analysis. To mitigate that, we insert a no-op statement that ensures
            # the control variable is marked as used.
            # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
            template = """
        var_name = False
        for target in iter_:
          (var_name,)
          body
        else:
          orelse
      """
            node = templates.replace(template,
                                     var_name=break_var,
                                     iter_=node.iter,
                                     target=node.target,
                                     body=node.body,
                                     orelse=guarded_orelse)

            new_for_node = node[1]
            anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
            anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)

        return node