def test_trivial_noninvertible(self):
        def f(x):
            del x
            return 1.

        with self.assertRaises(ValueError):
            core.inverse(f)(1.)
示例#2
0
  def test_inverse_with_tuple_inputs(self):

    @custom_inverse.custom_inverse
    def dense(params, x):
      w, b = params
      return w * x + b

    def dense_ildj(invals, out, out_ildj):
      (w, b), _ = invals
      if w is None or b is None:
        raise custom_inverse.NonInvertibleError()
      in_ildj = core.ildj(lambda x: w * x + b)(out)
      return ((w, b), (out - b) / w), ((jnp.zeros_like(w), jnp.zeros_like(b)),
                                       out_ildj + in_ildj)

    dense_apply = functools.partial(dense, (2., 1.))
    self.assertEqual(core.inverse(dense_apply)(5.), 2.)
    dense.def_inverse_and_ildj(dense_ildj)
    self.assertEqual(core.inverse(dense_apply)(5.), 2.)

    def dense_ildj2(invals, out, out_ildj):
      (w, b), _ = invals
      if w is None and b is None:
        raise custom_inverse.NonInvertibleError()
      in_ildj = jnp.ones_like(out)
      return ((w, b), w * out), ((jnp.zeros_like(w), jnp.zeros_like(b)),
                                 out_ildj + in_ildj)

    dense.def_inverse_and_ildj(dense_ildj2)
    self.assertEqual(core.inverse(dense_apply)(5.), 10.)
    self.assertEqual(core.ildj(dense_apply)(5.), 1.)
示例#3
0
    def test_pow_inverse(self):
        def f(x, y):
            return lax.pow(x, y)

        f_x_inv = core.inverse(lambda x: f(x, 2.))
        onp.testing.assert_allclose(f_x_inv(2.), np.sqrt(2.))
        f_y_inv = core.inverse(lambda y: f(2., y))
        onp.testing.assert_allclose(f_y_inv(3.), np.log(3.) / np.log(2.))
示例#4
0
  def test_div_inverse(self):
    def f(x):
      return x / 2.
    f_inv = core.inverse(f)
    onp.testing.assert_allclose(f_inv(1.0), 2.)

    def f2(x):
      return 2. / x
    f2_inv = core.inverse(f2)
    onp.testing.assert_allclose(f2_inv(1.0), 2.)
示例#5
0
  def test_simple_inverse(self):
    def f(x):
      return np.exp(x)
    f_inv = core.inverse(f, 0.1)
    onp.testing.assert_allclose(f_inv(1.0), 0.)

    def f2(x):
      return np.exp(x)
    f2_inv = core.inverse(f2, np.zeros(2))
    onp.testing.assert_allclose(f2_inv(np.ones(2)), np.zeros(2))
示例#6
0
  def test_trivial_inverse(self):
    def f(x):
      return x
    f_inv = core.inverse(f)
    onp.testing.assert_allclose(f_inv(1.0), 1.0)

    def f2(x, y):
      return x, y
    f2_inv = core.inverse(f2)
    onp.testing.assert_allclose(f2_inv(1.0, 2.0), (1.0, 2.0))
示例#7
0
  def test_mul_inverse(self):
    def f(x):
      return x * 2.
    f_inv = core.inverse(f)
    onp.testing.assert_allclose(f_inv(1.0), 0.5)

    def f2(x):
      return 2. * x
    f2_inv = core.inverse(f2)
    onp.testing.assert_allclose(f2_inv(1.0), 0.5)
示例#8
0
    def test_integer_pow_inverse(self):
        def f(x):
            return lax.integer_pow(x, 2)

        f_inv = core.inverse(f)
        onp.testing.assert_allclose(f_inv(2.), np.sqrt(2.))

        def f2(x):
            return lax.integer_pow(x, 3)

        f2_inv = core.inverse(f2)
        onp.testing.assert_allclose(f2_inv(2.), onp.cbrt(2.))
示例#9
0
  def test_noninvertible_error_should_cause_unary_inverse_to_fail(self):

    @custom_inverse.custom_inverse
    def add_one(x):
      return x + 1.

    def add_one_inv(_):
      raise custom_inverse.NonInvertibleError()

    add_one.def_inverse_unary(add_one_inv)

    with self.assertRaises(ValueError):
      core.inverse(add_one)(1.)
