def test_J_array(xs): def prod(xs): p = array_reduce(lambda x, y: x * y, xs, ()) return array_to_scalar(p) jy, bprop = J(prod)(J(xs)) return Jinv(jy), bprop(1.0)
def f(x, y): a = x * x def mula(z): return a * z res, bprop = J(mula)(J(y)) return bprop(1)[1]
def f(x, y): def sqx(): return x * x def mulsqx(z): return sqx() * z res, bprop = J(mulsqx)(J(y)) return bprop(1)[1]
def test_J(x): def f(x): return x * x jf = J(f) jx = J(x) jy, bprop = jf(jx) df, dx = bprop(1.0) return jy, df, dx
def test_J_return_function(x): def f(y): return y * y def g(): return f jg, _ = J(g)() jy, bprop = jg(J(x)) _, dy = bprop(1.0) return Jinv(jy), dy
def test_J_bprop_invalid(x): def f(x): return x * x _, bprop = J(f)(J(x)) return bprop(1.0, 1.0)
def test_Jinv2(x): def f(x): return x * x ff = Jinv(J(f)) return ff(x)