Пример #1
0
  def test_grad(self, ad="simple"):
    call_tf = CALL_TF_IMPLEMENTATIONS[ad]

    def f_jax(x):
      return 3. * jnp.sin(2. * x)

    def f_outside(x):
      return 3. * call_tf(tf.math.sin, 2. * x, result_shape=x)

    x = 4.
    self.assertAllClose(f_jax(x), f_outside(x))

    grad_f = api.grad(f_outside)(x)
    self.assertAllClose(api.grad(f_jax)(x), grad_f)
Пример #2
0
  def test_higher_order_grad(self, degree=4):
    call_tf = call_tf_full_ad

    def f_jax(x):
      return 2. * x * x * x

    def f_outside(x):
      return 2. * call_tf(lambda y: y * y * y, x,
                          result_shape=x)

    grad_jax = f_jax
    grad_outside = f_outside
    for i in range(degree):
      grad_jax = api.grad(grad_jax)
      grad_outside = api.grad(grad_outside)

    res_jax = grad_jax(5.)
    self.assertAllClose(res_jax, grad_outside(5.))
Пример #3
0
  def test_grad_pytree(self):
    call_tf = call_tf_full_ad

    def f_jax(xy):
      dict_ab = dict(a=2. * xy[0], b=xy[0] * xy[1])
      return 3. * dict_ab["a"] + 4. * dict_ab["b"]

    def f_outside(xy):
      dict_ab = call_tf(
          lambda xy: dict(a=2. * xy[0], b=xy[0] * xy[1]),
          xy,
          result_shape=dict(a=xy[0], b=xy[1]))
      return 3. * dict_ab["a"] + 4. * dict_ab["b"]

    xy = (5., 6.)
    self.assertAllClose(f_jax(xy), f_outside(xy))
    res_jax = api.grad(f_jax)(xy)
    self.assertAllClose(res_jax, api.grad(f_outside)(xy))
Пример #4
0
 def testJVPOfGradOfIndexing(self):
   # Should return a value, even though we didn't pass a symbolic zero as the
   # index tangent.
   x = jnp.ones((3, 4), jnp.float32)
   i = jnp.ones((3,), jnp.int32)
   f = lambda x, i: jnp.sum(x[i])
   primals, tangents = api.jvp(api.grad(f), (x, i),
                               (x, np.zeros(i.shape, dtypes.float0)))
   expected = np.broadcast_to(
     np.array([0, 3, 0], dtype=np.float32)[:, None], (3, 4))
   self.assertAllClose(expected, primals)
   self.assertAllClose(np.zeros_like(x), tangents)
Пример #5
0
  def testGammaGrad(self, alpha):
    rng = random.PRNGKey(0)
    alphas = np.full((100,), alpha)
    z = random.gamma(rng, alphas)
    actual_grad = api.grad(lambda x: random.gamma(rng, x).sum())(alphas)

    eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
    cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps)
               - scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps)
    pdf = scipy.stats.gamma.pdf(z, alpha)
    expected_grad = -cdf_dot / pdf

    self.assertAllClose(actual_grad, expected_grad, check_dtypes=True,
                        rtol=2e-2 if jtu.device_under_test() == "tpu" else 7e-4)
Пример #6
0
  def testStopGradient(self):
    def f(x):
      return lax.sin(x) * lax.cos(lax.stop_gradient(x))

    def f2(x, y):
      return lax.sin(x) * lax.cos(y)

    x = 3.14
    ans = api.grad(f)(x)
    expected = api.grad(f2)(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(api.grad(f))(x)
    expected = api.grad(api.grad(f2))(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
    expected = np.array(0.0)
    self.assertAllClose(ans, expected, check_dtypes=False)

    with jax.enable_checks(False):
      with self.assertRaises(TypeError):
        lax.stop_gradient(lambda x: x)
Пример #7
0
 def testGradOfXlog1pyAtZero(self):
     partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
     self.assertAllClose(api.grad(partial_xlog1py)(-1.),
                         0.,
                         check_dtypes=False)