Пример #1
0
def autodiff_ast(func, wrt, motion, mode, preserve_result, check_dims,
                 verbose):
    """Perform AD on a single function and return the AST.

  Args:
    See `grad`.

  Returns:
    node: The AST of a module containing the adjoint and primal function
        definitions.
    required: A list of non-built in functions that this function called, and
        of which the primals and adjoints need to be made available in order
        for the returned function to run.
  """
    node = annotate.resolve_calls(func)
    node = desugar.explicit_loop_indexes(node)
    fence.validate(node, inspect.getsource(func))
    node = anf_.anf(node)
    if verbose >= 2:
        print('ANF')
        print(quoting.to_source(node))
    if mode == 'reverse':
        node, required, stack = reverse_ad.reverse_ad(node.body[0], wrt,
                                                      preserve_result,
                                                      check_dims)
        if verbose >= 2:
            print('RAW')
            print(quoting.to_source(node))
        if motion == 'split':
            node = reverse_ad.split(node, stack)
        else:
            node = reverse_ad.joint(node)
        if verbose >= 2:
            print('MOTION')
            print(quoting.to_source(node))
    elif mode == 'forward':
        node, required = forward_ad.forward_ad(node.body[0], wrt,
                                               preserve_result, check_dims)
    return node, required
Пример #2
0
def test_anf():
    def g(x):
        return x * 2

    h = g

    def f(x):
        y = g(h(x))
        return y

    assert anf_lines(f)[1].strip() == "h_x = h(x)"
    assert anf_function(f, locals())(2) == 8

    def f(x):
        return x * x * x

    assert 'return' in anf_lines(f)[-1] and '*' not in anf_lines(f)[-1]
    assert anf_function(f)(2) == 8

    def f(x):
        y = [(x.y[0], ), 3]
        y += x * f(x[g(x)].b, (3, x / -2))

    assert anf.anf(quoting.parse_function(f))
Пример #3
0
def anf_function(f, globals_=None):
    m = gast.gast_to_ast(anf.anf(quoting.parse_function(f)))
    m = gast.fix_missing_locations(m)
    exec(compile(m, '<string>', 'exec'), globals_)
    return f
Пример #4
0
def anf_lines(f):
    """Return the ANF transformed source code as lines."""
    return quoting.unquote(anf.anf(quoting.parse_function(f))).split('\n')