def visit_While(self, node): original_node = node scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) template = """ var_name = False while ag__.and_(lambda: test, lambda: ag__.not_(var_name)): body else: orelse """ node = templates.replace(template, var_name=break_var, test=node.test, body=node.body, orelse=guarded_orelse) new_while_node = node[1] anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES) return node
def _process(self, node): qn = anno.getanno(node, anno.Basic.QN) if qn in self.name_map: new_node = gast.Name(str(self.name_map[qn]), node.ctx, None) # All annotations get carried over. for k in anno.keys(node): anno.copyanno(node, new_node, k) return new_node return self.generic_visit(node)
def test_copy(self): node_1 = ast.Name() anno.setanno(node_1, 'foo', 3) node_2 = ast.Name() anno.copyanno(node_1, node_2, 'foo') anno.copyanno(node_1, node_2, 'bar') self.assertTrue(anno.hasanno(node_2, 'foo')) self.assertFalse(anno.hasanno(node_2, 'bar'))
def test_source_map(self): def test_fn(x): if x > 0: x += 1 return x node, source = parser.parse_entity(test_fn) fn_node = node.body[0] origin_info.resolve(fn_node, source) # Insert a traced line. new_node = parser.parse_str('x = abs(x)').body[0] anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN) fn_node.body.insert(0, new_node) # Insert an untraced line. fn_node.body.insert(0, parser.parse_str('x = 0').body[0]) modified_source = compiler.ast_to_source(fn_node) source_map = origin_info.source_map(fn_node, modified_source, 'test_filename', [0]) loc = origin_info.LineLocation('test_filename', 1) origin = source_map[loc] self.assertEqual(origin.source_code_line, 'def test_fn(x):') self.assertEqual(origin.loc.lineno, 1) # The untraced line, inserted second. loc = origin_info.LineLocation('test_filename', 2) self.assertFalse(loc in source_map) # The traced line, inserted first. loc = origin_info.LineLocation('test_filename', 3) origin = source_map[loc] self.assertEqual(origin.source_code_line, ' if x > 0:') self.assertEqual(origin.loc.lineno, 2) loc = origin_info.LineLocation('test_filename', 4) origin = source_map[loc] self.assertEqual(origin.source_code_line, ' if x > 0:') self.assertEqual(origin.loc.lineno, 2)
def visit_For(self, node): original_node = node scope = anno.getanno(node, NodeAnno.BODY_SCOPE) break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) node.body, break_used = self._process_body(node.body, break_var) # A break in the else clause applies to the containing scope. node.orelse = self.visit_block(node.orelse) if break_used: # Python's else clause only triggers if the loop exited cleanly (e.g. # break did not trigger). guarded_orelse = self._guard_if_present(node.orelse, break_var) extra_test = templates.replace_as_expression('ag__.not_(var_name)', var_name=break_var) # The extra test is hidden in the AST, which will confuse the static # analysis. To mitigate that, we insert a no-op statement that ensures # the control variable is marked as used. # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = False for target in iter_: (var_name,) body else: orelse """ node = templates.replace(template, var_name=break_var, iter_=node.iter, target=node.target, body=node.body, orelse=guarded_orelse) new_for_node = node[1] anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test) anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES) return node
def copy(self, node): """Returns a deep copy of node (excluding some fields, see copy_clean).""" if isinstance(node, list): return [self.copy(n) for n in node] elif isinstance(node, tuple): return tuple(self.copy(n) for n in node) elif not isinstance(node, (gast.AST, ast.AST)): # Assuming everything that's not an AST, list or tuple is a value type # and may simply be assigned. return node assert isinstance(node, (gast.AST, ast.AST)) new_fields = {} for f in node._fields: if not f.startswith('__') and hasattr(node, f): new_fields[f] = self.copy(getattr(node, f)) new_node = type(node)(**new_fields) if self.preserve_annos: for k in self.preserve_annos: anno.copyanno(node, new_node, k) return new_node
def visit_Name(self, node): self.generic_visit(node) if isinstance(node.ctx, gast.Param): self._process_function_arg(node) elif isinstance(node.ctx, gast.Load): qn = anno.getanno(node, anno.Basic.QN) if self.scope.hasval(qn): # E.g. if we had # a = b # then for future references to `a` we should have definition = `b` definition = self.scope.getval(qn) anno.copyanno(definition, node, 'type') anno.copyanno(definition, node, 'type_fqn') # TODO(mdan): Remove this when the directives module is in. anno.copyanno(definition, node, 'element_type') anno.copyanno(definition, node, 'element_shape') return node