Example #1
0
  def visit_Compare(self, node):
    node = self.generic_visit(node)

    if not all(self._has_matching_func(op) for op in node.ops):
      if len(node.ops) == 1:
        # Basic expressions are safe to leave as they are.
        return node
      else:
        raise NotImplementedError(
            'compound expression with at least one unsupported '
            'operator: {}'.format(node.ops))

    ops_and_comps = list(zip(node.ops, node.comparators))
    left = node.left
    op_tree = None

    # Repeated comparisons are converted to conjunctions:
    #   a < b < c   ->   a < b and b < c
    while ops_and_comps:
      op, right = ops_and_comps.pop(0)
      binary_comparison = self._as_function(
          self._matching_func(op), (left, right))
      if isinstance(left, gast.Name) and isinstance(right, gast.Name):
        anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True)
      if op_tree:
        self._expect_simple_symbol(right)
        op_tree = self._as_function('tf.logical_and',
                                    (binary_comparison, op_tree))
      else:
        op_tree = binary_comparison
      left = right
    assert op_tree is not None
    return op_tree
Example #2
0
  def _track_symbol(self,
                    node,
                    composite_writes_alter_parent=False,
                    writes_create_symbol=False):
    # A QN may be missing when we have an attribute (or subscript) on a function
    # call. Example: a().b
    if not anno.hasanno(node, anno.Basic.QN):
      return
    qn = anno.getanno(node, anno.Basic.QN)

    if isinstance(node.ctx, gast.Store):
      self.scope.mark_write(qn)
      if qn.is_composite and composite_writes_alter_parent:
        self.scope.mark_write(qn.parent)
      if writes_create_symbol:
        self.scope.mark_creation(qn, writes_create_symbol=True)
      if self._in_aug_assign:
        self.scope.mark_read(qn)
    elif isinstance(node.ctx, gast.Load):
      self.scope.mark_read(qn)
    elif isinstance(node.ctx, gast.Param):
      # Param contexts appear in function defs, so they have the meaning of
      # defining a variable.
      self.scope.mark_write(qn)
      self.scope.mark_param(qn, self.enclosing_entities[-1])
    else:
      raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))

    anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))

    if self._in_return_statement:
      self.scope.mark_returned(qn)
Example #3
0
    def visit_Compare(self, node):
        node = self.generic_visit(node)

        ops_and_comps = list(zip(node.ops, node.comparators))
        left = node.left
        op_tree = None

        # Repeated comparisons are converted to conjunctions:
        #   a < b < c   ->   a < b and b < c
        while ops_and_comps:
            op, right = ops_and_comps.pop(0)
            binary_comparison = self._as_function(self._matching_func(op),
                                                  (left, right))
            if isinstance(left, gast.Name) and isinstance(right, gast.Name):
                anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True)
            if op_tree:
                self._expect_simple_symbol(right)
                op_tree = self._as_function('ag__.and_',
                                            (op_tree, binary_comparison),
                                            args_as_lambda=True)
            else:
                op_tree = binary_comparison
            left = right
        assert op_tree is not None
        return op_tree
Example #4
0
 def visit_Lambda(self, node):
   assert not self._in_function_def_args
   self.state[_Lambda].enter()
   node = self.generic_visit(node)
   anno.setanno(node, anno.Static.SCOPE, self.scope)
   self.state[_Lambda].exit()
   return node
Example #5
0
    def visit_Attribute(self, node):
        parent_types = self.visit(node.value)

        # Attempt to use the static value if known.
        parent_value = anno.Static.VALUE.of(node.value, None)
        if parent_value is not None:
            static_value = getattr(parent_value, node.attr, None)

        else:
            # Fall back to the type if that is known.
            if parent_types is None:
                return None

            inferred_values = [
                getattr(t, node.attr, None) for t in parent_types
            ]
            if not inferred_values:
                return None

            static_value = inferred_values[0]
            if static_value is None:
                return None

            if any(v is not static_value for v in inferred_values[1:]):
                # Static value not stable, assume it's dynamic.
                return None

        types = self.resolver.res_value(self.namespace, static_value)
        anno.setanno(node, anno.Static.VALUE, static_value)

        if __debug__:
            self._check_set(types)

        return types
 def visit_Attribute(self, node):
   node = self.generic_visit(node)
   parent_val = anno.getanno(node.value, "static_value", default=None)
   if parent_val is not None:
     if hasattr(parent_val, node.attr):
       anno.setanno(node, "static_value", getattr(parent_val, node.attr))
   return node
Example #7
0
 def visit_With(self, node):
     node.items = self.visit_block(node.items)
     node.body, definitely_returns = self._visit_statement_block(
         node, node.body)
     if definitely_returns:
         anno.setanno(node, STMT_DEFINITELY_RETURNS, True)
     return node
