def visit_While(self, node):
        original_node = node
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

        node.test = self.visit(node.test)
        node.body, break_used = self._process_body(node.body, break_var)
        # A break in the else clause applies to the containing scope.
        node.orelse = self.visit_block(node.orelse)

        if break_used:
            # Python's else clause only triggers if the loop exited cleanly (e.g.
            # break did not trigger).
            guarded_orelse = self._guard_if_present(node.orelse, break_var)

            template = """
        var_name = False
        while ag__.and_(lambda: test, lambda: ag__.not_(var_name)):
          body
        else:
          orelse
      """
            node = templates.replace(template,
                                     var_name=break_var,
                                     test=node.test,
                                     body=node.body,
                                     orelse=guarded_orelse)

            new_while_node = node[1]
            anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)

        return node
예제 #2
0
 def _process(self, node):
     qn = anno.getanno(node, anno.Basic.QN)
     if qn in self.name_map:
         new_node = gast.Name(str(self.name_map[qn]), node.ctx, None)
         # All annotations get carried over.
         for k in anno.keys(node):
             anno.copyanno(node, new_node, k)
         return new_node
     return self.generic_visit(node)
예제 #3
0
 def _process(self, node):
   qn = anno.getanno(node, anno.Basic.QN)
   if qn in self.name_map:
     new_node = gast.Name(str(self.name_map[qn]), node.ctx, None)
     # All annotations get carried over.
     for k in anno.keys(node):
       anno.copyanno(node, new_node, k)
     return new_node
   return self.generic_visit(node)
예제 #4
0
  def test_copy(self):
    node_1 = ast.Name()
    anno.setanno(node_1, 'foo', 3)

    node_2 = ast.Name()
    anno.copyanno(node_1, node_2, 'foo')
    anno.copyanno(node_1, node_2, 'bar')

    self.assertTrue(anno.hasanno(node_2, 'foo'))
    self.assertFalse(anno.hasanno(node_2, 'bar'))
    def test_source_map(self):
        def test_fn(x):
            if x > 0:
                x += 1
            return x

        node, source = parser.parse_entity(test_fn)
        fn_node = node.body[0]
        origin_info.resolve(fn_node, source)

        # Insert a traced line.
        new_node = parser.parse_str('x = abs(x)').body[0]
        anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
        fn_node.body.insert(0, new_node)

        # Insert an untraced line.
        fn_node.body.insert(0, parser.parse_str('x = 0').body[0])

        modified_source = compiler.ast_to_source(fn_node)

        source_map = origin_info.source_map(fn_node, modified_source,
                                            'test_filename', [0])

        loc = origin_info.LineLocation('test_filename', 1)
        origin = source_map[loc]
        self.assertEqual(origin.source_code_line, 'def test_fn(x):')
        self.assertEqual(origin.loc.lineno, 1)

        # The untraced line, inserted second.
        loc = origin_info.LineLocation('test_filename', 2)
        self.assertFalse(loc in source_map)

        # The traced line, inserted first.
        loc = origin_info.LineLocation('test_filename', 3)
        origin = source_map[loc]
        self.assertEqual(origin.source_code_line, '  if x > 0:')
        self.assertEqual(origin.loc.lineno, 2)

        loc = origin_info.LineLocation('test_filename', 4)
        origin = source_map[loc]
        self.assertEqual(origin.source_code_line, '  if x > 0:')
        self.assertEqual(origin.loc.lineno, 2)
    def visit_For(self, node):
        original_node = node
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

        node.target = self.visit(node.target)
        node.iter = self.visit(node.iter)
        node.body, break_used = self._process_body(node.body, break_var)
        # A break in the else clause applies to the containing scope.
        node.orelse = self.visit_block(node.orelse)

        if break_used:
            # Python's else clause only triggers if the loop exited cleanly (e.g.
            # break did not trigger).
            guarded_orelse = self._guard_if_present(node.orelse, break_var)
            extra_test = templates.replace_as_expression('ag__.not_(var_name)',
                                                         var_name=break_var)

            # The extra test is hidden in the AST, which will confuse the static
            # analysis. To mitigate that, we insert a no-op statement that ensures
            # the control variable is marked as used.
            # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
            template = """
        var_name = False
        for target in iter_:
          (var_name,)
          body
        else:
          orelse
      """
            node = templates.replace(template,
                                     var_name=break_var,
                                     iter_=node.iter,
                                     target=node.target,
                                     body=node.body,
                                     orelse=guarded_orelse)

            new_for_node = node[1]
            anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
            anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)

        return node
예제 #7
0
    def copy(self, node):
        """Returns a deep copy of node (excluding some fields, see copy_clean)."""

        if isinstance(node, list):
            return [self.copy(n) for n in node]
        elif isinstance(node, tuple):
            return tuple(self.copy(n) for n in node)
        elif not isinstance(node, (gast.AST, ast.AST)):
            # Assuming everything that's not an AST, list or tuple is a value type
            # and may simply be assigned.
            return node

        assert isinstance(node, (gast.AST, ast.AST))

        new_fields = {}
        for f in node._fields:
            if not f.startswith('__') and hasattr(node, f):
                new_fields[f] = self.copy(getattr(node, f))
        new_node = type(node)(**new_fields)

        if self.preserve_annos:
            for k in self.preserve_annos:
                anno.copyanno(node, new_node, k)
        return new_node
예제 #8
0
  def copy(self, node):
    """Returns a deep copy of node (excluding some fields, see copy_clean)."""

    if isinstance(node, list):
      return [self.copy(n) for n in node]
    elif isinstance(node, tuple):
      return tuple(self.copy(n) for n in node)
    elif not isinstance(node, (gast.AST, ast.AST)):
      # Assuming everything that's not an AST, list or tuple is a value type
      # and may simply be assigned.
      return node

    assert isinstance(node, (gast.AST, ast.AST))

    new_fields = {}
    for f in node._fields:
      if not f.startswith('__') and hasattr(node, f):
        new_fields[f] = self.copy(getattr(node, f))
    new_node = type(node)(**new_fields)

    if self.preserve_annos:
      for k in self.preserve_annos:
        anno.copyanno(node, new_node, k)
    return new_node
예제 #9
0
    def visit_Name(self, node):
        self.generic_visit(node)
        if isinstance(node.ctx, gast.Param):
            self._process_function_arg(node)
        elif isinstance(node.ctx, gast.Load):
            qn = anno.getanno(node, anno.Basic.QN)
            if self.scope.hasval(qn):
                # E.g. if we had
                # a = b
                # then for future references to `a` we should have definition = `b`
                definition = self.scope.getval(qn)
                anno.copyanno(definition, node, 'type')
                anno.copyanno(definition, node, 'type_fqn')

                # TODO(mdan): Remove this when the directives module is in.
                anno.copyanno(definition, node, 'element_type')
                anno.copyanno(definition, node, 'element_shape')
        return node
예제 #10
0
  def visit_Name(self, node):
    self.generic_visit(node)
    if isinstance(node.ctx, gast.Param):
      self._process_function_arg(node)
    elif isinstance(node.ctx, gast.Load):
      qn = anno.getanno(node, anno.Basic.QN)
      if self.scope.hasval(qn):
        # E.g. if we had
        # a = b
        # then for future references to `a` we should have definition = `b`
        definition = self.scope.getval(qn)
        anno.copyanno(definition, node, 'type')
        anno.copyanno(definition, node, 'type_fqn')

        # TODO(mdan): Remove this when the directives module is in.
        anno.copyanno(definition, node, 'element_type')
        anno.copyanno(definition, node, 'element_shape')
    return node