示例#10
0
  def test_unary_inverse(self):

    @custom_inverse.custom_inverse
    def add_one(x):
      return x + 1.

    self.assertEqual(core.inverse(add_one)(1.), 0.)
    def test_advanced_inverse_two(self):
        def f(x, y):
            return np.exp(x), x**2 + y

        f_inv = core.inverse(f, 0.1, 0.2)
        onp.testing.assert_allclose(f_inv(2.0, 2.0),
                                    (np.log(2.), 2 - np.log(2.)**2))
示例#12
0
 def test_advanced_inverse_three(self):
   def f(x, y, z):
     return np.exp(x), x ** 2 + y, np.exp(z + y)
   f_inv = core.inverse(f, 0., 0., 0.)
   onp.testing.assert_allclose(f_inv(2.0, 2.0, 2.0),
                               (np.log(2.), 2 - np.log(2.) ** 2,
                                np.log(2.0) - (2 - np.log(2.) ** 2)))
示例#13
0
 def test_logit_ildj(self):
   def naive_logit(x):
     # This is the default JAX implementation of logit.
     return np.log(x / (1. - x))
   naive_inv = core.inverse(naive_logit)
   naive_ildj = core.ildj(naive_logit)
   with self.assertRaises(ValueError):
     naive_inv(-100.)
   with self.assertRaises(ValueError):
     naive_ildj(-100.)
   f_inv = core.inverse(jax.scipy.special.logit)
   f_ildj = core.ildj(jax.scipy.special.logit)
   onp.testing.assert_allclose(f_inv(-100.), jax.scipy.special.expit(-100.))
   onp.testing.assert_allclose(f_ildj(-100.),
                               tfb.Sigmoid().forward_log_det_jacobian(
                                   -100., 0))
示例#14
0
  def test_sigmoid_ildj(self):
    def naive_sigmoid(x):
      # This is the default JAX implementation of sigmoid.
      return 1. / (1 + np.exp(-x))
    naive_inv = core.inverse(naive_sigmoid)
    naive_ildj = core.ildj(naive_sigmoid)

    with self.assertRaises(AssertionError):
      onp.testing.assert_allclose(naive_inv(0.9999),
                                  jax.scipy.special.logit(0.9999))
    with self.assertRaises(AssertionError):
      onp.testing.assert_allclose(naive_ildj(0.9999),
                                  tfb.Sigmoid().inverse_log_det_jacobian(
                                      0.9999, 0))

    f_inv = core.inverse(jax.nn.sigmoid)
    f_ildj = core.ildj(jax.nn.sigmoid)
    onp.testing.assert_allclose(f_inv(0.9999), jax.scipy.special.logit(0.9999))
    onp.testing.assert_allclose(f_ildj(0.9999),
                                tfb.Sigmoid().inverse_log_det_jacobian(
                                    0.9999, 0))
示例#15
0
  def test_noninvertible_error_should_cause_binary_inverse_to_fail(self):

    @custom_inverse.custom_inverse
    def add(x, y):
      return x + y

    def add_ildj(invals, z, z_ildj):
      x, y = invals
      if z is None:
        raise custom_inverse.NonInvertibleError()
      if x is not None and y is None:
        return (x, z - x), (jnp.zeros_like(z_ildj), z_ildj)
      elif x is None and y is not None:
        # Cannot invert if we don't know x
        raise custom_inverse.NonInvertibleError()
      return (None, None), (0., 0.)

    add.def_inverse_and_ildj(add_ildj)

    core.inverse(lambda x: add(1., x))(2.)
    with self.assertRaises(ValueError):
      core.inverse(lambda x: add(x, 1.))(2.)
示例#16
0
  def test_unary_ildj(self):

    @custom_inverse.custom_inverse
    def add_one(x):
      return x + 1.

    def add_one_ildj(y):
      del y
      return 4.

    add_one.def_inverse_unary(f_ildj=add_one_ildj)
    self.assertEqual(core.inverse(add_one)(2.), 1.)
    self.assertEqual(core.ildj(add_one)(2.), 4.)