Example #8
0
 def _block_statement_live_out(self, node):
     successors = self.current_analyzer.graph.stmt_next[node]
     stmt_live_out = set()
     for s in successors:
         stmt_live_out.update(self.current_analyzer.in_[s])
     anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(stmt_live_out))
     return node
  def test_create_source_map_multiple_nodes(self):

    source = """
        from __future__ import print_function
        def test_fn(x):
          return x + 1
    """
    source = textwrap.dedent(source)

    nodes = parser.parse_str(source, single_node=False)
    fake_import_origin = origin_info.OriginInfo(
        loc=origin_info.Location('fake_filename', 3, 7),
        function_name='fake_function_name',
        source_code_line='fake source line',
        comment=None)
    anno.setanno(nodes[0], anno.Basic.ORIGIN, fake_import_origin)
    fake_function_origin = origin_info.OriginInfo(
        loc=origin_info.Location('fake_filename', 3, 7),
        function_name='fake_function_name',
        source_code_line='fake source line',
        comment=None)
    anno.setanno(nodes[1], anno.Basic.ORIGIN, fake_function_origin)

    source_map = origin_info.create_source_map(nodes, source, 'test_filename')

    loc = origin_info.LineLocation('test_filename', 2)
    self.assertIn(loc, source_map)
    self.assertIs(source_map[loc], fake_import_origin)

    loc = origin_info.LineLocation('test_filename', 3)
    self.assertIn(loc, source_map)
    self.assertIs(source_map[loc], fake_function_origin)
Example #10
0
    def visit_Lambda(self, node):
        with self.state[_Function] as fn_scope:
            node = self.generic_visit(node)

            # Only wrap the top-level function. Theoretically, we can and should wrap
            # everything, but that can lead to excessive boilerplate when lambdas are
            # nested.
            # TODO(mdan): Looks more closely for use cases that actually require this.
            if fn_scope.level > 2:
                return templates.replace_as_expression(
                    'ag__.autograph_artifact(l)', l=node)

            scope = anno.getanno(node, anno.Static.SCOPE)
            function_context_name = self.ctx.namer.new_symbol(
                'lscope', scope.referenced)
            fn_scope.context_name = function_context_name
            anno.setanno(node, 'function_context_name', function_context_name)

            template = """
        ag__.with_function_scope(
            lambda function_context: body, function_context_name, options)
      """
            node.body = templates.replace_as_expression(
                template,
                options=self._function_scope_options(fn_scope).to_ast(),
                function_context=function_context_name,
                function_context_name=gast.Constant(function_context_name,
                                                    kind=None),
                body=node.body)

            return node
Example #11
0
 def _block_statement_live_out(self, node):
   successors = self.current_analyzer.graph.stmt_next[node]
   stmt_live_out = set()
   for s in successors:
     stmt_live_out.update(self.current_analyzer.in_[s])
   anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(stmt_live_out))
   return node
Example #12
0
 def visit_Attribute(self, node):
   node = self.generic_visit(node)
   parent_val = anno.getanno(node.value, STATIC_VALUE, default=None)
   if parent_val is not None and tf_inspect.ismodule(parent_val):
     if hasattr(parent_val, node.attr):
       anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr))
   return node
Example #13
0
    def _track_symbol(self,
                      node,
                      composite_writes_alter_parent=False,
                      writes_create_symbol=False):
        # A QN may be missing when we have an attribute (or subscript) on a function
        # call. Example: a().b
        if not anno.hasanno(node, anno.Basic.QN):
            return
        qn = anno.getanno(node, anno.Basic.QN)

        if isinstance(node.ctx, gast.Store):
            self.scope.mark_write(qn)
            if qn.is_composite and composite_writes_alter_parent:
                self.scope.mark_write(qn.parent)
            if writes_create_symbol:
                self.scope.mark_creation(qn, writes_create_symbol=True)
            if self._in_aug_assign:
                self.scope.mark_read(qn)
        elif isinstance(node.ctx, gast.Load):
            self.scope.mark_read(qn)
        elif isinstance(node.ctx, gast.Param):
            # Param contexts appear in function defs, so they have the meaning of
            # defining a variable.
            self.scope.mark_write(qn)
            self.scope.mark_param(qn, self.enclosing_entities[-1])
        else:
            raise ValueError('Unknown context %s for node %s.' %
                             (type(node.ctx), qn))

        anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))

        if self._in_return_statement:
            self.scope.mark_returned(qn)
