def testJacobianIssue54(self): # test modeling the code in https://github.com/google/jax/issues/54 def func(xs): return jnp.array(list(xs)) xs = jnp.ones((5, 1)) jacrev(func)(xs) # don't crash jacfwd(func)(xs) # don't crash
def test_periodic_general_deform_grad(self, spatial_dimension, dtype, box_format): N = 16 R_f, R, box, (s, E), (s_gf, E_gf), (s_g, E_g) = \ make_periodic_general_test_system(N, spatial_dimension, dtype, box_format) deformed_box = box * 0.9 self.assertAllClose( grad(E_gf)(R_f, box=deformed_box), grad(E_g)(R, new_box=deformed_box)) self.assertAllClose( jacfwd(E_gf)(R_f, box=deformed_box), jacfwd(E_g)(R, new_box=deformed_box))
def testIssue757(self): # code from https://github.com/google/jax/issues/757 def fn(a): return np.cos(a) def loop(val): iterations = 10 def apply_carry(x, i): return api.grad(fn, argnums=(0, ))(x)[0], i final_val, _ = lax.scan(apply_carry, val, np.arange(iterations)) return final_val arg = 0.5 api.jit(api.jacfwd(loop, argnums=(0, )))(arg) # doesn't crash
def test_remat_vmap(self): @api.remat def g(x): return lax.sin(lax.sin(x)) x = onp.arange(3.) ans = api.vmap(g)(x) expected = onp.sin(onp.sin(x)) self.assertAllClose(ans, expected, check_dtypes=False) ans = api.jacfwd(g)(x) expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x)) self.assertAllClose(ans, expected, check_dtypes=False) ans = api.jacrev(g)(x) expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x)) self.assertAllClose(ans, expected, check_dtypes=False)
def testJacobians(self): def jacbwd(f, x): y, pullback = vjp(f, x) std_basis = np.eye(np.size(y)).reshape((-1,) + np.shape(y)) jac_flat, = vmap(pullback, out_axes=np.ndim(y))(std_basis) return jac_flat.reshape(np.shape(y) + np.shape(x)) def jacfwd(f, x): pushfwd = lambda v: jvp(f, (x,), (v,)) std_basis = np.eye(np.size(x)).reshape((-1,) + np.shape(x)) y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis) return jac_flat.reshape(np.shape(y) + np.shape(x)) R = np.random.RandomState(0).randn A = R(4, 3) b = R(4) f = lambda x: jnp.tanh(jnp.dot(A, x) + b) x = R(3) self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False)
def _objective_hess(self) -> Callable[[DeviceArray], DeviceArray]: # hessian of the negative log posterior return jit(jacfwd(jacfwd(self._objective)))