예제 #1
0
  def testJacobianIssue54(self):
    # test modeling the code in https://github.com/google/jax/issues/54

    def func(xs):
      return jnp.array(list(xs))

    xs = jnp.ones((5, 1))
    jacrev(func)(xs)  # don't crash
    jacfwd(func)(xs)  # don't crash
예제 #2
0
    def test_periodic_general_deform_grad(self, spatial_dimension, dtype,
                                          box_format):
        N = 16
        R_f, R, box, (s, E), (s_gf, E_gf), (s_g, E_g) = \
          make_periodic_general_test_system(N, spatial_dimension, dtype, box_format)
        deformed_box = box * 0.9
        self.assertAllClose(
            grad(E_gf)(R_f, box=deformed_box),
            grad(E_g)(R, new_box=deformed_box))

        self.assertAllClose(
            jacfwd(E_gf)(R_f, box=deformed_box),
            jacfwd(E_g)(R, new_box=deformed_box))
예제 #3
0
    def testIssue757(self):
        # code from https://github.com/google/jax/issues/757
        def fn(a):
            return np.cos(a)

        def loop(val):
            iterations = 10

            def apply_carry(x, i):
                return api.grad(fn, argnums=(0, ))(x)[0], i

            final_val, _ = lax.scan(apply_carry, val, np.arange(iterations))
            return final_val

        arg = 0.5
        api.jit(api.jacfwd(loop, argnums=(0, )))(arg)  # doesn't crash
예제 #4
0
  def test_remat_vmap(self):
    @api.remat
    def g(x):
      return lax.sin(lax.sin(x))

    x = onp.arange(3.)

    ans = api.vmap(g)(x)
    expected = onp.sin(onp.sin(x))
    self.assertAllClose(ans, expected, check_dtypes=False)

    ans = api.jacfwd(g)(x)
    expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
    self.assertAllClose(ans, expected, check_dtypes=False)

    ans = api.jacrev(g)(x)
    expected = onp.diag(onp.cos(onp.sin(x)) * onp.cos(x))
    self.assertAllClose(ans, expected, check_dtypes=False)
예제 #5
0
  def testJacobians(self):
    def jacbwd(f, x):
      y, pullback = vjp(f, x)
      std_basis = np.eye(np.size(y)).reshape((-1,) + np.shape(y))
      jac_flat, = vmap(pullback, out_axes=np.ndim(y))(std_basis)
      return jac_flat.reshape(np.shape(y) + np.shape(x))

    def jacfwd(f, x):
      pushfwd = lambda v: jvp(f, (x,), (v,))
      std_basis = np.eye(np.size(x)).reshape((-1,) + np.shape(x))
      y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis)
      return jac_flat.reshape(np.shape(y) + np.shape(x))

    R = np.random.RandomState(0).randn

    A = R(4, 3)
    b = R(4)
    f = lambda x: jnp.tanh(jnp.dot(A, x) + b)

    x = R(3)
    self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False)
예제 #6
0
 def _objective_hess(self) -> Callable[[DeviceArray], DeviceArray]:
     # hessian of the negative log posterior
     return jit(jacfwd(jacfwd(self._objective)))