Example #14
0
    def visit_FunctionDef(self, node):
        # The FunctionDef node itself has a Scope object that tracks the creation
        # of its name, along with the usage of any decorator accompany it.
        self._enter_scope(False)
        node.decorator_list = self.visit_block(node.decorator_list)
        self.scope.mark_modified(qual_names.QN(node.name))
        anno.setanno(node, anno.Static.SCOPE, self.scope)
        self._exit_scope()

        # A separate Scope tracks the actual function definition.
        self._enter_scope(True)
        assert not (self._in_function_def_args or self.state[_Lambda].level)
        self._in_function_def_args = True
        node.args = self.visit(node.args)
        self._in_function_def_args = False

        # Track the body separately. This is for compatibility reasons, it may not
        # be strictly needed.
        self._enter_scope(False)
        node.body = self.visit_block(node.body)
        anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope)
        self._exit_scope()

        self._exit_scope()
        return node
Example #15
0
 def visit_Attribute(self, node):
   node = self.generic_visit(node)
   parent_val = anno.getanno(node.value, STATIC_VALUE, default=None)
   if parent_val is not None and tf_inspect.ismodule(parent_val):
     if hasattr(parent_val, node.attr):
       anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr))
   return node
Example #16
0
    def test_create_source_map_multiple_nodes(self):

        source = """
        from __future__ import print_function
        def test_fn(x):
          return x + 1
    """
        source = textwrap.dedent(source)

        nodes = parser.parse_str(source, single_node=False)
        fake_import_origin = origin_info.OriginInfo(
            loc=origin_info.Location('fake_filename', 3, 7),
            function_name='fake_function_name',
            source_code_line='fake source line',
            comment=None)
        anno.setanno(nodes[0], anno.Basic.ORIGIN, fake_import_origin)
        fake_function_origin = origin_info.OriginInfo(
            loc=origin_info.Location('fake_filename', 3, 7),
            function_name='fake_function_name',
            source_code_line='fake source line',
            comment=None)
        anno.setanno(nodes[1], anno.Basic.ORIGIN, fake_function_origin)

        source_map = origin_info.create_source_map(nodes, source,
                                                   'test_filename')

        loc = origin_info.LineLocation('test_filename', 2)
        self.assertIn(loc, source_map)
        self.assertIs(source_map[loc], fake_import_origin)

        loc = origin_info.LineLocation('test_filename', 3)
        self.assertIn(loc, source_map)
        self.assertIs(source_map[loc], fake_function_origin)
    def visit_Compare(self, node):
        node = self.generic_visit(node)

        if not all(self._has_matching_func(op) for op in node.ops):
            if len(node.ops) == 1:
                # Basic expressions are safe to leave as they are.
                return node
            else:
                raise NotImplementedError(
                    'compound expression with at least one unsupported '
                    'operator: {}'.format(node.ops))

        ops_and_comps = list(zip(node.ops, node.comparators))
        left = node.left
        op_tree = None

        # Repeated comparisons are converted to conjunctions:
        #   a < b < c   ->   a < b and b < c
        while ops_and_comps:
            op, right = ops_and_comps.pop(0)
            binary_comparison = self._as_function(self._matching_func(op),
                                                  (left, right))
            if isinstance(left, gast.Name) and isinstance(right, gast.Name):
                anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True)
            if op_tree:
                self._expect_simple_symbol(right)
                op_tree = self._as_function('tf.logical_and',
                                            (binary_comparison, op_tree))
            else:
                op_tree = binary_comparison
            left = right
        assert op_tree is not None
        return op_tree
Example #18
0
 def visit_Print(self, node):
     self._enter_scope(False)
     node.values = self.visit_block(node.values)
     anno.setanno(node, anno.Static.SCOPE, self.scope)
     anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
     self._exit_scope()
     return node
Example #19
0
  def visit_FunctionDef(self, node):
    # The FunctionDef node itself has a Scope object that tracks the creation
    # of its name, along with the usage of any decorator accompany it.
    self._enter_scope(False)
    node.decorator_list = self.visit_block(node.decorator_list)
    self.scope.mark_modified(qual_names.QN(node.name))
    anno.setanno(node, anno.Static.SCOPE, self.scope)
    self._exit_scope()

    # A separate Scope tracks the actual function definition.
    self._enter_scope(True)
    assert not (self._in_function_def_args or self.state[_Lambda].level)
    self._in_function_def_args = True
    node.args = self.visit(node.args)
    self._in_function_def_args = False

    # Track the body separately. This is for compatibility reasons, it may not
    # be strictly needed.
    self._enter_scope(False)
    node.body = self.visit_block(node.body)
    anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope)
    self._exit_scope()

    self._exit_scope()
    return node
Example #20
0
  def visit_Lambda(self, node):
    with self.state[_Function] as fn_scope:
      node = self.generic_visit(node)

      # TODO(mdan): Fix the tests so that we can always add this decorator.
      if fn_scope.level > 2:
        return templates.replace_as_expression(
            'ag__.autograph_artifact(l)', l=node)

      scope = anno.getanno(node, anno.Static.SCOPE)
      function_context_name = self.ctx.namer.new_symbol('lscope',
                                                        scope.referenced)
      fn_scope.context_name = function_context_name
      anno.setanno(node, 'function_context_name', function_context_name)

      template = """
        ag__.with_function_scope(
            lambda function_context: body, function_context_name, options)
      """
      node.body = templates.replace_as_expression(
          template,
          options=self._function_scope_options(fn_scope).to_ast(),
          function_context=function_context_name,
          function_context_name=gast.Constant(function_context_name, kind=None),
          body=node.body)

      return node
