def reverse_ad(node, wrt, preserve_result): """Perform reverse-mode AD on an AST. This function analyses the AST to determine which variables are active and proceeds by taking the naive derivative. Before returning the primal and adjoint it annotates push and pop statements as such. Args: node: A `FunctionDef` AST node. wrt: A tuple of argument indices with respect to which we take the derivative. preserve_result: A boolean indicating whether the generated derivative function should also return the original return value. Returns: mod: A `Module` node containing the naive primal and adjoint of the function which can be fed to the `split` and `joint` functions. required: A list of tuples of functions and argument indices. These functions were called by the function but did not have an adjoint. """ if not isinstance(node, gast.FunctionDef): raise TypeError # Activity analysis cfg.forward(node, cfg.Active(wrt)) ad = ReverseAD(wrt, preserve_result) pri, adj = ad.visit(node) mod = gast.Module(body=[pri, adj]) mod = annotate.find_stacks(mod) return mod, ad.required, ad.stack
def forward_ad(node, wrt, preserve_result=False, check_dims=True): """Perform forward-mode AD on an AST. This function analyses the AST to determine which variables are active and proceeds by taking the naive derivative. Before returning the primal and adjoint it annotates push and pop statements as such. Args: node: A `FunctionDef` AST node. wrt: A tuple of argument indices with respect to which we take the derivative. preserve_result: A boolean indicating whether the original non-differentiated function value should be returned check_dims: A boolean indicating whether the provided derivatives should have the same shape as their corresponding arguments. Returns: mod: A `Module` node containing the naive primal and adjoint of the function which can be fed to the `split` and `joint` functions. required: A list of tuples of functions and argument indices. These functions were called by the function but did not have an adjoint. """ if not isinstance(node, gast.FunctionDef): raise TypeError # Activity analysis cfg_obj = cfg.CFG.build_cfg(node) cfg.Active(range(len(node.args.args))).visit(cfg_obj.entry) # Build forward mode function fad = ForwardAD(wrt, preserve_result, check_dims) node = fad.visit(node) # Annotate stacks node = annotate.find_stacks(node) # Clean up naive forward-mode fcode node = gast.Module([node]) anno.clearanno(node) return node, fad.required
def test_active2(): node = tangent.quoting.parse_function(i) cfg.forward(node, cfg.Active(wrt=(1, ))) body = node.body[0].body # through y both x and z are now active assert len(anno.getanno(body[-1], 'active_out')) == 3
def test_active(): node = tangent.quoting.parse_function(h) cfg.forward(node, cfg.Active(wrt=(1, ))) body = node.body[0].body # y has been overwritten here, so nothing is active anymore assert not anno.getanno(body[-1], 'active_out')
def visit_FunctionDef(self, node): cfg.forward(node, cfg.Active(range(len(node.args.args)))) self.namer = naming.Namer.build(node) node = self.generic_visit(node) return node