示例#17
0
  def test_binary_inverse_and_ildj(self):

    @custom_inverse.custom_inverse
    def add(x, y):
      return x + y

    def add_ildj(invals, z, z_ildj):
      x, y = invals
      if z is None:
        raise custom_inverse.NonInvertibleError()
      if x is not None and y is None:
        return (x, z - x), (jnp.zeros_like(z_ildj), z_ildj)
      elif x is None and y is not None:
        return (z - y, y), (z_ildj, jnp.zeros_like(z_ildj))
      return (None, None), (0., 0.)

    add_one_left = functools.partial(add, 1.)
    add_one_right = lambda x: add(x, 1.)
    self.assertEqual(core.inverse(add_one_left)(2.), 1.)
    add.def_inverse_and_ildj(add_ildj)
    self.assertEqual(core.inverse(add_one_left)(2.), 1.)
    self.assertEqual(core.inverse(add_one_right)(2.), 1.)

    def add_ildj2(invals, z, z_ildj):
      x, y = invals
      if z is None:
        raise custom_inverse.NonInvertibleError()
      if x is not None and y is None:
        return (x, z**2), (jnp.zeros_like(z_ildj), 3. + z_ildj)
      elif x is None and y is not None:
        return (z / 4, y), (2. + z_ildj, jnp.zeros_like(z_ildj))
      return (None, None), (0., 0.)

    add.def_inverse_and_ildj(add_ildj2)
    self.assertEqual(core.inverse(add_one_left)(2.), 4.)
    self.assertEqual(core.ildj(add_one_left)(2.), 3.)

    self.assertEqual(core.inverse(add_one_right)(2.), 0.5)
    self.assertEqual(core.ildj(add_one_right)(2.), 2.)
示例#18
0
    def test_pow_ildj(self):
        def f(x, y):
            return lax.pow(x, y)

        f_x_ildj = core.ildj(lambda x: f(x, 2.))
        onp.testing.assert_allclose(f_x_ildj(3.),
                                    tfb.Power(2.).inverse_log_det_jacobian(3.))
        f_y_ildj = core.ildj(lambda y: f(2., y))
        f_y_inv = core.inverse(lambda y: f(2., y))
        y = f_y_inv(3.)
        onp.testing.assert_allclose(
            f_y_ildj(3.), -np.log(np.abs(jax.grad(lambda y: f(2., y))(y))))
        onp.testing.assert_allclose(f_y_ildj(3.),
                                    np.log(np.abs(jax.grad(f_y_inv)(3.))))
示例#19
0
    def def_inverse_unary(self, f_inv=None, f_ildj=None):
        """Defines a unary inverse rule.

    Args:
      f_inv: An optional unary function that returns the inverse of this
        function. If not provided, we automatically invert the forward function.
      f_ildj: An optional unary function that computes the ILDJ of this
        function. If it is not provided, we automatically compute it from
        `f_inv`.
    """
        f_inv_ = f_inv or core.inverse(self.func)
        f_ildj_ = f_ildj or core.ildj(core.inverse(f_inv_), reduce_ildj=False)

        def f_inverse_and_ildj(invals, outval, outildj):
            inval = invals[0]
            if outval is None:
                raise NonInvertibleError()
            if inval is None:
                new_inval = f_inv_(outval)
                new_inildj = f_ildj_(outval) + outildj
                return (new_inval, ), (new_inildj, )
            raise NonInvertibleError()

        self.def_inverse_and_ildj(f_inverse_and_ildj)
示例#20
0
  def test_unary_inverse_and_ildj(self):

    @custom_inverse.custom_inverse
    def add_one(x):
      return x + 1.

    def add_one_inv(y):
      return jnp.exp(y)

    def add_one_ildj(y):
      del y
      return 4.

    add_one.def_inverse_unary(add_one_inv, add_one_ildj)
    self.assertEqual(core.inverse(add_one)(2.), jnp.exp(2.))
    self.assertEqual(core.ildj(add_one)(2.), 4.)
    def test_conditional_inverse(self):
        def f(x, y):
            return x + 1., np.exp(x + 1.) + y

        f_inv = core.inverse(f, 0., 2.)
        onp.testing.assert_allclose(f_inv(0., 2.), (-1., 1.))
示例#22
0
    def test_sqrt_inverse(self):
        def f(x):
            return np.sqrt(x)

        f_inv = core.inverse(f)
        onp.testing.assert_allclose(f_inv(2.), 4.)
示例#23
0
    def test_reciprocal_inverse(self):
        def f(x):
            return np.reciprocal(x)

        f_inv = core.inverse(f)
        onp.testing.assert_allclose(f_inv(2.), 0.5)
    def test_sow_happens_in_forward_pass(self):
        def f(x, y):
            return x, harvest.sow(x, name='x', tag='foo') * y

        vals = harvest.reap(core.inverse(f), tag='foo')(1., 1.)
        self.assertDictEqual(vals, dict(x=1.))
    def test_noninvertible(self):
        def f(x, y):
            return x + y, x + y

        with self.assertRaises(ValueError):
            core.inverse(f)(1., 2.)