Example #21
0
 def visit_Print(self, node):
   self._enter_scope(False)
   node.values = self.visit_block(node.values)
   anno.setanno(node, anno.Static.SCOPE, self.scope)
   anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
   self._exit_scope()
   return node
Example #22
0
 def visit_Attribute(self, node):
     node = self.generic_visit(node)
     if anno.hasanno(node.value, anno.Basic.QN):
         anno.setanno(
             node, anno.Basic.QN,
             QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr))
     return node
  def visit_For(self, node):
    original_node = node
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

    node.target = self.visit(node.target)
    node.iter = self.visit(node.iter)
    node.body, break_used = self._process_body(node.body, break_var)
    # A break in the else clause applies to the containing scope.
    node.orelse = self.visit_block(node.orelse)

    if not break_used:
      template = """
        for target in iter_:
          body
        orelse
      """
      node = templates.replace(
          template,
          iter_=node.iter,
          target=node.target,
          body=node.body,
          orelse=node.orelse)

      new_for_node = node[0]
      anno.copyanno(original_node, new_for_node, anno.Basic.EXTRA_LOOP_TEST)
      anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)

      return node

    # Python's else clause only triggers if the loop exited cleanly (e.g.
    # break did not trigger).
    guarded_orelse = self._guard_if_present(node.orelse, break_var)
    extra_test = templates.replace_as_expression(
        'ag__.not_(var_name)', var_name=break_var)

    # The extra test is hidden in the AST, which will confuse the static
    # analysis. To mitigate that, we insert a no-op statement that ensures
    # the control variable is marked as used.
    # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
    template = """
      var_name = False
      for target in iter_:
        (var_name,)
        body
      orelse
    """
    node = templates.replace(
        template,
        var_name=break_var,
        iter_=node.iter,
        target=node.target,
        body=node.body,
        orelse=guarded_orelse)

    new_for_node = node[1]
    anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
    anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)

    return node
Example #24
0
 def _aggregate_predecessors_defined_in(self, node):
     preds = self.current_analyzer.graph.stmt_prev[node]
     node_defined_in = set()
     for p in preds:
         node_defined_in |= set(self.current_analyzer.out[p].value.keys())
     anno.setanno(node, anno.Static.DEFINED_VARS_IN,
                  frozenset(node_defined_in))
Example #25
0
    def visit_Lambda(self, node):
        self.state[_Function].enter()
        node = self.generic_visit(node)

        # Only wrap the top-level function. Theoretically, we can and should wrap
        # everything, but that can lead to excessive boilerplate when lambdas are
        # nested.
        # TODO(mdan): Looks more closely for use cases that actually require this.
        if self.state[_Function].level > 2:
            self.state[_Function].exit()
            return node

        scope = anno.getanno(node, anno.Static.SCOPE)
        function_context_name = self.ctx.namer.new_symbol(
            'lscope', scope.referenced)
        self.state[_Function].context_name = function_context_name
        anno.setanno(node, 'function_context_name', function_context_name)

        template = """
      ag__.with_function_scope(
          lambda function_context: body, function_context_name, options)
    """
        node.body = templates.replace_as_expression(
            template,
            options=self.ctx.program.options.to_ast(),
            function_context=function_context_name,
            function_context_name=gast.Str(function_context_name),
            body=node.body)

        self.state[_Function].exit()
        return node
 def visit_Name(self, node):
     node = self.generic_visit(node)
     if isinstance(node.ctx,
                   gast.Load) and node.id in self.ctx.info.namespace:
         anno.setanno(node, "static_value",
                      self.ctx.info.namespace[node.id])
     return node
 def _as_function(self, func_name, args):
     template = """
   func_name(args)
 """
     replacement = templates.replace_as_expression(
         template, func_name=parser.parse_expression(func_name), args=args)
     anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
     return replacement
Example #28
0
 def visit(self, node):
     types = super().visit(node)
     if __debug__:
         self._check_set(types)
     if types is not None:
         # TODO(mdan): Normalize by removing subtypes.
         anno.setanno(node, anno.Static.TYPES, tuple(types))
     return types
Example #29
0
 def _aggregate_successors_live_in(self, node):
     successors = self.current_analyzer.graph.stmt_next[node]
     node_live_out = set()
     for s in successors:
         node_live_out.update(self.current_analyzer.in_[s])
     anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(node_live_out))
     node = self.generic_visit(node)
     return node
