Example #1
0
 def test_kl_divergence_raises_on_event_shape(self):
   base_dist1 = tfd.MultivariateNormalDiag([0.1, 0.5, 0.9], [0.1, 1.1, 2.5])
   base_dist2 = tfd.Normal(-0.1, 1.5)
   bij1 = block.Block(lambda_bijector.Lambda(lambda x: x), ndims=1)
   bij2 = lambda_bijector.Lambda(lambda x: x)
   distrax_dist1 = transformed.Transformed(base_dist1, bij1)
   distrax_dist2 = transformed.Transformed(base_dist2, bij2)
   with self.assertRaises(ValueError):
     distrax_dist1.kl_divergence(distrax_dist2)
Example #2
0
def as_bijector(obj: BijectorLike) -> bijector.BijectorT:
    """Converts a bijector-like object to a Distrax bijector.

  Bijector-like objects are: Distrax bijectors, TFP bijectors, and callables.
  Distrax bijectors are returned unchanged. TFP bijectors are converted to a
  Distrax equivalent. Callables are wrapped by `distrax.Lambda`, with a few
  exceptions where an explicit implementation already exists and is returned.

  Args:
    obj: The bijector-like object to be converted.

  Returns:
    A Distrax bijector.
  """
    if isinstance(obj, bijector.Bijector):
        return obj
    elif isinstance(obj, tfb.Bijector):
        return bijector_from_tfp.BijectorFromTFP(obj)
    elif obj is jax.nn.sigmoid:
        return sigmoid.Sigmoid()
    elif obj is jnp.tanh:
        return tanh.Tanh()
    elif callable(obj):
        return lambda_bijector.Lambda(obj)
    else:
        raise TypeError(
            f"A bijector-like object can be a `distrax.Bijector`, a `tfb.Bijector`,"
            f" or a callable. Got type `{type(obj)}`.")
    def test_sample_and_log_prob(self, mu, sigma, sample_shape, base_dist):
        base = base_dist(mu, sigma)
        bijector = lambda_bijector.Lambda(lambda x: 10 * jnp.tanh(0.1 * x))
        dist = transformed.Transformed(base, bijector)

        def sample_and_log_prob_fn(seed, sample_shape):
            return dist.sample_and_log_prob(seed=seed,
                                            sample_shape=sample_shape)

        samples, log_prob = self.variant(sample_and_log_prob_fn,
                                         ignore_argnums=(1, ),
                                         static_argnums=(1, ))(self.seed,
                                                               sample_shape)
        expected_samples = bijector.forward(
            base.sample(seed=self.seed, sample_shape=sample_shape))

        tfp_bijector = tfb.Chain([tfb.Scale(10), tfb.Tanh(), tfb.Scale(0.1)])
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)
        tfp_samples = tfp_dist.sample(seed=self.seed,
                                      sample_shape=sample_shape)
        tfp_log_prob = tfp_dist.log_prob(samples)

        chex.assert_equal_shape([samples, tfp_samples])
        np.testing.assert_allclose(log_prob, tfp_log_prob, rtol=RTOL)
        np.testing.assert_allclose(samples, expected_samples, rtol=RTOL)
Example #4
0
  def test_kl_divergence(self, mode_string):
    base_dist1 = tfd.Normal([0.1, 0.5, 0.9], [0.1, 1.1, 2.5])
    base_dist2 = tfd.Normal(-0.1, 1.5)
    bij_tfp1 = tfb.Identity()
    bij_tfp2 = tfb.Identity()
    bij_distrax1 = bij_tfp1
    bij_distrax2 = lambda_bijector.Lambda(lambda x: x)
    tfp_dist1 = tfd.TransformedDistribution(base_dist1, bij_tfp1)
    tfp_dist2 = tfd.TransformedDistribution(base_dist2, bij_tfp2)
    distrax_dist1 = transformed.Transformed(base_dist1, bij_distrax1)
    distrax_dist2 = transformed.Transformed(base_dist2, bij_distrax2)

    expected_result_fwd = base_dist1.kl_divergence(base_dist2)
    expected_result_inv = base_dist2.kl_divergence(base_dist1)

    distrax_fn1 = self.variant(distrax_dist1.kl_divergence)
    distrax_fn2 = self.variant(distrax_dist2.kl_divergence)

    if mode_string == 'distrax_to_distrax':
      result_fwd = distrax_fn1(distrax_dist2)
      result_inv = distrax_fn2(distrax_dist1)
    elif mode_string == 'distrax_to_tfp':
      result_fwd = distrax_fn1(tfp_dist2)
      result_inv = distrax_fn2(tfp_dist1)
    elif mode_string == 'tfp_to_distrax':
      result_fwd = tfp_dist1.kl_divergence(distrax_dist2)
      result_inv = tfp_dist2.kl_divergence(distrax_dist1)

    np.testing.assert_allclose(result_fwd, expected_result_fwd, rtol=RTOL)
    np.testing.assert_allclose(result_inv, expected_result_inv, rtol=RTOL)
    def test_jittable(self):
        @jax.jit
        def f(x, b):
            return b.forward(x)

        bijector = lambda_bijector.Lambda(lambda x: x)
        x = np.zeros(())
        f(x, bijector)
