def test_inverse_of_sow_is_identity(self): def f(x): return harvest.sow(x, name='x', tag='foo') x, ildj_ = inverse_and_ildj(f, 1.)(1.) self.assertEqual(x, 1.) self.assertEqual(ildj_, 0.)
def test_inverse_of_nest(self): def f(x): x = harvest.nest(lambda x: x, scope='foo')(x) return x / 2. f_inv = inverse_and_ildj(f, 2.) x, ildj_ = f_inv(2.) onp.testing.assert_allclose(x, 4.) onp.testing.assert_allclose(ildj_, -np.log(np.abs(jax.jacrev(f)(4.))), atol=1e-6, rtol=1e-6)
def test_inverse_of_pmap(self): def f(x): return jax.pmap(lambda x: np.exp(x) + 2.)(x) f_inv = inverse_and_ildj(f, np.ones(2) * 4) x, ildj_ = f_inv(np.ones(2) * 4) onp.testing.assert_allclose(x, np.log(2.) * np.ones(2)) onp.testing.assert_allclose( ildj_, -np.log(np.abs(np.sum(jax.jacrev(f)(np.log(2.) * np.ones(2))))), atol=1e-6, rtol=1e-6)
def test_div_inverse_ildj(self): def f(x): return x / 2 f_inv = inverse_and_ildj(f, 2.) x, ildj_ = f_inv(2.) onp.testing.assert_allclose(x, 4.) onp.testing.assert_allclose(ildj_, -np.log(np.abs(jax.jacrev(f)(4.))), atol=1e-6, rtol=1e-6) def f2(x): return 3. / x f2_inv = inverse_and_ildj(f2, 2.) x, ildj_ = f2_inv(2.) onp.testing.assert_allclose(x, 1.5) onp.testing.assert_allclose(ildj_, -np.log(np.abs(jax.jacrev(f2)(1.5))), atol=1e-6, rtol=1e-6)
def test_mul_inverse_ildj(self): def f(x): return x * 2 f_inv = inverse_and_ildj(f, 1.) x, ildj_ = f_inv(2.) onp.testing.assert_allclose(x, 1.) onp.testing.assert_allclose(ildj_, -np.log(np.abs(jax.jacrev(f)(1.))), atol=1e-6, rtol=1e-6) def f2(x): return 2 * x f2_inv = inverse_and_ildj(f2, 1.) x, ildj_ = f2_inv(2.) onp.testing.assert_allclose(x, 1.) onp.testing.assert_allclose(ildj_, -np.log(np.abs(jax.jacrev(f)(1.))), atol=1e-6, rtol=1e-6)
def test_inverse_of_jit(self): def f(x): x = jax.jit(lambda x: x)(x) return x / 2. f_inv = inverse_and_ildj(f, 2.) x, ildj_ = f_inv(2.) onp.testing.assert_allclose(x, 4.) onp.testing.assert_allclose(ildj_, -np.log(np.abs(jax.jacrev(f)(4.))), atol=1e-6, rtol=1e-6) def f2(x): return jax.jit(lambda x: 3. / x)(x) f2_inv = inverse_and_ildj(f2, 2.) x, ildj_ = f2_inv(2.) onp.testing.assert_allclose(x, 1.5) onp.testing.assert_allclose(ildj_, -np.log(np.abs(jax.jacrev(f2)(1.5))), atol=1e-6, rtol=1e-6)
def test_lower_triangular_jacobian(self): def f(x, y): return x + 2., np.exp(x) + y def f_vec(x): return np.array([x[0] + 2., np.exp(x[0]) + x[1]]) f_inv = inverse_and_ildj(f, 0., 0.) x, ildj_ = f_inv(3., np.exp(1.) + 1.) onp.testing.assert_allclose(x, (1., 1.)) onp.testing.assert_allclose( ildj_, -np.log(np.abs( np.linalg.slogdet(jax.jacrev(f_vec)(np.ones(2)))[0])), atol=1e-6, rtol=1e-6)
def test_pmap_forward(self): def f(x, y): z = jax.pmap(np.exp)(x) return x + 2., z + y def f_vec(x): return np.array([x[0] + 2., np.exp(x[0]) + x[1]]) f_inv = inverse_and_ildj(f, np.ones(2), np.ones(2)) x, ildj_ = f_inv(2 * np.ones(2), np.ones(2)) onp.testing.assert_allclose(x, (np.zeros(2), np.zeros(2))) onp.testing.assert_allclose( ildj_, -np.log(np.abs( np.linalg.slogdet(jax.jacrev(f_vec)(np.ones(2)))[0])), atol=1e-6, rtol=1e-6)