Example #30
0
 def visit(self, node):
     node = super(Annotator, self).visit(node)
     if (self.current_analyzer is not None and isinstance(node, gast.stmt)
             and node in self.current_analyzer.graph.index):
         cfg_node = self.current_analyzer.graph.index[node]
         anno.setanno(node, anno.Static.LIVE_VARS_IN,
                      frozenset(self.current_analyzer.in_[cfg_node]))
     return node
Example #31
0
 def visit_Name(self, node):
   node = self.generic_visit(node)
   if isinstance(node.ctx, gast.Load):
     defs = anno.getanno(node, anno.Static.DEFINITIONS, ())
     is_defined = bool(defs)
     if not is_defined and node.id in self.ctx.info.namespace:
       anno.setanno(node, STATIC_VALUE, self.ctx.info.namespace[node.id])
   return node
Example #32
0
 def _aggregate_successors_live_in(self, node):
   successors = self.current_analyzer.graph.stmt_next[node]
   node_live_out = set()
   for s in successors:
     node_live_out.update(self.current_analyzer.in_[s])
   anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(node_live_out))
   node = self.generic_visit(node)
   return node
 def _as_function(self, func_name, args):
   template = """
     func_name(args)
   """
   replacement = templates.replace_as_expression(
       template, func_name=parser.parse_expression(func_name), args=args)
   anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
   return replacement
Example #34
0
    def visit(self, node):
        if not isinstance(node, gast.AST):
            # This is not that uncommon a mistake: various node bodies are lists, for
            # example, posing a land mine for transformers that need to recursively
            # call `visit`.  The error needs to be raised before the exception handler
            # below is installed, because said handler will mess up if `node` is not,
            # in fact, a node.
            msg = ('invalid value for "node": expected "ast.AST", got "{}"; to'
                   ' visit lists of nodes, use "visit_block" instead').format(
                       type(node))
            raise ValueError(msg)

        if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
            return node

        parent_origin = self.ctx.current_origin
        if anno.hasanno(node, anno.Basic.ORIGIN):
            self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)

        try:
            processing_expr_node = isinstance(node, gast.Expr)
            if processing_expr_node:
                entry_expr_value = node.value

            result = super(Base, self).visit(node)

            # Adjust for consistency: replacing the value of an Expr with
            # an Assign node removes the need for the Expr node.
            if (processing_expr_node and isinstance(result, gast.Expr)
                    and (result.value is not entry_expr_value)):
                # When the replacement is a list, it is assumed that the list came
                # from a template that contained a number of statements, which
                # themselves are standalone and don't require an enclosing Expr.
                if isinstance(result.value,
                              (list, tuple, gast.Assign, gast.AugAssign)):
                    result = result.value

            # By default, all replacements receive the origin info of the replaced
            # node.
            if result is not node and result is not None:
                inherited_origin = anno.getanno(node,
                                                anno.Basic.ORIGIN,
                                                default=parent_origin)
                if inherited_origin is not None:
                    nodes_to_adjust = result
                    if isinstance(result, (list, tuple)):
                        nodes_to_adjust = result
                    else:
                        nodes_to_adjust = (result, )
                    for n in nodes_to_adjust:
                        if not anno.hasanno(n, anno.Basic.ORIGIN):
                            anno.setanno(n, anno.Basic.ORIGIN,
                                         inherited_origin)
        finally:
            self.ctx.current_origin = parent_origin

        return result
Example #35
0
 def visit_Call(self, node):
     self._enter_scope(False)
     node.args = self.visit_block(node.args)
     node.keywords = self.visit_block(node.keywords)
     # TODO(mdan): Account starargs, kwargs
     anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
     self._exit_scope()
     node.func = self.visit(node.func)
     return node
Example #36
0
 def visit(self, node):
   node = super(Annotator, self).visit(node)
   if (self.current_analyzer is not None and
       isinstance(node, gast.stmt) and
       node in self.current_analyzer.graph.index):
     cfg_node = self.current_analyzer.graph.index[node]
     anno.setanno(node, anno.Static.LIVE_VARS_IN,
                  frozenset(self.current_analyzer.in_[cfg_node]))
   return node
Example #37
0
 def _process_statement_directive(self, call_node, directive):
   if self.local_scope_level < 1:
     raise ValueError(
         '"%s" must be used inside a statement' % directive.__name__)
   target = self.get_local(ENCLOSING_LOOP)
   node_anno = anno.getanno(target, converter.AgAnno.DIRECTIVES, {})
   node_anno[directive] = _map_args(call_node, directive)
   anno.setanno(target, converter.AgAnno.DIRECTIVES, node_anno)
   return call_node
Example #38
0
 def visit_Call(self, node):
   self._enter_scope(False)
   node.args = self.visit_block(node.args)
   node.keywords = self.visit_block(node.keywords)
   # TODO(mdan): Account starargs, kwargs
   anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
   self._exit_scope()
   node.func = self.visit(node.func)
   return node
