def test_works_with_tfp_caching(self, tfp_bij_name, batch_shape_in, event_shape_in, batch_shape_out, event_shape_out): tfp_bij = self._test_bijectors[tfp_bij_name] bij = bijector_from_tfp.BijectorFromTFP(tfp_bij) key1, key2 = jax.random.split(jax.random.PRNGKey(42)) # Forward caching. x = jax.random.uniform(key1, batch_shape_in + event_shape_in) y = self.variant(bij.forward)(x) x1 = self.variant(bij.inverse)(y) logdet1 = self.variant(bij.inverse_log_det_jacobian)(y) x2, logdet2 = self.variant(bij.inverse_and_log_det)(y) self.assertEqual(x1.shape, x2.shape) self.assertEqual(logdet1.shape, logdet2.shape) np.testing.assert_allclose(x1, x2, atol=1e-8) np.testing.assert_allclose(logdet1, logdet2, atol=1e-8) # Inverse caching. y = jax.random.uniform(key2, batch_shape_out + event_shape_out) x = self.variant(bij.inverse)(y) y1 = self.variant(bij.forward)(x) logdet1 = self.variant(bij.forward_log_det_jacobian)(x) y2, logdet2 = self.variant(bij.forward_and_log_det)(x) self.assertEqual(y1.shape, y2.shape) self.assertEqual(logdet1.shape, logdet2.shape) np.testing.assert_allclose(y1, y2, atol=1e-8) np.testing.assert_allclose(logdet1, logdet2, atol=1e-8)
def as_bijector(obj: BijectorLike) -> bijector.BijectorT: """Converts a bijector-like object to a Distrax bijector. Bijector-like objects are: Distrax bijectors, TFP bijectors, and callables. Distrax bijectors are returned unchanged. TFP bijectors are converted to a Distrax equivalent. Callables are wrapped by `distrax.Lambda`, with a few exceptions where an explicit implementation already exists and is returned. Args: obj: The bijector-like object to be converted. Returns: A Distrax bijector. """ if isinstance(obj, bijector.Bijector): return obj elif isinstance(obj, tfb.Bijector): return bijector_from_tfp.BijectorFromTFP(obj) elif obj is jax.nn.sigmoid: return sigmoid.Sigmoid() elif obj is jnp.tanh: return tanh.Tanh() elif callable(obj): return lambda_bijector.Lambda(obj) else: raise TypeError( f"A bijector-like object can be a `distrax.Bijector`, a `tfb.Bijector`," f" or a callable. Got type `{type(obj)}`.")
def test_jittable(self): @jax.jit def f(x, b): return b.forward(x) bijector = bijector_from_tfp.BijectorFromTFP(tfb.Tanh()) x = np.zeros(()) f(x, bijector)
def test_access_properties_tfp_bijector(self): tfp_bij = self._test_bijectors['BatchedChain'] bij = bijector_from_tfp.BijectorFromTFP(tfp_bij) # Access the attribute `bijectors` np.testing.assert_allclose(bij.bijectors[0].shift, tfp_bij.bijectors[0].shift, atol=1e-8) np.testing.assert_allclose(bij.bijectors[1].scale.diag, tfp_bij.bijectors[1].scale.diag, atol=1e-8)
def test_forward_methods_are_correct(self, tfp_bij_name, batch_shape_in, event_shape_in, batch_shape_out, event_shape_out): tfp_bij = self._test_bijectors[tfp_bij_name] bij = bijector_from_tfp.BijectorFromTFP(tfp_bij) key = jax.random.PRNGKey(42) x = jax.random.uniform(key, batch_shape_in + event_shape_in) y = self.variant(bij.forward)(x) logdet = self.variant(bij.forward_log_det_jacobian)(x) y_tfp = tfp_bij.forward(x) logdet_tfp = tfp_bij.forward_log_det_jacobian(x, len(event_shape_in)) logdet_tfp = jnp.broadcast_to(logdet_tfp, batch_shape_out) self.assertEqual(y.shape, batch_shape_out + event_shape_out) self.assertEqual(logdet.shape, batch_shape_out) np.testing.assert_allclose(y, y_tfp, atol=1e-8) np.testing.assert_allclose(logdet, logdet_tfp, atol=1e-4)