def testHessian(self): # test based on code from sindhwani@google def fun(x, t): return np.sum(np.power(np.maximum(x, 0.0), 2)) + t x = onp.array([-1., -0.5, 0., 0.5, 1.0]) ans = hessian(lambda x: fun(x, 0.0))(x) expected = onp.array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0.5, 0., 0.], [0., 0., 0., 2., 0.], [0., 0., 0., 0., 2.]]) self.assertAllClose(ans, expected, check_dtypes=False)
def testIssue387(self): # https://github.com/google/jax/issues/387 R = np.random.RandomState(0).rand(100, 2) def dist_sq(R): dR = R[:, jnp.newaxis, :] - R[jnp.newaxis, :, :] zero = jnp.zeros_like(dR) dR = dR - jnp.where(jnp.abs(dR) < 0.5, zero, 0.5 * jnp.sign(dR)) return jnp.sum(dR ** 2, axis=2) @jit def f(R): _ = dist_sq(R) return jnp.sum(R ** 2) _ = hessian(f)(R) # don't crash on UnshapedArray