Example #39
0
  def test_rename_symbols_annotations(self):
    node = parser.parse_str('a[i]')
    node = qual_names.resolve(node)
    anno.setanno(node, 'foo', 'bar')
    orig_anno = anno.getanno(node, 'foo')

    node = ast_util.rename_symbols(node,
                                   {qual_names.QN('a'): qual_names.QN('b')})

    self.assertIs(anno.getanno(node, 'foo'), orig_anno)
Example #40
0
 def visit_For(self, node):
   self._enter_scope(False)
   node.target = self.visit(node.target)
   node.iter = self.visit(node.iter)
   anno.setanno(node.iter, anno.Static.SCOPE, self.scope)
   self._exit_scope()
   node = self._process_parallel_blocks(node,
                                        ((node.body, NodeAnno.BODY_SCOPE),
                                         (node.orelse, NodeAnno.ORELSE_SCOPE)))
   return node
Example #41
0
 def visit_While(self, node):
     self._enter_scope(False)
     node.test = self.visit(node.test)
     anno.setanno(node, NodeAnno.COND_SCOPE, self.scope)
     anno.setanno(node.test, anno.Static.SCOPE, self.scope)
     self._exit_scope()
     node = self._process_parallel_blocks(
         node, ((node.body, NodeAnno.BODY_SCOPE),
                (node.orelse, NodeAnno.ORELSE_SCOPE)))
     return node
Example #42
0
  def test_copy(self):
    node_1 = ast.Name()
    anno.setanno(node_1, 'foo', 3)

    node_2 = ast.Name()
    anno.copyanno(node_1, node_2, 'foo')
    anno.copyanno(node_1, node_2, 'bar')

    self.assertTrue(anno.hasanno(node_2, 'foo'))
    self.assertFalse(anno.hasanno(node_2, 'bar'))
Example #43
0
    def visit_If(self, node):
        self._enter_scope(False)
        node.test = self.visit(node.test)
        node_scope = self._exit_and_record_scope(node.test)
        anno.setanno(node, NodeAnno.COND_SCOPE, node_scope)

        node = self._process_parallel_blocks(
            node, ((node.body, NodeAnno.BODY_SCOPE),
                   (node.orelse, NodeAnno.ORELSE_SCOPE)))
        return node
    def test_rename_symbols_annotations(self):
        node = parser.parse_str('a[i]')
        node = qual_names.resolve(node)
        anno.setanno(node, 'foo', 'bar')
        orig_anno = anno.getanno(node, 'foo')

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('b')})

        self.assertIs(anno.getanno(node, 'foo'), orig_anno)
Example #45
0
 def visit_While(self, node):
   self._enter_scope(False)
   node.test = self.visit(node.test)
   anno.setanno(node, NodeAnno.COND_SCOPE, self.scope)
   anno.setanno(node.test, anno.Static.SCOPE, self.scope)
   self._exit_scope()
   node = self._process_parallel_blocks(node,
                                        ((node.body, NodeAnno.BODY_SCOPE),
                                         (node.orelse, NodeAnno.ORELSE_SCOPE)))
   return node
Example #46
0
def resolve(nodes, source, function=None):
  """Adds an origin information to all nodes inside the body of function.

  Args:
    nodes: Union[ast.AST, Iterable[ast.AST, ...]]
    source: Text, the source code string for the function whose body nodes will
      be annotated.
    function: Callable, the function that will have all nodes inside of it
      annotation with an OriginInfo annotation with key anno.Basic.ORIGIN.  If
      it is None then only the line numbers and column offset will be set in the
      annotation, with the rest of the information being None.

  Returns:
    A tuple of the AST node for function and a String containing its source
    code.
  """
  if not isinstance(nodes, (list, tuple)):
    nodes = (nodes,)

  if function:
    _, function_lineno = tf_inspect.getsourcelines(function)
    function_filepath = tf_inspect.getsourcefile(function)
  else:
    function_lineno = None
    function_filepath = None

  # TODO(mdan): Pull this to a separate utility.
  code_reader = six.StringIO(source)
  comment_map = {}
  for token in tokenize.generate_tokens(code_reader.readline):
    tok_type, tok_string, loc, _, _ = token
    srow, _ = loc
    if tok_type == tokenize.COMMENT:
      comment_map[srow] = tok_string.strip()[1:].strip()

  source_lines = source.split('\n')
  for node in nodes:
    for n in gast.walk(node):
      if not hasattr(n, 'lineno'):
        continue

      lineno_in_body = n.lineno

      source_code_line = source_lines[lineno_in_body - 1]
      if function:
        source_lineno = function_lineno + lineno_in_body
        function_name = function.__name__
      else:
        source_lineno = lineno_in_body
        function_name = None

      location = Location(function_filepath, source_lineno, n.col_offset)
      origin = OriginInfo(location, function_name,
                          source_code_line, comment_map.get(source_lineno))
      anno.setanno(n, anno.Basic.ORIGIN, origin)