Example #6
0
 def test_kl_divergence_raises_on_different_bijectors(self):
   base_dist1 = tfd.Normal([0.1, 0.5, 0.9], [0.1, 1.1, 2.5])
   base_dist2 = tfd.Normal(-0.1, 1.5)
   bij1 = lambda_bijector.Lambda(lambda x: x)
   bij2 = sigmoid.Sigmoid()
   distrax_dist1 = transformed.Transformed(base_dist1, bij1)
   distrax_dist2 = transformed.Transformed(base_dist2, bij2)
   with self.assertRaises(NotImplementedError):
     distrax_dist1.kl_divergence(distrax_dist2)
    def test_event_shape(self, mu, sigma, base_dist):
        base = base_dist(mu, sigma)
        bijector = lambda_bijector.Lambda(jnp.tanh)
        dist = transformed.Transformed(base, bijector)

        tfp_bijector = tfb.Tanh()
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)

        assert dist.event_shape == tfp_dist.event_shape
    def test_prob(self, mu, sigma, value, base_dist):
        base = base_dist(mu, sigma)
        bijector = lambda_bijector.Lambda(jnp.tanh)
        dist = transformed.Transformed(base, bijector)
        actual = self.variant(dist.prob)(value)

        tfp_bijector = tfb.Tanh()
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)
        expected = tfp_dist.prob(value)

        np.testing.assert_allclose(actual, expected, rtol=RTOL)
    def test_log_dets(self, lambda_bjct, tfp_bijector_fn):
        bijector = lambda_bijector.Lambda(lambda_bjct)
        tfp_bijector = tfp_bijector_fn()

        x = np.array([0.05, 0.3, 0.45], dtype=np.float32)
        fldj = tfp_bijector.forward_log_det_jacobian(x, event_ndims=0)
        fldj_ = self.variant(bijector.forward_log_det_jacobian)(x)
        np.testing.assert_allclose(fldj_, fldj, rtol=RTOL)

        y = bijector.forward(x)
        ildj = tfp_bijector.inverse_log_det_jacobian(y, event_ndims=0)
        ildj_ = self.variant(bijector.inverse_log_det_jacobian)(y)
        np.testing.assert_allclose(ildj_, ildj, rtol=RTOL)
    def test_method(self, function_string, mu, sigma, base_dist):
        base = base_dist(mu, sigma)
        bijector = lambda_bijector.Lambda(lambda x: x + 3)
        dist = transformed.Transformed(base, bijector)

        tfp_bijector = tfb.Shift(3)
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)

        np.testing.assert_allclose(self.variant(getattr(dist,
                                                        function_string))(),
                                   getattr(tfp_dist, function_string)(),
                                   rtol=RTOL)
 def test_raises_on_invalid_input_shape(self):
     bij = lambda_bijector.Lambda(
         forward=lambda x: x,
         inverse=lambda y: y,
         forward_log_det_jacobian=lambda x: jnp.zeros_like(x[:-1]),
         inverse_log_det_jacobian=lambda y: jnp.zeros_like(y[:-1]),
         event_ndims_in=1)
     for fn in [
             bij.forward, bij.inverse, bij.forward_log_det_jacobian,
             bij.inverse_log_det_jacobian, bij.forward_and_log_det,
             bij.inverse_and_log_det
     ]:
         with self.assertRaises(ValueError):
             fn(jnp.array(0))
    def test_against_tfp_bijectors(self, lambda_bjct, tfp_bijector, base_dist):
        mu = np.array([-1.0, 0.0, 1.0], dtype=np.float32)
        sigma = np.array([0.5, 1.0, 2.5], dtype=np.float32)
        base = base_dist(mu, sigma)

        bijector = lambda_bijector.Lambda(lambda_bjct)
        dist = transformed.Transformed(base, bijector)
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector())

        y = np.array([0.05, 0.3, 0.95], dtype=np.float32)

        lp_y = tfp_dist.log_prob(y)
        lp_y_ = self.variant(dist.log_prob)(y)
        np.testing.assert_allclose(lp_y_, lp_y, rtol=RTOL)

        p_y = tfp_dist.prob(y)
        p_y_ = self.variant(dist.prob)(y)
        np.testing.assert_allclose(p_y_, p_y, rtol=RTOL)
    def test_sample_shape(self, mu, sigma, sample_shape, base_dist):
        base = base_dist(mu, sigma)
        bijector = lambda_bijector.Lambda(jnp.tanh)
        dist = transformed.Transformed(base, bijector)

        def sample_fn(seed, sample_shape):
            return dist.sample(seed=seed, sample_shape=sample_shape)

        samples = self.variant(sample_fn,
                               ignore_argnums=(1, ),
                               static_argnums=1)(self.seed, sample_shape)

        tfp_bijector = tfb.Tanh()
        tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base),
                                               tfp_bijector)
        tfp_samples = tfp_dist.sample(sample_shape=sample_shape,
                                      seed=self.seed)

        chex.assert_equal_shape([samples, tfp_samples])
 def test_raises_on_event_ndims_without_log_det(self, ndims_in, ndims_out):
     with self.assertRaises(ValueError):
         lambda_bijector.Lambda(forward=lambda x: x,
                                event_ndims_in=ndims_in,
                                event_ndims_out=ndims_out)
 def test_raises_on_log_det_without_event_ndims(self):
     with self.assertRaises(ValueError):
         lambda_bijector.Lambda(
             forward=lambda x: x,
             forward_log_det_jacobian=lambda x: jnp.zeros_like(x[:-1]),
             event_ndims_in=None)
 def test_raises_on_both_none(self):
     with self.assertRaises(ValueError):
         lambda_bijector.Lambda(forward=None, inverse=None)