def as_bijector(obj: BijectorLike) -> bijector.BijectorT: """Converts a bijector-like object to a Distrax bijector. Bijector-like objects are: Distrax bijectors, TFP bijectors, and callables. Distrax bijectors are returned unchanged. TFP bijectors are converted to a Distrax equivalent. Callables are wrapped by `distrax.Lambda`, with a few exceptions where an explicit implementation already exists and is returned. Args: obj: The bijector-like object to be converted. Returns: A Distrax bijector. """ if isinstance(obj, bijector.Bijector): return obj elif isinstance(obj, tfb.Bijector): return bijector_from_tfp.BijectorFromTFP(obj) elif obj is jax.nn.sigmoid: return sigmoid.Sigmoid() elif obj is jnp.tanh: return tanh.Tanh() elif callable(obj): return lambda_bijector.Lambda(obj) else: raise TypeError( f"A bijector-like object can be a `distrax.Bijector`, a `tfb.Bijector`," f" or a callable. Got type `{type(obj)}`.")
def test_jittable(self): @jax.jit def f(x, b): return b.forward(x) bijector = tanh.Tanh() x = np.zeros(()) f(x, bijector)
def test_integer_inputs(self, inputs): bijector = tanh.Tanh() output, log_det = self.variant(bijector.forward_and_log_det)(inputs) expected_out = jnp.tanh(inputs).astype(jnp.float32) expected_log_det = jnp.zeros_like(inputs, dtype=jnp.float32) np.testing.assert_array_equal(output, expected_out) np.testing.assert_array_equal(log_det, expected_log_det)
def test_event_shape(self, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = tanh.Tanh() dist = transformed.Transformed(base, bijector) tfp_bijector = tfb.Tanh() tfp_dist = tfd.TransformedDistribution( conversion.to_tfp(base), tfp_bijector) assert dist.event_shape == tfp_dist.event_shape
def test_prob(self, mu, sigma, value, base_dist): base = base_dist(mu, sigma) bijector = tanh.Tanh() dist = transformed.Transformed(base, bijector) actual = self.variant(dist.prob)(value) tfp_bijector = tfb.Tanh() tfp_dist = tfd.TransformedDistribution( conversion.to_tfp(base), tfp_bijector) expected = tfp_dist.prob(value) np.testing.assert_allclose(actual, expected, atol=1e-9)
def test_stability(self): bijector = tanh.Tanh() tfp_bijector = tfb.Tanh() x = np.array([-10.0, -3.3, 0.0, 3.3, 10.0], dtype=np.float32) fldj = tfp_bijector.forward_log_det_jacobian(x, event_ndims=0) fldj_ = self.variant(bijector.forward_log_det_jacobian)(x) np.testing.assert_allclose(fldj_, fldj, rtol=RTOL) y = bijector.forward(x) ildj = tfp_bijector.inverse_log_det_jacobian(y, event_ndims=0) ildj_ = self.variant(bijector.inverse_log_det_jacobian)(y) np.testing.assert_allclose(ildj_, ildj, rtol=RTOL)
def test_sample_shape(self, mu, sigma, sample_shape, base_dist): base = base_dist(mu, sigma) bijector = tanh.Tanh() dist = transformed.Transformed(base, bijector) def sample_fn(seed, sample_shape): return dist.sample(seed=seed, sample_shape=sample_shape) samples = self.variant(sample_fn, ignore_argnums=(1,), static_argnums=1)( self.seed, sample_shape) tfp_bijector = tfb.Tanh() tfp_dist = tfd.TransformedDistribution( conversion.to_tfp(base), tfp_bijector) tfp_samples = tfp_dist.sample(sample_shape=sample_shape, seed=self.seed) chex.assert_equal_shape([samples, tfp_samples])
def test_sample_and_log_prob(self, mu, sigma, sample_shape, base_dist): base = base_dist(mu, sigma) bijector = tanh.Tanh() dist = transformed.Transformed(base, bijector) def sample_and_log_prob_fn(seed, sample_shape): return dist.sample_and_log_prob(seed=seed, sample_shape=sample_shape) samples, log_prob = self.variant( sample_and_log_prob_fn, ignore_argnums=(1,), static_argnums=(1,))( self.seed, sample_shape) expected_samples = bijector.forward( base.sample(seed=self.seed, sample_shape=sample_shape)) tfp_bijector = tfb.Tanh() tfp_dist = tfd.TransformedDistribution( conversion.to_tfp(base), tfp_bijector) tfp_samples = tfp_dist.sample(seed=self.seed, sample_shape=sample_shape) tfp_log_prob = tfp_dist.log_prob(samples) chex.assert_equal_shape([samples, tfp_samples]) np.testing.assert_allclose(log_prob, tfp_log_prob, rtol=RTOL) np.testing.assert_allclose(samples, expected_samples, rtol=RTOL)