def test_kl_divergence_raises_on_event_shape(self): base_dist1 = tfd.MultivariateNormalDiag([0.1, 0.5, 0.9], [0.1, 1.1, 2.5]) base_dist2 = tfd.Normal(-0.1, 1.5) bij1 = block.Block(lambda_bijector.Lambda(lambda x: x), ndims=1) bij2 = lambda_bijector.Lambda(lambda x: x) distrax_dist1 = transformed.Transformed(base_dist1, bij1) distrax_dist2 = transformed.Transformed(base_dist2, bij2) with self.assertRaises(ValueError): distrax_dist1.kl_divergence(distrax_dist2)
def as_bijector(obj: BijectorLike) -> bijector.BijectorT: """Converts a bijector-like object to a Distrax bijector. Bijector-like objects are: Distrax bijectors, TFP bijectors, and callables. Distrax bijectors are returned unchanged. TFP bijectors are converted to a Distrax equivalent. Callables are wrapped by `distrax.Lambda`, with a few exceptions where an explicit implementation already exists and is returned. Args: obj: The bijector-like object to be converted. Returns: A Distrax bijector. """ if isinstance(obj, bijector.Bijector): return obj elif isinstance(obj, tfb.Bijector): return bijector_from_tfp.BijectorFromTFP(obj) elif obj is jax.nn.sigmoid: return sigmoid.Sigmoid() elif obj is jnp.tanh: return tanh.Tanh() elif callable(obj): return lambda_bijector.Lambda(obj) else: raise TypeError( f"A bijector-like object can be a `distrax.Bijector`, a `tfb.Bijector`," f" or a callable. Got type `{type(obj)}`.")
def test_sample_and_log_prob(self, mu, sigma, sample_shape, base_dist): base = base_dist(mu, sigma) bijector = lambda_bijector.Lambda(lambda x: 10 * jnp.tanh(0.1 * x)) dist = transformed.Transformed(base, bijector) def sample_and_log_prob_fn(seed, sample_shape): return dist.sample_and_log_prob(seed=seed, sample_shape=sample_shape) samples, log_prob = self.variant(sample_and_log_prob_fn, ignore_argnums=(1, ), static_argnums=(1, ))(self.seed, sample_shape) expected_samples = bijector.forward( base.sample(seed=self.seed, sample_shape=sample_shape)) tfp_bijector = tfb.Chain([tfb.Scale(10), tfb.Tanh(), tfb.Scale(0.1)]) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) tfp_samples = tfp_dist.sample(seed=self.seed, sample_shape=sample_shape) tfp_log_prob = tfp_dist.log_prob(samples) chex.assert_equal_shape([samples, tfp_samples]) np.testing.assert_allclose(log_prob, tfp_log_prob, rtol=RTOL) np.testing.assert_allclose(samples, expected_samples, rtol=RTOL)
def test_kl_divergence(self, mode_string): base_dist1 = tfd.Normal([0.1, 0.5, 0.9], [0.1, 1.1, 2.5]) base_dist2 = tfd.Normal(-0.1, 1.5) bij_tfp1 = tfb.Identity() bij_tfp2 = tfb.Identity() bij_distrax1 = bij_tfp1 bij_distrax2 = lambda_bijector.Lambda(lambda x: x) tfp_dist1 = tfd.TransformedDistribution(base_dist1, bij_tfp1) tfp_dist2 = tfd.TransformedDistribution(base_dist2, bij_tfp2) distrax_dist1 = transformed.Transformed(base_dist1, bij_distrax1) distrax_dist2 = transformed.Transformed(base_dist2, bij_distrax2) expected_result_fwd = base_dist1.kl_divergence(base_dist2) expected_result_inv = base_dist2.kl_divergence(base_dist1) distrax_fn1 = self.variant(distrax_dist1.kl_divergence) distrax_fn2 = self.variant(distrax_dist2.kl_divergence) if mode_string == 'distrax_to_distrax': result_fwd = distrax_fn1(distrax_dist2) result_inv = distrax_fn2(distrax_dist1) elif mode_string == 'distrax_to_tfp': result_fwd = distrax_fn1(tfp_dist2) result_inv = distrax_fn2(tfp_dist1) elif mode_string == 'tfp_to_distrax': result_fwd = tfp_dist1.kl_divergence(distrax_dist2) result_inv = tfp_dist2.kl_divergence(distrax_dist1) np.testing.assert_allclose(result_fwd, expected_result_fwd, rtol=RTOL) np.testing.assert_allclose(result_inv, expected_result_inv, rtol=RTOL)
def test_jittable(self): @jax.jit def f(x, b): return b.forward(x) bijector = lambda_bijector.Lambda(lambda x: x) x = np.zeros(()) f(x, bijector)
def test_kl_divergence_raises_on_different_bijectors(self): base_dist1 = tfd.Normal([0.1, 0.5, 0.9], [0.1, 1.1, 2.5]) base_dist2 = tfd.Normal(-0.1, 1.5) bij1 = lambda_bijector.Lambda(lambda x: x) bij2 = sigmoid.Sigmoid() distrax_dist1 = transformed.Transformed(base_dist1, bij1) distrax_dist2 = transformed.Transformed(base_dist2, bij2) with self.assertRaises(NotImplementedError): distrax_dist1.kl_divergence(distrax_dist2)
def test_event_shape(self, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = lambda_bijector.Lambda(jnp.tanh) dist = transformed.Transformed(base, bijector) tfp_bijector = tfb.Tanh() tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) assert dist.event_shape == tfp_dist.event_shape
def test_prob(self, mu, sigma, value, base_dist): base = base_dist(mu, sigma) bijector = lambda_bijector.Lambda(jnp.tanh) dist = transformed.Transformed(base, bijector) actual = self.variant(dist.prob)(value) tfp_bijector = tfb.Tanh() tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) expected = tfp_dist.prob(value) np.testing.assert_allclose(actual, expected, rtol=RTOL)
def test_log_dets(self, lambda_bjct, tfp_bijector_fn): bijector = lambda_bijector.Lambda(lambda_bjct) tfp_bijector = tfp_bijector_fn() x = np.array([0.05, 0.3, 0.45], dtype=np.float32) fldj = tfp_bijector.forward_log_det_jacobian(x, event_ndims=0) fldj_ = self.variant(bijector.forward_log_det_jacobian)(x) np.testing.assert_allclose(fldj_, fldj, rtol=RTOL) y = bijector.forward(x) ildj = tfp_bijector.inverse_log_det_jacobian(y, event_ndims=0) ildj_ = self.variant(bijector.inverse_log_det_jacobian)(y) np.testing.assert_allclose(ildj_, ildj, rtol=RTOL)
def test_method(self, function_string, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = lambda_bijector.Lambda(lambda x: x + 3) dist = transformed.Transformed(base, bijector) tfp_bijector = tfb.Shift(3) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) np.testing.assert_allclose(self.variant(getattr(dist, function_string))(), getattr(tfp_dist, function_string)(), rtol=RTOL)
def test_raises_on_invalid_input_shape(self): bij = lambda_bijector.Lambda( forward=lambda x: x, inverse=lambda y: y, forward_log_det_jacobian=lambda x: jnp.zeros_like(x[:-1]), inverse_log_det_jacobian=lambda y: jnp.zeros_like(y[:-1]), event_ndims_in=1) for fn in [ bij.forward, bij.inverse, bij.forward_log_det_jacobian, bij.inverse_log_det_jacobian, bij.forward_and_log_det, bij.inverse_and_log_det ]: with self.assertRaises(ValueError): fn(jnp.array(0))
def test_against_tfp_bijectors(self, lambda_bjct, tfp_bijector, base_dist): mu = np.array([-1.0, 0.0, 1.0], dtype=np.float32) sigma = np.array([0.5, 1.0, 2.5], dtype=np.float32) base = base_dist(mu, sigma) bijector = lambda_bijector.Lambda(lambda_bjct) dist = transformed.Transformed(base, bijector) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector()) y = np.array([0.05, 0.3, 0.95], dtype=np.float32) lp_y = tfp_dist.log_prob(y) lp_y_ = self.variant(dist.log_prob)(y) np.testing.assert_allclose(lp_y_, lp_y, rtol=RTOL) p_y = tfp_dist.prob(y) p_y_ = self.variant(dist.prob)(y) np.testing.assert_allclose(p_y_, p_y, rtol=RTOL)
def test_sample_shape(self, mu, sigma, sample_shape, base_dist): base = base_dist(mu, sigma) bijector = lambda_bijector.Lambda(jnp.tanh) dist = transformed.Transformed(base, bijector) def sample_fn(seed, sample_shape): return dist.sample(seed=seed, sample_shape=sample_shape) samples = self.variant(sample_fn, ignore_argnums=(1, ), static_argnums=1)(self.seed, sample_shape) tfp_bijector = tfb.Tanh() tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) tfp_samples = tfp_dist.sample(sample_shape=sample_shape, seed=self.seed) chex.assert_equal_shape([samples, tfp_samples])
def test_raises_on_event_ndims_without_log_det(self, ndims_in, ndims_out): with self.assertRaises(ValueError): lambda_bijector.Lambda(forward=lambda x: x, event_ndims_in=ndims_in, event_ndims_out=ndims_out)
def test_raises_on_log_det_without_event_ndims(self): with self.assertRaises(ValueError): lambda_bijector.Lambda( forward=lambda x: x, forward_log_det_jacobian=lambda x: jnp.zeros_like(x[:-1]), event_ndims_in=None)
def test_raises_on_both_none(self): with self.assertRaises(ValueError): lambda_bijector.Lambda(forward=None, inverse=None)