Exemple #1
0
def to_tfp(obj: Union[bijector.Bijector, tfb.Bijector,
                      distribution.Distribution, tfd.Distribution],
           name: Optional[str] = None):
    """Converts a distribution or bijector to a TFP-compatible equivalent object.

  The returned object is not necessarily of type `tfb.Bijector` or
  `tfd.Distribution`; rather, it is a Distrax object that implements TFP
  functionality so that it can be used in TFP.

  If the input is already of TFP type, it is returned unchanged.

  Args:
    obj: The distribution or bijector to be converted to TFP.
    name: The name of the resulting object.

  Returns:
    A TFP-compatible equivalent distribution or bijector.
  """
    if isinstance(obj, (tfb.Bijector, tfd.Distribution)):
        return obj
    elif isinstance(obj, bijector.Bijector):
        return tfp_compatible_bijector.tfp_compatible_bijector(obj, name)
    elif isinstance(obj, distribution.Distribution):
        return tfp_compatible_distribution.tfp_compatible_distribution(
            obj, name)
    else:
        raise TypeError(
            f"`to_tfp` can only convert objects of type: `distrax.Bijector`,"
            f" `tfb.Bijector`, `distrax.Distribution`, `tfd.Distribution`. Got type"
            f" `{type(obj)}`.")
    def test_with_different_event_ndims(self):
        dx_bij = Lambda(forward=lambda x: x.reshape(x.shape[:-1] + (2, 3)),
                        inverse=lambda y: y.reshape(y.shape[:-2] + (6, )),
                        forward_log_det_jacobian=lambda _: 0,
                        inverse_log_det_jacobian=lambda _: 0,
                        is_constant_jacobian=True,
                        event_ndims_in=1,
                        event_ndims_out=2)
        tfp_bij = tfp_compatible_bijector(dx_bij)

        with self.subTest('forward_event_ndims'):
            assert tfp_bij.forward_event_ndims(1) == 2
            assert tfp_bij.forward_event_ndims(2) == 3

        with self.subTest('inverse_event_ndims'):
            assert tfp_bij.inverse_event_ndims(2) == 1
            assert tfp_bij.inverse_event_ndims(3) == 2

        with self.subTest('forward_event_ndims with incorrect input'):
            with self.assertRaises(ValueError):
                tfp_bij.forward_event_ndims(0)

        with self.subTest('inverse_event_ndims with incorrect input'):
            with self.assertRaises(ValueError):
                tfp_bij.inverse_event_ndims(0)

            with self.assertRaises(ValueError):
                tfp_bij.inverse_event_ndims(1)

        with self.subTest('forward_event_shape'):
            y_shape = tfp_bij.forward_event_shape((6, ))
            y_shape_tensor = tfp_bij.forward_event_shape_tensor((6, ))
            self.assertEqual(y_shape, (2, 3))
            np.testing.assert_array_equal(y_shape_tensor, jnp.array((2, 3)))

        with self.subTest('inverse_event_shape'):
            x_shape = tfp_bij.inverse_event_shape((2, 3))
            x_shape_tensor = tfp_bij.inverse_event_shape_tensor((2, 3))
            self.assertEqual(x_shape, (6, ))
            np.testing.assert_array_equal(x_shape_tensor, jnp.array((6, )))

        with self.subTest('TransformedDistribution with correct event_ndims'):
            base = tfd.MultivariateNormalDiag(np.zeros(6), np.ones(6))
            dist = tfd.TransformedDistribution(base, tfp_bij)
            chex.assert_equal(dist.event_shape, (2, 3))

            sample = dist.sample(seed=jax.random.PRNGKey(0))
            chex.assert_shape(sample, (2, 3))

            log_prob = dist.log_prob(sample)
            chex.assert_shape(log_prob, ())

        with self.subTest(
                'TransformedDistribution with incorrect event_ndims'):
            base = tfd.Normal(np.zeros(6), np.ones(6))
            dist = tfd.TransformedDistribution(base, tfp_bij)
            with self.assertRaises(ValueError):
                _ = dist.event_shape
    def test_batched_events(self, bij_fn, batch_shape):
        base = tfd.MultivariateNormalDiag(np.zeros(batch_shape + (3, )),
                                          np.ones(batch_shape + (3, )))
        bij = tfp_compatible_bijector(bij_fn())
        dist = tfd.TransformedDistribution(base, bij)

        with self.subTest('sample'):
            sample = dist.sample(seed=jax.random.PRNGKey(0))
            chex.assert_shape(sample, batch_shape + (3, ))

        with self.subTest('log_prob'):
            sample = dist.sample(seed=jax.random.PRNGKey(0))
            log_prob = dist.log_prob(sample)
            chex.assert_shape(log_prob, batch_shape)
    def test_forward_and_inverse(self, dx_bijector_fn, tfp_bijector_fn, event):
        dx_bij = tfp_compatible_bijector(dx_bijector_fn())
        tfp_bij = tfp_bijector_fn()

        with self.subTest('forward'):
            dx_out = dx_bij.forward(event)
            tfp_out = tfp_bij.forward(event)
            np.testing.assert_allclose(dx_out, tfp_out, rtol=RTOL)

        with self.subTest('inverse'):
            y = tfp_bij.forward(event)
            dx_out = dx_bij.inverse(y)
            tfp_out = tfp_bij.inverse(y)
            np.testing.assert_allclose(dx_out, tfp_out, rtol=RTOL)
    def test_invert(self, dx_bijector_fn, tfb_bijector_fn, event):
        dx_bij = tfp_compatible_bijector(dx_bijector_fn())
        tfp_bij = tfb_bijector_fn()

        invert_with_dx = tfb.Invert(dx_bij)
        invert_with_tfp = tfb.Invert(tfp_bij)

        with self.subTest('forward'):
            y_dx = invert_with_dx.forward(event)
            y_tfp = invert_with_tfp.forward(event)
            np.testing.assert_allclose(y_dx, y_tfp, rtol=RTOL)

        with self.subTest('inverse'):
            y = invert_with_tfp.forward(event)
            x_dx = invert_with_dx.inverse(y)
            np.testing.assert_allclose(x_dx, event, rtol=RTOL)
    def test_chain(self, dx_bijector_fn, tfb_bijector_fn, event):
        dx_bij = tfp_compatible_bijector(dx_bijector_fn())
        tfp_bij = tfb_bijector_fn()

        chain_with_dx = tfb.Chain([tfb.Shift(1.0), tfb.Scale(3.0), dx_bij])
        chain_with_tfp = tfb.Chain([tfb.Shift(1.0), tfb.Scale(3.0), tfp_bij])

        with self.subTest('forward'):
            y_dx = chain_with_dx.forward(event)
            y_tfp = chain_with_tfp.forward(event)
            np.testing.assert_allclose(y_dx, y_tfp, rtol=RTOL)

        with self.subTest('inverse'):
            y = chain_with_tfp.forward(event)
            x_dx = chain_with_dx.inverse(y)
            np.testing.assert_allclose(x_dx, event, rtol=RTOL)
    def test_log_det_jacobian(self, dx_bijector_fn, tfp_bijector_fn, event):
        base_bij = dx_bijector_fn()
        dx_bij = tfp_compatible_bijector(base_bij)
        tfp_bij = tfp_bijector_fn()

        with self.subTest('forward'):
            dx_out = dx_bij.forward_log_det_jacobian(
                event, event_ndims=base_bij.event_ndims_in)
            tfp_out = tfp_bij.forward_log_det_jacobian(
                event, event_ndims=base_bij.event_ndims_in)
            np.testing.assert_allclose(dx_out, tfp_out, rtol=RTOL)

        with self.subTest('inverse'):
            y = tfp_bij.forward(event)
            dx_out = dx_bij.inverse_log_det_jacobian(
                y, event_ndims=base_bij.event_ndims_out)
            tfp_out = tfp_bij.inverse_log_det_jacobian(
                y, event_ndims=base_bij.event_ndims_out)
            np.testing.assert_allclose(dx_out, tfp_out, rtol=RTOL)
  def test_transformed_distribution(
      self, dx_bijector_fn, tfp_bijector_fn, sample_shape):
    base_dist = tfd.MultivariateNormalDiag(np.zeros((3, 2)), np.ones((3, 2)))
    dx_bijector = dx_bijector_fn()
    wrapped_bijector = tfp_compatible_bijector(dx_bijector)
    tfp_bijector = tfp_bijector_fn()
    dist_with_wrapped = tfd.TransformedDistribution(base_dist, wrapped_bijector)
    dist_tfp_only = tfd.TransformedDistribution(base_dist, tfp_bijector)

    with self.subTest('sample'):
      dist_with_wrapped.sample(
          seed=jax.random.PRNGKey(0), sample_shape=sample_shape)

    with self.subTest('log_prob'):
      y = dist_tfp_only.sample(
          seed=jax.random.PRNGKey(0), sample_shape=sample_shape)
      log_prob_wrapped = dist_with_wrapped.log_prob(y)
      log_prob_tfp_only = dist_tfp_only.log_prob(y)
      np.testing.assert_allclose(log_prob_wrapped, log_prob_tfp_only, rtol=RTOL)