def test_J_array(xs): def prod(xs): p = array_reduce(lambda x, y: x * y, xs, ()) return array_to_scalar(p) jy, bprop = J(prod)(J(xs)) return Jinv(jy), bprop(1.0)
def test_J_return_function(x): def f(y): return y * y def g(): return f jg, _ = J(g)() jy, bprop = jg(J(x)) _, dy = bprop(1.0) return Jinv(jy), dy
def test_Jinv4(x): return Jinv(scalar_add)(x)
def test_Jinv3(x): def f(x): return x * x return Jinv(f)(x)
def test_Jinv2(x): def f(x): return x * x ff = Jinv(J(f)) return ff(x)
def test_Jinv(x): return Jinv(x)