def test_works_with_tfp_caching(self, tfp_bij_name, batch_shape_in,
                                    event_shape_in, batch_shape_out,
                                    event_shape_out):
        tfp_bij = self._test_bijectors[tfp_bij_name]
        bij = bijector_from_tfp.BijectorFromTFP(tfp_bij)
        key1, key2 = jax.random.split(jax.random.PRNGKey(42))

        # Forward caching.
        x = jax.random.uniform(key1, batch_shape_in + event_shape_in)
        y = self.variant(bij.forward)(x)
        x1 = self.variant(bij.inverse)(y)
        logdet1 = self.variant(bij.inverse_log_det_jacobian)(y)
        x2, logdet2 = self.variant(bij.inverse_and_log_det)(y)
        self.assertEqual(x1.shape, x2.shape)
        self.assertEqual(logdet1.shape, logdet2.shape)
        np.testing.assert_allclose(x1, x2, atol=1e-8)
        np.testing.assert_allclose(logdet1, logdet2, atol=1e-8)

        # Inverse caching.
        y = jax.random.uniform(key2, batch_shape_out + event_shape_out)
        x = self.variant(bij.inverse)(y)
        y1 = self.variant(bij.forward)(x)
        logdet1 = self.variant(bij.forward_log_det_jacobian)(x)
        y2, logdet2 = self.variant(bij.forward_and_log_det)(x)
        self.assertEqual(y1.shape, y2.shape)
        self.assertEqual(logdet1.shape, logdet2.shape)
        np.testing.assert_allclose(y1, y2, atol=1e-8)
        np.testing.assert_allclose(logdet1, logdet2, atol=1e-8)
Exemple #2
0
def as_bijector(obj: BijectorLike) -> bijector.BijectorT:
    """Converts a bijector-like object to a Distrax bijector.

  Bijector-like objects are: Distrax bijectors, TFP bijectors, and callables.
  Distrax bijectors are returned unchanged. TFP bijectors are converted to a
  Distrax equivalent. Callables are wrapped by `distrax.Lambda`, with a few
  exceptions where an explicit implementation already exists and is returned.

  Args:
    obj: The bijector-like object to be converted.

  Returns:
    A Distrax bijector.
  """
    if isinstance(obj, bijector.Bijector):
        return obj
    elif isinstance(obj, tfb.Bijector):
        return bijector_from_tfp.BijectorFromTFP(obj)
    elif obj is jax.nn.sigmoid:
        return sigmoid.Sigmoid()
    elif obj is jnp.tanh:
        return tanh.Tanh()
    elif callable(obj):
        return lambda_bijector.Lambda(obj)
    else:
        raise TypeError(
            f"A bijector-like object can be a `distrax.Bijector`, a `tfb.Bijector`,"
            f" or a callable. Got type `{type(obj)}`.")
    def test_jittable(self):
        @jax.jit
        def f(x, b):
            return b.forward(x)

        bijector = bijector_from_tfp.BijectorFromTFP(tfb.Tanh())
        x = np.zeros(())
        f(x, bijector)
 def test_access_properties_tfp_bijector(self):
     tfp_bij = self._test_bijectors['BatchedChain']
     bij = bijector_from_tfp.BijectorFromTFP(tfp_bij)
     # Access the attribute `bijectors`
     np.testing.assert_allclose(bij.bijectors[0].shift,
                                tfp_bij.bijectors[0].shift,
                                atol=1e-8)
     np.testing.assert_allclose(bij.bijectors[1].scale.diag,
                                tfp_bij.bijectors[1].scale.diag,
                                atol=1e-8)
    def test_forward_methods_are_correct(self, tfp_bij_name, batch_shape_in,
                                         event_shape_in, batch_shape_out,
                                         event_shape_out):
        tfp_bij = self._test_bijectors[tfp_bij_name]
        bij = bijector_from_tfp.BijectorFromTFP(tfp_bij)
        key = jax.random.PRNGKey(42)
        x = jax.random.uniform(key, batch_shape_in + event_shape_in)

        y = self.variant(bij.forward)(x)
        logdet = self.variant(bij.forward_log_det_jacobian)(x)
        y_tfp = tfp_bij.forward(x)
        logdet_tfp = tfp_bij.forward_log_det_jacobian(x, len(event_shape_in))
        logdet_tfp = jnp.broadcast_to(logdet_tfp, batch_shape_out)

        self.assertEqual(y.shape, batch_shape_out + event_shape_out)
        self.assertEqual(logdet.shape, batch_shape_out)
        np.testing.assert_allclose(y, y_tfp, atol=1e-8)
        np.testing.assert_allclose(logdet, logdet_tfp, atol=1e-4)