Example #47
0
 def _block_statement_live_in(self, node, entry_node):
   if entry_node in self.current_analyzer.graph.index:
     cfg_node = self.current_analyzer.graph.index[entry_node]
     stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node])
   else:
     assert anno.hasanno(entry_node, anno.Static.LIVE_VARS_IN), (
         'If not matching a CFG node, must be a block statement:'
         ' {}'.format(entry_node))
     stmt_live_in = anno.getanno(entry_node, anno.Static.LIVE_VARS_IN)
   anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in)
   return node
Example #48
0
 def test_copy_clean_preserves_annotations(self):
   node = parser.parse_str(
       textwrap.dedent("""
     def f(a):
       return a + 1
   """))
   anno.setanno(node.body[0], 'foo', 'bar')
   anno.setanno(node.body[0], 'baz', 1)
   new_node = ast_util.copy_clean(node, preserve_annos={'foo'})
   self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar')
   self.assertFalse(anno.hasanno(new_node.body[0], 'baz'))
Example #49
0
 def _process_function_arg(self, arg_node):
   qn = anno.getanno(arg_node, anno.Basic.QN)
   arg_name = str(qn)
   self.scope.setval(qn, arg_node)
   if (len(self.enclosing_entities) == 1 and
       arg_name in self.entity_info.arg_types):
     # Forge a node to hold the type information, so that method calls on
     # it can resolve the type.
     type_string, type_obj = self.entity_info.arg_types[arg_name]
     anno.setanno(arg_node, 'type', type_obj)
     anno.setanno(arg_node, 'type_fqn', tuple(type_string.split('.')))
Example #50
0
def resolve(node, source, function=None):
  """Adds an origin information to node and its subnodes.

  This allows us to map the original source code line numbers to generated
  source code.

  Args:
    node: gast.AST node. Should be a gast.FunctionDef. This is the node we
        annotate with origin information.
    source: Text, the source code. Should satisfy relationship
        `node in iter_tree(gast.parse(source))`; otherwise the lineno will be
        unreliable.
    function: The original function. If it is None then only the line numbers
        and column offset will be set in the annotation, with the rest of the
        information being None.
  """
  if function:
    _, function_lineno = tf_inspect.getsourcelines(function)
    function_filepath = tf_inspect.getsourcefile(function)
  else:
    function_lineno = None
    function_filepath = None

  # TODO(mdan): Pull this to a separate utility.
  code_reader = six.StringIO(source)
  comment_map = {}
  for token in tokenize.generate_tokens(code_reader.readline):
    tok_type, tok_string, loc, _, _ = token
    srow, _ = loc
    if tok_type == tokenize.COMMENT:
      comment_map[srow] = tok_string.strip()[1:].strip()

  source_lines = source.split('\n')
  for n in gast.walk(node):
    if not hasattr(n, 'lineno'):
      continue

    within_body_offset = n.lineno - node.lineno

    source_code_line = source_lines[n.lineno - 1]
    if function:
      source_lineno = function_lineno + within_body_offset
      function_name = function.__name__
    else:
      source_lineno = n.lineno
      function_name = None

    location = Location(function_filepath, source_lineno, n.col_offset)
    origin = OriginInfo(location, function_name,
                        source_code_line, comment_map.get(source_lineno))
    anno.setanno(n, anno.Basic.ORIGIN, origin)
Example #51
0
  def visit_Call(self, node):
    # If the function call is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.ctx.program.options.strip_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least one positional argument.' %
              target_entity)
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if anno.hasanno(node.func, anno.Basic.QN):
        # Special-case a few builtins that otherwise go undetected. This
        # normally doesn't pose a problem, but the dict built-in doesn't
        # work with inspect.getargspec which is required for dynamic functions.
        # Note: expecting this is resilient to aliasing (e.g.
        # dict = an_evil_dict), because in those cases the regular mechanisms
        # process a simple user function.
        qn = anno.getanno(node.func, anno.Basic.QN)
        # Add items to this list as needed.
        if str(qn) in ('dict',):
          return node

      if ast_util.matches(node, 'super(_)'):
        # super() calls are preserved. The class conversion mechanism will
        # ensure that they return the correct value.
        return node

      if self.ctx.program.options.recursive:
        node = self._insert_dynamic_conversion(node)
    return node
  def test_basic(self):

    def test_fn():
      raise ValueError()

    node, ctx = self.prepare(test_fn, {})
    anno.setanno(
        node, anno.Basic.ORIGIN,
        origin_info.OriginInfo(None, 'test_function_name', 'test_code',
                               'test_comment'))
    node = error_handlers.transform(node, ctx)
    with self.compiled(node, {}) as result:
      with self.assertRaises(errors.GraphConstructionError):
        # Here we just assert that the handler works. Its correctness is
        # verified by errors_test.py.
        result.test_fn()
