Esempio n. 1
0
    def test_logit_ildj(self):
        def naive_logit(x):
            # This is the default JAX implementation of logit.
            return np.log(x / (1. - x))

        naive_inv = core.inverse(naive_logit)
        naive_ildj = core.ildj(naive_logit)
        with self.assertRaises(ValueError):
            naive_inv(-100.)
        with self.assertRaises(ValueError):
            naive_ildj(-100.)
        f_inv = core.inverse(jax.scipy.special.logit)
        f_ildj = core.ildj(jax.scipy.special.logit)
        onp.testing.assert_allclose(f_inv(-100.),
                                    jax.scipy.special.expit(-100.))
        onp.testing.assert_allclose(
            f_ildj(-100.),
            tfb.Sigmoid().forward_log_det_jacobian(-100., 0))
Esempio n. 2
0
  def test_sigmoid_ildj(self):
    def naive_sigmoid(x):
      # This is the default JAX implementation of sigmoid.
      return 1. / (1 + np.exp(-x))
    naive_inv = core.inverse(naive_sigmoid)
    naive_ildj = core.ildj(naive_sigmoid)

    with self.assertRaises(AssertionError):
      onp.testing.assert_allclose(naive_inv(0.9999),
                                  jax.scipy.special.logit(0.9999))
    with self.assertRaises(AssertionError):
      onp.testing.assert_allclose(naive_ildj(0.9999),
                                  tfb.Sigmoid().inverse_log_det_jacobian(
                                      0.9999, 0))

    f_inv = core.inverse(jax.nn.sigmoid)
    f_ildj = core.ildj(jax.nn.sigmoid)
    onp.testing.assert_allclose(f_inv(0.9999), jax.scipy.special.logit(0.9999))
    onp.testing.assert_allclose(f_ildj(0.9999),
                                tfb.Sigmoid().inverse_log_det_jacobian(
                                    0.9999, 0))
Esempio n. 3
0
  def test_binary_inverse_and_ildj(self):

    @custom_inverse.custom_inverse
    def add(x, y):
      return x + y

    def add_ildj(invals, z, z_ildj):
      x, y = invals
      if z is None:
        raise custom_inverse.NonInvertibleError()
      if x is not None and y is None:
        return (x, z - x), (jnp.zeros_like(z_ildj), z_ildj)
      elif x is None and y is not None:
        return (z - y, y), (z_ildj, jnp.zeros_like(z_ildj))
      return (None, None), (0., 0.)

    add_one_left = functools.partial(add, 1.)
    add_one_right = lambda x: add(x, 1.)
    self.assertEqual(core.inverse(add_one_left)(2.), 1.)
    add.def_inverse_and_ildj(add_ildj)
    self.assertEqual(core.inverse(add_one_left)(2.), 1.)
    self.assertEqual(core.inverse(add_one_right)(2.), 1.)

    def add_ildj2(invals, z, z_ildj):
      x, y = invals
      if z is None:
        raise custom_inverse.NonInvertibleError()
      if x is not None and y is None:
        return (x, z**2), (jnp.zeros_like(z_ildj), 3. + z_ildj)
      elif x is None and y is not None:
        return (z / 4, y), (2. + z_ildj, jnp.zeros_like(z_ildj))
      return (None, None), (0., 0.)

    add.def_inverse_and_ildj(add_ildj2)
    self.assertEqual(core.inverse(add_one_left)(2.), 4.)
    self.assertEqual(core.ildj(add_one_left)(2.), 3.)

    self.assertEqual(core.inverse(add_one_right)(2.), 0.5)
    self.assertEqual(core.ildj(add_one_right)(2.), 2.)
Esempio n. 4
0
  def test_unary_ildj(self):

    @custom_inverse.custom_inverse
    def add_one(x):
      return x + 1.

    def add_one_ildj(y):
      del y
      return 4.

    add_one.def_inverse_unary(f_ildj=add_one_ildj)
    self.assertEqual(core.inverse(add_one)(2.), 1.)
    self.assertEqual(core.ildj(add_one)(2.), 4.)
Esempio n. 5
0
  def test_unary_inverse_and_ildj(self):

    @custom_inverse.custom_inverse
    def add_one(x):
      return x + 1.

    def add_one_inv(y):
      return jnp.exp(y)

    def add_one_ildj(y):
      del y
      return 4.

    add_one.def_inverse_unary(add_one_inv, add_one_ildj)
    self.assertEqual(core.inverse(add_one)(2.), jnp.exp(2.))
    self.assertEqual(core.ildj(add_one)(2.), 4.)
Esempio n. 6
0
    def def_inverse_unary(self, f_inv=None, f_ildj=None):
        """Defines a unary inverse rule.

    Args:
      f_inv: An optional unary function that returns the inverse of this
        function. If not provided, we automatically invert the forward function.
      f_ildj: An optional unary function that computes the ILDJ of this
        function. If it is not provided, we automatically compute it from
        `f_inv`.
    """
        f_inv_ = f_inv or core.inverse(self.func)
        f_ildj_ = f_ildj or core.ildj(core.inverse(f_inv_), reduce_ildj=False)

        def f_inverse_and_ildj(invals, outval, outildj):
            inval = invals[0]
            if outval is None:
                raise NonInvertibleError()
            if inval is None:
                new_inval = f_inv_(outval)
                new_inildj = f_ildj_(outval) + outildj
                return (new_inval, ), (new_inildj, )
            raise NonInvertibleError()

        self.def_inverse_and_ildj(f_inverse_and_ildj)
Esempio n. 7
0
 def wrapped(*args):
     mapped_args = mapping_fn(*args)
     ildjs = inverse.ildj(mapping_fn, *args)(mapped_args)
     return target_log_prob(mapped_args) - np.sum(
         np.array(tree_util.tree_leaves(ildjs)))
Esempio n. 8
0
    def test_sqrt_ildj(self):
        def f(x):
            return np.sqrt(x)

        f_ildj = core.ildj(f)
        onp.testing.assert_allclose(f_ildj(3.), np.log(2.) + np.log(3.))
Esempio n. 9
0
    def test_reciprocal_ildj(self):
        def f(x):
            return np.reciprocal(x)

        f_ildj = core.ildj(f)
        onp.testing.assert_allclose(f_ildj(2.), onp.log(1 / 4.))