def autodiff_ast(func, wrt, motion, mode, preserve_result, check_dims, verbose): """Perform AD on a single function and return the AST. Args: See `grad`. Returns: node: The AST of a module containing the adjoint and primal function definitions. required: A list of non-built in functions that this function called, and of which the primals and adjoints need to be made available in order for the returned function to run. """ node = annotate.resolve_calls(func) node = desugar.explicit_loop_indexes(node) fence.validate(node, inspect.getsource(func)) node = anf_.anf(node) if verbose >= 2: print('ANF') print(quoting.to_source(node)) if mode == 'reverse': node, required, stack = reverse_ad.reverse_ad(node.body[0], wrt, preserve_result, check_dims) if verbose >= 2: print('RAW') print(quoting.to_source(node)) if motion == 'split': node = reverse_ad.split(node, stack) else: node = reverse_ad.joint(node) if verbose >= 2: print('MOTION') print(quoting.to_source(node)) elif mode == 'forward': node, required = forward_ad.forward_ad(node.body[0], wrt, preserve_result, check_dims) return node, required
def test_anf(): def g(x): return x * 2 h = g def f(x): y = g(h(x)) return y assert anf_lines(f)[1].strip() == "h_x = h(x)" assert anf_function(f, locals())(2) == 8 def f(x): return x * x * x assert 'return' in anf_lines(f)[-1] and '*' not in anf_lines(f)[-1] assert anf_function(f)(2) == 8 def f(x): y = [(x.y[0], ), 3] y += x * f(x[g(x)].b, (3, x / -2)) assert anf.anf(quoting.parse_function(f))
def anf_function(f, globals_=None): m = gast.gast_to_ast(anf.anf(quoting.parse_function(f))) m = gast.fix_missing_locations(m) exec(compile(m, '<string>', 'exec'), globals_) return f
def anf_lines(f): """Return the ANF transformed source code as lines.""" return quoting.unquote(anf.anf(quoting.parse_function(f))).split('\n')