def test_defined(): node = tangent.quoting.parse_function(g) cfg.forward(node, cfg.Defined()) body = node.body[0].body # only x is for sure defined at the end assert len(anno.getanno(body[1], 'defined_in')) == 1 # at the end of the if body both x and y are defined if_body = body[0].body assert len(anno.getanno(if_body[0], 'defined_out')) == 2
def _fix(node): """Fix the naive construction of the adjont. See `fixes.py` for details. This function also returns the result of reaching definitions analysis so that `split` mode can use this to carry over the state from primal to adjoint. Args: node: A module with the primal and adjoint function definitions as returned by `reverse_ad`. Returns: node: A module with the primal and adjoint function with additional variable definitions and such added so that pushes onto the stack and gradient accumulations are all valid. defined: The variables defined at the end of the primal. reaching: The variable definitions that reach the end of the primal. """ # Do reaching definitions analysis on primal and adjoint pri_cfg = cfg.CFG.build_cfg(node.body[0]) defined = cfg.Defined() defined.visit(pri_cfg.entry) reaching = cfg.ReachingDefinitions() reaching.visit(pri_cfg.entry) cfg.forward(node.body[1], cfg.Defined()) cfg.forward(node.body[1], cfg.ReachingDefinitions()) # Remove pushes of variables that were never defined fixes.CleanStack().visit(node) fixes.FixStack().visit(node.body[0]) # Change accumulation into definition if possible fixes.CleanGrad().visit(node.body[1]) # Define gradients that might or might not be defined fixes.FixGrad().visit(node.body[1]) return node, defined.exit, reaching.exit