def test_reaching(): node = tangent.quoting.parse_function(f) cfg.forward(node, cfg.ReachingDefinitions()) body = node.body[0].body # Only the argument reaches the expression assert len(anno.getanno(body[0], 'definitions_in')) == 1 while_body = body[1].body # x can be either the argument here, or from the previous loop assert len(anno.getanno(while_body[0], 'definitions_in')) == 2 # x can only be the previous line here assert len(anno.getanno(while_body[1], 'definitions_in')) == 1 # x can be the argument here or the last definition from the while body assert len(anno.getanno(body[2], 'definitions_in')) == 2
def _fix(node): """Fix the naive construction of the adjont. See `fixes.py` for details. This function also returns the result of reaching definitions analysis so that `split` mode can use this to carry over the state from primal to adjoint. Args: node: A module with the primal and adjoint function definitions as returned by `reverse_ad`. Returns: node: A module with the primal and adjoint function with additional variable definitions and such added so that pushes onto the stack and gradient accumulations are all valid. defined: The variables defined at the end of the primal. reaching: The variable definitions that reach the end of the primal. """ # Do reaching definitions analysis on primal and adjoint pri_cfg = cfg.CFG.build_cfg(node.body[0]) defined = cfg.Defined() defined.visit(pri_cfg.entry) reaching = cfg.ReachingDefinitions() reaching.visit(pri_cfg.entry) cfg.forward(node.body[1], cfg.Defined()) cfg.forward(node.body[1], cfg.ReachingDefinitions()) # Remove pushes of variables that were never defined fixes.CleanStack().visit(node) fixes.FixStack().visit(node.body[0]) # Change accumulation into definition if possible fixes.CleanGrad().visit(node.body[1]) # Define gradients that might or might not be defined fixes.FixGrad().visit(node.body[1]) return node, defined.exit, reaching.exit
def read_counts(node): """Check how many times a variable definition was used. Args: node: An AST to analyze. Returns: A dictionary from assignment nodes to the number of times the assigned to variable was used. """ cfg.forward(node, cfg.ReachingDefinitions()) rc = ReadCounts() rc.visit(node) return rc.n_read
def unused(node): """Find unused definitions that can be remove. This runs reaching definitions analysis followed by a walk over the AST to find all variable definitions that are not used later on. Args: node: The AST of e.g. a function body to find unused variable definitions. Returns: unused: After visiting all the nodes, this attribute contanis a set of definitions in the form of `(variable_name, node)` pairs which are unused in this AST. """ cfg.forward(node, cfg.ReachingDefinitions()) unused_obj = Unused() unused_obj.visit(node) return unused_obj.unused