Exemplo n.º 1
0
    def test_inverse_of_sow_is_identity(self):
        def f(x):
            return harvest.sow(x, name='x', tag='foo')

        x, ildj_ = inverse_and_ildj(f, 1.)(1.)
        self.assertEqual(x, 1.)
        self.assertEqual(ildj_, 0.)
Exemplo n.º 2
0
    def test_inverse_of_nest(self):
        def f(x):
            x = harvest.nest(lambda x: x, scope='foo')(x)
            return x / 2.

        f_inv = inverse_and_ildj(f, 2.)
        x, ildj_ = f_inv(2.)
        onp.testing.assert_allclose(x, 4.)
        onp.testing.assert_allclose(ildj_,
                                    -np.log(np.abs(jax.jacrev(f)(4.))),
                                    atol=1e-6,
                                    rtol=1e-6)
Exemplo n.º 3
0
    def test_inverse_of_pmap(self):
        def f(x):
            return jax.pmap(lambda x: np.exp(x) + 2.)(x)

        f_inv = inverse_and_ildj(f, np.ones(2) * 4)
        x, ildj_ = f_inv(np.ones(2) * 4)
        onp.testing.assert_allclose(x, np.log(2.) * np.ones(2))
        onp.testing.assert_allclose(
            ildj_,
            -np.log(np.abs(np.sum(jax.jacrev(f)(np.log(2.) * np.ones(2))))),
            atol=1e-6,
            rtol=1e-6)
Exemplo n.º 4
0
    def test_div_inverse_ildj(self):
        def f(x):
            return x / 2

        f_inv = inverse_and_ildj(f, 2.)
        x, ildj_ = f_inv(2.)
        onp.testing.assert_allclose(x, 4.)
        onp.testing.assert_allclose(ildj_,
                                    -np.log(np.abs(jax.jacrev(f)(4.))),
                                    atol=1e-6,
                                    rtol=1e-6)

        def f2(x):
            return 3. / x

        f2_inv = inverse_and_ildj(f2, 2.)
        x, ildj_ = f2_inv(2.)
        onp.testing.assert_allclose(x, 1.5)
        onp.testing.assert_allclose(ildj_,
                                    -np.log(np.abs(jax.jacrev(f2)(1.5))),
                                    atol=1e-6,
                                    rtol=1e-6)
Exemplo n.º 5
0
    def test_mul_inverse_ildj(self):
        def f(x):
            return x * 2

        f_inv = inverse_and_ildj(f, 1.)
        x, ildj_ = f_inv(2.)
        onp.testing.assert_allclose(x, 1.)
        onp.testing.assert_allclose(ildj_,
                                    -np.log(np.abs(jax.jacrev(f)(1.))),
                                    atol=1e-6,
                                    rtol=1e-6)

        def f2(x):
            return 2 * x

        f2_inv = inverse_and_ildj(f2, 1.)
        x, ildj_ = f2_inv(2.)
        onp.testing.assert_allclose(x, 1.)
        onp.testing.assert_allclose(ildj_,
                                    -np.log(np.abs(jax.jacrev(f)(1.))),
                                    atol=1e-6,
                                    rtol=1e-6)
Exemplo n.º 6
0
    def test_inverse_of_jit(self):
        def f(x):
            x = jax.jit(lambda x: x)(x)
            return x / 2.

        f_inv = inverse_and_ildj(f, 2.)
        x, ildj_ = f_inv(2.)
        onp.testing.assert_allclose(x, 4.)
        onp.testing.assert_allclose(ildj_,
                                    -np.log(np.abs(jax.jacrev(f)(4.))),
                                    atol=1e-6,
                                    rtol=1e-6)

        def f2(x):
            return jax.jit(lambda x: 3. / x)(x)

        f2_inv = inverse_and_ildj(f2, 2.)
        x, ildj_ = f2_inv(2.)
        onp.testing.assert_allclose(x, 1.5)
        onp.testing.assert_allclose(ildj_,
                                    -np.log(np.abs(jax.jacrev(f2)(1.5))),
                                    atol=1e-6,
                                    rtol=1e-6)
Exemplo n.º 7
0
    def test_lower_triangular_jacobian(self):
        def f(x, y):
            return x + 2., np.exp(x) + y

        def f_vec(x):
            return np.array([x[0] + 2., np.exp(x[0]) + x[1]])

        f_inv = inverse_and_ildj(f, 0., 0.)
        x, ildj_ = f_inv(3., np.exp(1.) + 1.)
        onp.testing.assert_allclose(x, (1., 1.))
        onp.testing.assert_allclose(
            ildj_,
            -np.log(np.abs(
                np.linalg.slogdet(jax.jacrev(f_vec)(np.ones(2)))[0])),
            atol=1e-6,
            rtol=1e-6)
Exemplo n.º 8
0
    def test_pmap_forward(self):
        def f(x, y):
            z = jax.pmap(np.exp)(x)
            return x + 2., z + y

        def f_vec(x):
            return np.array([x[0] + 2., np.exp(x[0]) + x[1]])

        f_inv = inverse_and_ildj(f, np.ones(2), np.ones(2))
        x, ildj_ = f_inv(2 * np.ones(2), np.ones(2))
        onp.testing.assert_allclose(x, (np.zeros(2), np.zeros(2)))
        onp.testing.assert_allclose(
            ildj_,
            -np.log(np.abs(
                np.linalg.slogdet(jax.jacrev(f_vec)(np.ones(2)))[0])),
            atol=1e-6,
            rtol=1e-6)