Esempio n. 1
0
  def test_div_inverse_ildj(self):
    def f(x):
      return x / 2
    f_inv = core.inverse_and_ildj(f, 2.)
    x, ildj_ = f_inv(2.)
    onp.testing.assert_allclose(x, 4.)
    onp.testing.assert_allclose(ildj_, -np.log(np.abs(jax.jacrev(f)(4.))),
                                atol=1e-6, rtol=1e-6)

    def f2(x):
      return 3. / x
    f2_inv = core.inverse_and_ildj(f2, 2.)
    x, ildj_ = f2_inv(2.)
    onp.testing.assert_allclose(x, 1.5)
    onp.testing.assert_allclose(ildj_, -np.log(np.abs(jax.jacrev(f2)(1.5))),
                                atol=1e-6, rtol=1e-6)
Esempio n. 2
0
  def test_mul_inverse_ildj(self):
    def f(x):
      return x * 2
    f_inv = core.inverse_and_ildj(f, 1.)
    x, ildj_ = f_inv(2.)
    onp.testing.assert_allclose(x, 1.)
    onp.testing.assert_allclose(ildj_, -np.log(np.abs(jax.jacrev(f)(1.))),
                                atol=1e-6, rtol=1e-6)

    def f2(x):
      return 2 * x
    f2_inv = core.inverse_and_ildj(f2, 1.)
    x, ildj_ = f2_inv(2.)
    onp.testing.assert_allclose(x, 1.)
    onp.testing.assert_allclose(ildj_, -np.log(np.abs(jax.jacrev(f)(1.))),
                                atol=1e-6, rtol=1e-6)
Esempio n. 3
0
  def test_inverse_of_jit(self):
    def f(x):
      x = jax.jit(lambda x: x)(x)
      return x / 2.
    f_inv = core.inverse_and_ildj(f, 2.)
    x, ildj_ = f_inv(2.)
    onp.testing.assert_allclose(x, 4.)
    onp.testing.assert_allclose(ildj_, -np.log(np.abs(jax.jacrev(f)(4.))),
                                atol=1e-6, rtol=1e-6)

    def f2(x):
      return jax.jit(lambda x: 3. / x)(x)
    f2_inv = core.inverse_and_ildj(f2, 2.)
    x, ildj_ = f2_inv(2.)
    onp.testing.assert_allclose(x, 1.5)
    onp.testing.assert_allclose(ildj_, -np.log(np.abs(jax.jacrev(f2)(1.5))),
                                atol=1e-6, rtol=1e-6)
Esempio n. 4
0
 def test_inverse_of_pmap(self):
   def f(x):
     return jax.pmap(lambda x: np.exp(x) + 2.)(x)
   f_inv = core.inverse_and_ildj(f, np.ones(2) * 4)
   x, ildj_ = f_inv(np.ones(2) * 4)
   onp.testing.assert_allclose(x, np.log(2.) * np.ones(2))
   onp.testing.assert_allclose(ildj_,
                               -np.log(np.abs(np.sum(jax.jacrev(f)(
                                   np.log(2.) * np.ones(2))))),
                               atol=1e-6, rtol=1e-6)
Esempio n. 5
0
 def test_lower_triangular_jacobian(self):
   def f(x, y):
     return x + 2., np.exp(x) + y
   def f_vec(x):
     return np.array([x[0] + 2., np.exp(x[0]) + x[1]])
   f_inv = core.inverse_and_ildj(f, 0., 0.)
   x, ildj_ = f_inv(3., np.exp(1.) + 1.)
   onp.testing.assert_allclose(x, (1., 1.))
   onp.testing.assert_allclose(ildj_, -np.log(
       np.abs(np.linalg.slogdet(jax.jacrev(f_vec)(np.ones(2)))[0])),
                               atol=1e-6, rtol=1e-6)
Esempio n. 6
0
 def test_pmap_forward(self):
   def f(x, y):
     z = jax.pmap(np.exp)(x)
     return x + 2., z + y
   def f_vec(x):
     return np.array([x[0] + 2., np.exp(x[0]) + x[1]])
   f_inv = core.inverse_and_ildj(f, np.ones(2), np.ones(2))
   x, ildj_ = f_inv(2 * np.ones(2), np.ones(2))
   onp.testing.assert_allclose(x, (np.zeros(2), np.zeros(2)))
   onp.testing.assert_allclose(ildj_, -np.log(
       np.abs(np.linalg.slogdet(jax.jacrev(f_vec)(np.ones(2)))[0])),
                               atol=1e-6, rtol=1e-6)