Example #53
0
  def visit_If(self, node):
    node.test = self.visit(node.test)

    node.body, body_definitely_returns = self._visit_statement_block(
        node, node.body)
    if body_definitely_returns:
      anno.setanno(node, BODY_DEFINITELY_RETURNS, True)

    node.orelse, orelse_definitely_returns = self._visit_statement_block(
        node, node.orelse)
    if orelse_definitely_returns:
      anno.setanno(node, ORELSE_DEFINITELY_RETURNS, True)

    if body_definitely_returns and orelse_definitely_returns:
      self.state[_Block].definitely_returns = True

    return node
  def visit_Name(self, node):
    if self.current_analyzer is None:
      # Names may appear outside function defs - for example in class
      # definitions.
      return node

    analyzer = self.current_analyzer
    cfg_node = self.current_cfg_node

    assert cfg_node is not None, 'name node outside of any statement?'

    qn = anno.getanno(node, anno.Basic.QN)
    if isinstance(node.ctx, gast.Load):
      anno.setanno(node, anno.Static.DEFINITIONS,
                   tuple(analyzer.in_[cfg_node].value.get(qn, ())))
    else:
      anno.setanno(node, anno.Static.DEFINITIONS,
                   tuple(analyzer.out[cfg_node].value.get(qn, ())))

    return node
  def test_create_source_map(self):

    source = """
        def test_fn(x):
          return x + 1
    """
    source = textwrap.dedent(source)

    node = parser.parse_str(source)
    fake_origin = origin_info.OriginInfo(
        loc=origin_info.Location('fake_filename', 3, 7),
        function_name='fake_function_name',
        source_code_line='fake source line',
        comment=None)
    anno.setanno(node, anno.Basic.ORIGIN, fake_origin)

    source_map = origin_info.create_source_map(node, source, 'test_filename')

    loc = origin_info.LineLocation('test_filename', 2)
    self.assertIn(loc, source_map)
    self.assertIs(source_map[loc], fake_origin)
  def test_create_source_map(self):

    def test_fn(x):
      return x + 1

    node, _ = parser.parse_entity(test_fn)
    fake_origin = origin_info.OriginInfo(
        loc=origin_info.Location('fake_filename', 3, 7),
        function_name='fake_function_name',
        source_code_line='fake source line',
        comment=None)
    fn_node = node.body[0]
    anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin)
    converted_code = compiler.ast_to_source(fn_node)

    source_map = origin_info.create_source_map(
        fn_node, converted_code, 'test_filename', [0])

    loc = origin_info.LineLocation('test_filename', 2)
    self.assertIn(loc, source_map)
    self.assertIs(source_map[loc], fake_origin)
Example #57
0
  def visit_For(self, node):
    node.iter = self.visit(node.iter)
    node.target = self.visit(node.target)

    # Add the check for return to the loop condition.
    node.body = self._visit_statement_block(node, node.body)
    if self.state[_Return].used:
      extra_test = anno.getanno(node, 'extra_test', default=None)
      if extra_test is not None:
        extra_test = templates.replace_as_expression(
            'ag__.and_(lambda: ag__.not_(control_var), lambda: extra_test)',
            extra_test=extra_test,
            control_var=self.state[_Function].do_return_var_name)
      else:
        extra_test = templates.replace_as_expression(
            'ag__.not_(control_var)',
            control_var=self.state[_Function].do_return_var_name)
      anno.setanno(node, 'extra_test', extra_test)

    node.orelse = self._visit_statement_block(node, node.orelse)
    return node
Example #58
0
  def visit_For(self, node):
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

    node.target = self.visit(node.target)
    node.iter = self.visit(node.iter)
    node.body, break_used = self._process_body(node.body, break_var)
    # A break in the else clause applies to the containing scope.
    node.orelse = self.visit_block(node.orelse)

    if break_used:
      # Python's else clause only triggers if the loop exited cleanly (e.g.
      # break did not trigger).
      guarded_orelse = self._guard_if_present(node.orelse, break_var)
      extra_test = templates.replace_as_expression(
          'not var_name', var_name=break_var)

      # The extra test is hidden in the AST, which will confuse the static
      # analysis. To mitigate that, we insert a no-op statement that ensures
      # the control variable is marked as used.
      # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
      template = """
        var_name = tf.constant(False)
        for target in iter_:
          (var_name,)
          body
        else:
          orelse
      """
      node = templates.replace(
          template,
          var_name=break_var,
          iter_=node.iter,
          target=node.target,
          body=node.body,
          orelse=guarded_orelse)

      anno.setanno(node[1], 'extra_test', extra_test)

    return node