def reverse_ad(node, wrt, preserve_result): """Perform reverse-mode AD on an AST. This function analyses the AST to determine which variables are active and proceeds by taking the naive derivative. Before returning the primal and adjoint it annotates push and pop statements as such. Args: node: A `FunctionDef` AST node. wrt: A tuple of argument indices with respect to which we take the derivative. preserve_result: A boolean indicating whether the generated derivative function should also return the original return value. Returns: mod: A `Module` node containing the naive primal and adjoint of the function which can be fed to the `split` and `joint` functions. required: A list of tuples of functions and argument indices. These functions were called by the function but did not have an adjoint. """ if not isinstance(node, gast.FunctionDef): raise TypeError # Activity analysis cfg.forward(node, cfg.Active(wrt)) ad = ReverseAD(wrt, preserve_result) pri, adj = ad.visit(node) mod = gast.Module(body=[pri, adj]) mod = annotate.find_stacks(mod) return mod, ad.required, ad.stack
def test_defined(): node = tangent.quoting.parse_function(g) cfg.forward(node, cfg.Defined()) body = node.body[0].body # only x is for sure defined at the end assert len(anno.getanno(body[1], 'defined_in')) == 1 # at the end of the if body both x and y are defined if_body = body[0].body assert len(anno.getanno(if_body[0], 'defined_out')) == 2
def test_reaching(): node = tangent.quoting.parse_function(f) cfg.forward(node, cfg.ReachingDefinitions()) body = node.body[0].body # Only the argument reaches the expression assert len(anno.getanno(body[0], 'definitions_in')) == 1 while_body = body[1].body # x can be either the argument here, or from the previous loop assert len(anno.getanno(while_body[0], 'definitions_in')) == 2 # x can only be the previous line here assert len(anno.getanno(while_body[1], 'definitions_in')) == 1 # x can be the argument here or the last definition from the while body assert len(anno.getanno(body[2], 'definitions_in')) == 2
def read_counts(node): """Check how many times a variable definition was used. Args: node: An AST to analyze. Returns: A dictionary from assignment nodes to the number of times the assigned to variable was used. """ cfg.forward(node, cfg.ReachingDefinitions()) rc = ReadCounts() rc.visit(node) return rc.n_read
def unused(node): """Find unused definitions that can be remove. This runs reaching definitions analysis followed by a walk over the AST to find all variable definitions that are not used later on. Args: node: The AST of e.g. a function body to find unused variable definitions. Returns: unused: After visiting all the nodes, this attribute contanis a set of definitions in the form of `(variable_name, node)` pairs which are unused in this AST. """ cfg.forward(node, cfg.ReachingDefinitions()) unused_obj = Unused() unused_obj.visit(node) return unused_obj.unused
def _fix(node): """Fix the naive construction of the adjont. See `fixes.py` for details. This function also returns the result of reaching definitions analysis so that `split` mode can use this to carry over the state from primal to adjoint. Args: node: A module with the primal and adjoint function definitions as returned by `reverse_ad`. Returns: node: A module with the primal and adjoint function with additional variable definitions and such added so that pushes onto the stack and gradient accumulations are all valid. defined: The variables defined at the end of the primal. reaching: The variable definitions that reach the end of the primal. """ # Do reaching definitions analysis on primal and adjoint pri_cfg = cfg.CFG.build_cfg(node.body[0]) defined = cfg.Defined() defined.visit(pri_cfg.entry) reaching = cfg.ReachingDefinitions() reaching.visit(pri_cfg.entry) cfg.forward(node.body[1], cfg.Defined()) cfg.forward(node.body[1], cfg.ReachingDefinitions()) # Remove pushes of variables that were never defined fixes.CleanStack().visit(node) fixes.FixStack().visit(node.body[0]) # Change accumulation into definition if possible fixes.CleanGrad().visit(node.body[1]) # Define gradients that might or might not be defined fixes.FixGrad().visit(node.body[1]) return node, defined.exit, reaching.exit
def test_active2(): node = tangent.quoting.parse_function(i) cfg.forward(node, cfg.Active(wrt=(1, ))) body = node.body[0].body # through y both x and z are now active assert len(anno.getanno(body[-1], 'active_out')) == 3
def test_active(): node = tangent.quoting.parse_function(h) cfg.forward(node, cfg.Active(wrt=(1, ))) body = node.body[0].body # y has been overwritten here, so nothing is active anymore assert not anno.getanno(body[-1], 'active_out')
def visit_FunctionDef(self, node): cfg.forward(node, cfg.Active(range(len(node.args.args)))) self.namer = naming.Namer.build(node) node = self.generic_visit(node) return node