def _visit_For(self, node): assert isinstance(node, gast.For) if isinstance(node.iter, gast.Call): # for in range(var[0]|var.numpy()[0]) or for in enumerate(var|var.numpy()) if isinstance(node.iter.func, gast.Name): if node.iter.func.id == "range" or node.iter.func.id == "enumerate": for arg in node.iter.args: self.visit(arg) else: return # for in var.numpy() elif isinstance(node.iter.func, gast.Attribute): if node.iter.func.attr == 'numpy': self._visit_Call(node.iter) else: return else: return elif isinstance(node.iter, gast.Name): # for in var self.visit(node.iter) else: return for child_node in gast.walk(node): if isinstance(child_node, (gast.Continue, gast.Break)): self._visit_break_continue(child_node) return
def synchronize_lcds(self, node): node = FuseAttributes().visit(node) loads, lcds = defaultdict(list), set() for child in node.body: for n in gast.walk(child): if isinstance(n, gast.Name) and isinstance(n.ctx, gast.Load): loads[n.id].append(n) if isinstance(child, gast.Assign): name = child.targets[0].id if name in loads: if name in lcds: raise NotImplementedError("cannot process LCD " "stored to twice") lcds.add(name) node = SplitAttributes().visit(node) synchronizes = [] for name in lcds: synchronize = gast.Assign( [gast.Name(name, gast.Store(), None)], gast.Call( gast.Attribute( gast.Name(name, gast.Load(), None), gast.Name('_synchronize', gast.Load(), None), None), [], [])) synchronizes.append(synchronize) node.body.extend(synchronizes) return node
def dead_code_elimination(node): """Perform a simple form of dead code elimination on a Python AST. This method performs reaching definitions analysis on all function definitions. It then looks for the definition of variables that are not used elsewhere and removes those definitions. This function takes into consideration push and pop statements; if a pop statement is removed, it will also try to remove the accompanying push statement. Note that this *requires dead code elimination to be performed on the primal and adjoint simultaneously*. Args: node: The AST to optimize. Returns: The optimized AST. """ to_remove = set(def_[1] for def_ in annotate.unused(node) if not isinstance(def_[1], (gast.arguments, gast.For))) for n in list(to_remove): for succ in gast.walk(n): if anno.getanno(succ, 'push', False): to_remove.add(anno.getanno(succ, 'push')) transformers.Remove(to_remove).visit(node) anno.clearanno(node) return node
def visit_loop(self, node, update_mask=gast.NameConstant(value=None)): node = FuseAttributes().visit(node) loads, stores = defaultdict(list), set() for child in node.body: for n in gast.walk(child): if isinstance(n, gast.Name) and isinstance(n.ctx, gast.Load): loads[n.id].append(n) if isinstance(child, gast.Assign): if len(child.targets) > 1: raise NotImplementedError("cannot process LCD that is " "part of multiple assignment") name = child.targets[0].id if name in loads: if name in stores: raise NotImplementedError("cannot process LCD " "stored to twice") # $var = $expr -> $var = $var._update($expr) child.value = gast.Call( gast.Attribute(gast.Name(name, gast.Load(), None), gast.Name('_update', gast.Load(), None), None), [child.value, update_mask], []) stores.add(name) node = SplitAttributes().visit(node) synchronizes = [] for name in stores: synchronize = gast.Assign( [gast.Name(name, gast.Store(), None)], gast.Call( gast.Attribute( gast.Name(name, gast.Load(), None), gast.Name('_synchronize', gast.Load(), None), None), [], [])) synchronizes.append(synchronize) node.body.extend(synchronizes) return node
def compute_same_identifier_edges(tree, ast_to_node_id): """Compute EXTRA_SAME_IDENTIFIER edges from an AST. These edges connect any two `Name` nodes with the same identifier, including Args: tree: The AST to construct an example for. ast_to_node_id: Dictionary that maps AST node ids to their graph node id. Returns: List of same-identifier edges. """ result = [] nodes_by_identifier = collections.defaultdict(list) for ast_node in gast.walk(tree): if isinstance(ast_node, gast.Name): graph_node_id = ast_to_node_id[id(ast_node)] identifier = ast_node.id # pytype: disable=attribute-error for matching in nodes_by_identifier[identifier]: result.append( (graph_node_id, matching, SAME_IDENTIFIER_EDGE_TYPE)) result.append( (matching, graph_node_id, SAME_IDENTIFIER_EDGE_TYPE)) nodes_by_identifier[identifier].append(graph_node_id) result.append( (graph_node_id, graph_node_id, SAME_IDENTIFIER_EDGE_TYPE)) return result
def test_nested_loop_vars(self): func = self.nested_for_loop_func test_func = inspect.getsource(func) gast_root = gast.parse(test_func) name_visitor = NameVisitor(gast_root) self.loop_var_names = [ set(["j", "two"]), set(["i", "three", "b"]), set(["i", "j"]) ] self.create_var_names = [set(), set(["b"]), set()] i = 0 for node in gast.walk(gast_root): if isinstance(node, (gast.While, gast.For)): loop_var_names, create_var_names = name_visitor.get_loop_var_names( node) self.assertEqual( loop_var_names, self.loop_var_names[i], msg="loop_var_names : {}, \nexpected loop_var_names : {}". format(loop_var_names, self.loop_var_names[i])) self.assertEqual( create_var_names, self.create_var_names[i], msg= "i = {}\ncreate_var_names : {}, \nexpected create_var_names : {}" .format(i, create_var_names, self.create_var_names[i])) i += 1
def _transform_var_shape_if_necessary(self, cond): need_transformed = False for child_node in gast.walk(cond): var_shape_node = None if isinstance(child_node, (gast.Attribute)): if self.is_var_shape(child_node): var_shape_node = child_node elif isinstance(child_node, (gast.Name)): if child_node.id in self.name_to_var_shape: var_shape_node = self.name_to_var_shape[child_node.id] if var_shape_node: need_transformed = True wrapper_node = self.node_to_wrapper_map.get(child_node) parent_node = wrapper_node.parent.node for field, value in gast.iter_fields(parent_node): if child_node is value: setattr(parent_node, field, create_convert_shape_node(var_shape_node)) break # Some child_node may be in a list such as gast.Compare if isinstance(value, list): has_converted_shape = False for i, v in enumerate(value): if child_node is v: value[i] = create_convert_shape_node( var_shape_node) has_converted_shape = True break if has_converted_shape: break return need_transformed
def _without_context(node, lines, minl, maxl): """Returns a clean node and source code without indenting and context.""" for n in gast.walk(node): lineno = getattr(n, 'lineno', None) if lineno is not None: n.lineno = lineno - minl end_lineno = getattr(n, 'end_lineno', None) if end_lineno is not None: n.end_lineno = end_lineno - minl code_lines = lines[minl - 1:maxl] # Attempt to clean up surrounding context code. end_col_offset = getattr(node, 'end_col_offset', None) if end_col_offset is not None: # This is only available in 3.8. code_lines[-1] = code_lines[-1][:end_col_offset] col_offset = getattr(node, 'col_offset', None) if col_offset is None: # Older Python: try to find the "lambda" token. This is brittle. match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0]) if match is not None: col_offset = match.start(0) if col_offset is not None: code_lines[0] = code_lines[0][col_offset:] code_block = '\n'.join([c.rstrip() for c in code_lines]) return node, code_block
def explicit_loop_indexes(node): node = ExplicitLoopIndexes().visit(node) for n in gast.walk(node): for key in ('active_in', 'active_out', 'active_gen', 'active_kill'): if anno.hasanno(n, key): anno.delanno(n, key) return node
def run_analyses(node, analyses): """Perform dataflow analysis on all functions within an AST. Args: node: An AST node on which to run dataflow analysis. analyses: Either an instance of the Forward or Backward dataflow analysis class, or a list or tuple of them. Returns: node: The node, but now with annotations on the AST nodes containing the results of the dataflow analyses. """ if not isinstance(analyses, (tuple, list)): analyses = (analyses, ) for analysis in analyses: if not isinstance(analysis, (Forward, Backward)): raise TypeError('not a valid forward analysis object') for child_node in gast.walk(node): if isinstance(child_node, gast.FunctionDef): cfg_obj = CfgBuilder().build_cfg(child_node) for analysis in analyses: if isinstance(analysis, Backward): analysis.visit(cfg_obj.exit) elif isinstance(analysis, Forward): analysis.visit(cfg_obj.entry) for analysis in analyses: PropagateAnalysis(analysis).visit(node) return node
def run_analyses(node, analyses): """Perform dataflow analysis on all functions within an AST. Args: node: An AST node on which to run dataflow analysis. analyses: Either an instance of the Forward or Backward dataflow analysis class, or a list or tuple of them. Returns: node: The node, but now with annotations on the AST nodes containing the results of the dataflow analyses. """ if not isinstance(analyses, (tuple, list)): analyses = (analyses,) for analysis in analyses: if not isinstance(analysis, (Forward, Backward)): raise TypeError('not a valid forward analysis object') for child_node in gast.walk(node): if isinstance(child_node, gast.FunctionDef): cfg_obj = CfgBuilder().build_cfg(child_node) for analysis in analyses: if isinstance(analysis, Backward): analysis.visit(cfg_obj.exit) elif isinstance(analysis, Forward): analysis.visit(cfg_obj.entry) for analysis in analyses: PropagateAnalysis(analysis).visit(node) return node
def test_compute_jumps_out_edges(self): tree = gast.parse( textwrap.dedent("""\ def foo(): # tree.body[0] return # tree.body[0].body[0] while True: # tree.body[0].body[1] break # tree.body[0].body[1].body[0] continue # tree.body[0].body[1].body[1] return # tree.body[0].body[1].body[2] while True: # tree.body[0].body[1].body[3] break # tree.body[0].body[1].body[3].body[0] return 4 # tree.body[0].body[1].body[3].body[1] """)) expected_type = graph_edge_util.JUMPS_OUT_OF_EDGE_TYPE expected_targets = [ (tree.body[0].body[0], tree.body[0], expected_type), (tree.body[0].body[1].body[0], tree.body[0].body[1], expected_type), (tree.body[0].body[1].body[1], tree.body[0].body[1], expected_type), (tree.body[0].body[1].body[2], tree.body[0], expected_type), (tree.body[0].body[1].body[3].body[0], tree.body[0].body[1].body[3], expected_type), (tree.body[0].body[1].body[3].body[1], tree.body[0], expected_type), (tree.body[0].body[1].body[3].body[1].value, tree.body[0], expected_type), ] # For this test, we pretend that the AST nodes are the node ids. targets = graph_edge_util.compute_jumps_out_edges( tree, {id(x): x for x in gast.walk(tree)}) self.assertCountEqual(targets, expected_targets)
def visit_Assign(self, node): if self._update_class_node_dict(node): return None for child_node in gast.walk(node.value): if isinstance(child_node, gast.Call): self._visit_Call(child_node) return node
def is_active(self, node): active_variables = anno.getanno(node, 'active_in') for succ in gast.walk(node): if (isinstance(succ, gast.Name) and isinstance(succ.ctx, gast.Load) and succ.id in active_variables): return True return False
def test_walk(self): code = 'x + 1' tree = gast.parse(code, mode='eval') dump = gast.dump(tree) norm = ("Expression(body=BinOp(left=Name(id='x', ctx=Load(), " "annotation=None), op=Add(), right=Num(n=1)))") self.assertEqual(dump, norm) self.assertEqual(len(list(gast.walk(tree))), 6)
def _visit_While(self, node): assert isinstance(node, gast.While) test = node.test self.generic_visit(test) for child_node in gast.walk(node): if isinstance(child_node, (gast.Continue, gast.Break)): self._visit_break_continue(child_node) return
def forward(node, analysis): """Perform a given analysis on all functions within an AST.""" if not isinstance(analysis, Forward): raise TypeError('not a valid forward analysis object') for succ in gast.walk(node): if isinstance(succ, gast.FunctionDef): cfg_obj = CFG.build_cfg(succ) analysis.visit(cfg_obj.entry) return node
def clearanno(node): for succ in gast.walk(node): if hasattr(succ, ANNOTATION_FIELD): new = {} for anno in FIXED_ANNOTATIONS: if hasanno(succ, anno): new[anno] = getanno(succ, anno) setattr(succ, ANNOTATION_FIELD, new) return node
def visit_Expr(self, node): value_node = node.value for child_node in gast.walk(value_node): if isinstance(child_node, gast.Call): if is_dygraph_api(child_node): return else: self._visit_Call(child_node) return node
def copy_origin(from_node, to_node): """Copies the origin info from a node to another, recursively.""" origin = anno.Basic.ORIGIN.of(from_node, default=None) if origin is None: return if not isinstance(to_node, (list, tuple)): to_node = (to_node, ) for node in to_node: for n in gast.walk(node): anno.setanno(n, anno.Basic.ORIGIN, origin)
def visit_Compare(self, node): pre_control_flow_num = self.is_control_flow_num if not compare_with_none(node): self.generic_visit(node) for child in gast.walk(node): if isinstance(child, gast.Subscript): self._visit_Subscript(child) if self.is_control_flow_num > pre_control_flow_num: self._compare_node_tenor_set.add(node) return node
def resolve(nodes, source, function=None): """Adds an origin information to all nodes inside the body of function. Args: nodes: Union[ast.AST, Iterable[ast.AST, ...]] source: Text, the source code string for the function whose body nodes will be annotated. function: Callable, the function that will have all nodes inside of it annotation with an OriginInfo annotation with key anno.Basic.ORIGIN. If it is None then only the line numbers and column offset will be set in the annotation, with the rest of the information being None. Returns: A tuple of the AST node for function and a String containing its source code. """ if not isinstance(nodes, (list, tuple)): nodes = (nodes,) if function: _, function_lineno = tf_inspect.getsourcelines(function) function_filepath = tf_inspect.getsourcefile(function) else: function_lineno = None function_filepath = None # TODO(mdan): Pull this to a separate utility. code_reader = six.StringIO(source) comment_map = {} for token in tokenize.generate_tokens(code_reader.readline): tok_type, tok_string, loc, _, _ = token srow, _ = loc if tok_type == tokenize.COMMENT: comment_map[srow] = tok_string.strip()[1:].strip() source_lines = source.split('\n') for node in nodes: for n in gast.walk(node): if not hasattr(n, 'lineno'): continue lineno_in_body = n.lineno source_code_line = source_lines[lineno_in_body - 1] if function: source_lineno = function_lineno + lineno_in_body function_name = function.__name__ else: source_lineno = lineno_in_body function_name = None location = Location(function_filepath, source_lineno, n.col_offset) origin = OriginInfo(location, function_name, source_code_line, comment_map.get(source_lineno)) anno.setanno(n, anno.Basic.ORIGIN, origin)
def resolve(nodes, source, function=None): """Adds an origin information to all nodes inside the body of function. Args: nodes: Union[ast.AST, Iterable[ast.AST, ...]] source: Text, the source code string for the function whose body nodes will be annotated. function: Callable, the function that will have all nodes inside of it annotation with an OriginInfo annotation with key anno.Basic.ORIGIN. If it is None then only the line numbers and column offset will be set in the annotation, with the rest of the information being None. Returns: A tuple of the AST node for function and a String containing its source code. """ if not isinstance(nodes, (list, tuple)): nodes = (nodes, ) if function: _, function_lineno = tf_inspect.getsourcelines(function) function_filepath = tf_inspect.getsourcefile(function) else: function_lineno = None function_filepath = None # TODO(mdan): Pull this to a separate utility. code_reader = six.StringIO(source) comment_map = {} for token in tokenize.generate_tokens(code_reader.readline): tok_type, tok_string, loc, _, _ = token srow, _ = loc if tok_type == tokenize.COMMENT: comment_map[srow] = tok_string.strip()[1:].strip() source_lines = source.split('\n') for node in nodes: for n in gast.walk(node): if not hasattr(n, 'lineno'): continue lineno_in_body = n.lineno source_code_line = source_lines[lineno_in_body - 1] if function: source_lineno = function_lineno + lineno_in_body function_name = function.__name__ else: source_lineno = lineno_in_body function_name = None location = Location(function_filepath, source_lineno, n.col_offset) origin = OriginInfo(location, function_name, source_code_line, comment_map.get(source_lineno)) anno.setanno(n, anno.Basic.ORIGIN, origin)
def visit_Expr(self, node): value_node = node.value for child_node in gast.walk(value_node): if isinstance(child_node, gast.Call): # TODO(liym27): # Considers that a dygraph api which modifies the input or has a output. if is_dygraph_api(child_node): return else: self._visit_Call(child_node) return node
def __call__(self, codeobj): cache = self.cache key = self.get_file_info(codeobj) result = cache.get(key) if result is not None: return result fname = key[0] cache[(fname, 0)] = mod_ast = gast.ast_to_gast(self.parse_file(fname)) for obj in gast.walk(mod_ast): if isinstance(obj, gast.FunctionDef): cache[(fname, obj.lineno)] = obj return cache[key]
def test_loop_vars(self): for i in range(len(self.loop_funcs)): func = self.loop_funcs[i] test_func = inspect.getsource(func) gast_root = gast.parse(test_func) name_visitor = NameVisitor(gast_root) for node in gast.walk(gast_root): if isinstance(node, (gast.While, gast.For)): loop_var_names, create_var_names = name_visitor.get_loop_var_names( node) self.assertEqual(loop_var_names, self.loop_var_names[i]) self.assertEqual(create_var_names, self.create_var_names[i])
def analyze_func(ast: AST) -> AST: """Annotate `FunctionDef` nodes in the given AST with additional infomation Walks through the `FunctionDef` nodes in given AST and annotates each node with the following info: - `n_args`: the function's arguments count. - `docstr`: the function's docstring if present, otherwise None - `is_empty`: whether the function is empty. - `is_generator`: whether the function produces a generator via `yield`. Args: ast: AST to scan for and annotate `FunctionDef` in. Returns: The given AST with the `FunctionDef` annotated with additional infomation. """ # walk through AST to find FunctionDef nodes fn_asts = [n for n in gast.walk(ast) if isinstance(n, FunctionDef)] for fn_ast in fn_asts: fn_ast.n_args = len(fn_ast.args.args) fn_ast.docstr = gast.get_docstring(fn_ast) # detect empty if contains pass and/or just a docstrings fn_ast.is_empty = True for node in fn_ast.body: if isinstance(node, Pass): continue if ( isinstance(node, Expr) and isinstance(node.value, Constant) and isinstance(node.value.value, str) ): continue fn_ast.is_empty = False # detect as generator if contains yield statement fn_ast.is_generator = any( isinstance(node, gast.Yield) for node in gast.walk(fn_ast) ) return ast
def visit_Compare(self, node): # Ignores child node with `if x` or `if x is None` # TODO(Aurelius84): `if tensor` will be supported in dygraph # and should be considered as is_control_flow. pre_control_flow_num = self.is_control_flow_num if not compare_with_none(node): self.generic_visit(node) for child in gast.walk(node): if isinstance(child, gast.Subscript): self._visit_Subscript(child) if self.is_control_flow_num > pre_control_flow_num: self._compare_node_tenor_set.add(node) return node
def prepend_uninitialized_grads(self, node): if anno.hasanno(node, 'defined_in'): uses = (succ for succ in gast.walk(node) if isinstance(succ, gast.Name) and isinstance(succ.ctx, gast.Load)) for use in uses: if ((anno.hasanno(use, 'adjoint_var') or anno.hasanno(use, 'temp_adjoint_var')) and use.id not in anno.getanno(node, 'defined_in') and use.id not in self.added): self.added.add(use.id) self.insert_top(self._init(use)) return node
def resolve(node, source, function=None): """Adds an origin information to node and its subnodes. This allows us to map the original source code line numbers to generated source code. Args: node: gast.AST node. Should be a gast.FunctionDef. This is the node we annotate with origin information. source: Text, the source code. Should satisfy relationship `node in iter_tree(gast.parse(source))`; otherwise the lineno will be unreliable. function: The original function. If it is None then only the line numbers and column offset will be set in the annotation, with the rest of the information being None. """ if function: _, function_lineno = tf_inspect.getsourcelines(function) function_filepath = tf_inspect.getsourcefile(function) else: function_lineno = None function_filepath = None # TODO(mdan): Pull this to a separate utility. code_reader = six.StringIO(source) comment_map = {} for token in tokenize.generate_tokens(code_reader.readline): tok_type, tok_string, loc, _, _ = token srow, _ = loc if tok_type == tokenize.COMMENT: comment_map[srow] = tok_string.strip()[1:].strip() source_lines = source.split('\n') for n in gast.walk(node): if not hasattr(n, 'lineno'): continue within_body_offset = n.lineno - node.lineno source_code_line = source_lines[n.lineno - 1] if function: source_lineno = function_lineno + within_body_offset function_name = function.__name__ else: source_lineno = n.lineno function_name = None location = Location(function_filepath, source_lineno, n.col_offset) origin = OriginInfo(location, function_name, source_code_line, comment_map.get(source_lineno)) anno.setanno(n, anno.Basic.ORIGIN, origin)
def dup(node, copy_map, field_name='___pyct_anno'): """Recursively copies annotations in an AST tree. Args: node: ast.AST copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination key. All annotations with the source key will be copied to identical annotations with the destination key. field_name: str """ for n in gast.walk(node): for k in copy_map: if hasanno(n, k, field_name): setanno(n, copy_map[k], getanno(n, k, field_name), field_name)
def _visit_For(self, node): assert isinstance(node, gast.For) # TODO # self.is_control_flow_num += 1 if not isinstance(node.iter, gast.Call): return if not isinstance(node.iter.func, gast.Name): return if node.iter.func.id != "range": return for arg in node.iter.args: self.visit(arg) for child_node in gast.walk(node): if isinstance(child_node, (gast.Continue, gast.Break)): self._visit_break_continue(child_node) return
def resolve(nodes, source, function=None): """Adds an origin information to all nodes inside the body of function. Args: nodes: Union[ast.AST, Iterable[ast.AST, ...]] source: Text, the source code string for the function whose body nodes will be annotated. function: Callable, the function that will have all nodes inside of it annotation with an OriginInfo annotation with key anno.Basic.ORIGIN. If it is None then only the line numbers and column offset will be set in the annotation, with the rest of the information being None. Returns: A tuple of the AST node for function and a String containing its source code. """ if not isinstance(nodes, (list, tuple)): nodes = (nodes,) if function: _, function_lineno = tf_inspect.getsourcelines(function) function_filepath = tf_inspect.getsourcefile(function) else: function_lineno = None function_filepath = None source_lines = source.split('\n') for node in nodes: for n in gast.walk(node): if not hasattr(n, 'lineno'): continue lineno_in_body = n.lineno source_code_line = source_lines[lineno_in_body - 1] if function: source_lineno = function_lineno + lineno_in_body function_name = function.__name__ else: source_lineno = lineno_in_body function_name = None location = Location(function_filepath, source_lineno, n.col_offset) origin = OriginInfo(location, function_name, source_code_line) anno.setanno(n, anno.Basic.ORIGIN, origin)
def resolve(node, source, function=None): """Adds an origin information to all nodes inside the body of function. Args: node: The AST node for the function whose body nodes will be annotated. source: Text, the source code string for the function whose body nodes will be annotated. function: Callable, the function that will have all nodes inside of it annotation with an OriginInfo annotation with key anno.Basic.ORIGIN. If it is None then only the line numbers and column offset will be set in the annotation, with the rest of the information being None. Returns: A tuple of the AST node for function and a String containing its source code. """ if function: _, function_lineno = tf_inspect.getsourcelines(function) function_filepath = tf_inspect.getsourcefile(function) else: function_lineno = None function_filepath = None source_lines = source.split('\n') for n in gast.walk(node): if hasattr(n, 'lineno'): # n.lineno is relative to the start of the enclosing function, so need to # offset it by the line of the function. source_code_line = source_lines[n.lineno - 1] if function: source_lineno = n.lineno + function_lineno - 1 function_name = function.__name__ else: source_lineno = n.lineno function_name = None anno.setanno( n, anno.Basic.ORIGIN, OriginInfo(function_filepath, function_name, source_lineno, n.col_offset, source_code_line))
def add_filename_field(node, filename): for descendant in ast.walk(node): descendant.filename = filename
def contains_return(node): for n in gast.walk(node): if isinstance(n, gast.Return): return True return False