def test_batched_bijector_shapes(self, batch_shape, sample_shape): base = tfd.MultivariateNormalDiag(jnp.zeros(3), jnp.ones(3)) bijector = block.Block(tfb.Scale(jnp.ones(batch_shape + (3,))), 1) dist = transformed.Transformed(base, bijector) with self.subTest('batch_shape'): chex.assert_equal(dist.batch_shape, batch_shape) with self.subTest('sample.shape'): sample = dist.sample(seed=self.seed, sample_shape=sample_shape) chex.assert_equal(sample.shape, sample_shape + batch_shape + (3,)) with self.subTest('sample_and_log_prob sample.shape'): sample, log_prob = dist.sample_and_log_prob( seed=self.seed, sample_shape=sample_shape) chex.assert_equal(sample.shape, sample_shape + batch_shape + (3,)) with self.subTest('sample_and_log_prob log_prob.shape'): sample, log_prob = dist.sample_and_log_prob( seed=self.seed, sample_shape=sample_shape) chex.assert_equal(log_prob.shape, sample_shape + batch_shape) with self.subTest('sample_and_log_prob log_prob value'): sample, log_prob = dist.sample_and_log_prob( seed=self.seed, sample_shape=sample_shape) np.testing.assert_allclose(log_prob, dist.log_prob(sample))
def test_jittable(self): @jax.jit def f(x, b): return b.forward(x) bijector = block_bijector.Block(scalar_affine.ScalarAffine(0), 1) x = np.zeros((2, 3)) f(x, bijector)
def test_raises_on_invalid_input_shape(self): bij = block_bijector.Block(lambda x: x, 1) for fn in [ bij.forward, bij.inverse, bij.forward_log_det_jacobian, bij.inverse_log_det_jacobian, bij.forward_and_log_det, bij.inverse_and_log_det ]: with self.assertRaises(ValueError): fn(jnp.array(0))
def test_kl_divergence_raises_on_event_shape(self): base_dist1 = tfd.MultivariateNormalDiag([0.1, 0.5, 0.9], [0.1, 1.1, 2.5]) base_dist2 = tfd.Normal(-0.1, 1.5) bij1 = block.Block(lambda_bijector.Lambda(lambda x: x), ndims=1) bij2 = lambda_bijector.Lambda(lambda x: x) distrax_dist1 = transformed.Transformed(base_dist1, bij1) distrax_dist2 = transformed.Transformed(base_dist2, bij2) with self.assertRaises(ValueError): distrax_dist1.kl_divergence(distrax_dist2)
def test_raises_if_inner_bijector_is_not_scalar(self): key = jax.random.PRNGKey(101) event_shape = (2, 3) bijector = masked_coupling.MaskedCoupling( mask=jax.random.choice(key, jnp.array([True, False]), event_shape), conditioner=lambda x: x, bijector=lambda _: block.Block(lambda x: x, 1)) with self.assertRaises(ValueError): bijector.forward_and_log_det(jnp.zeros(event_shape)) with self.assertRaises(ValueError): bijector.inverse_and_log_det(jnp.zeros(event_shape))
def test_against_tfp_semantics(self, tfp_bijector_fn, ndims): tfp_bijector = tfp_bijector_fn() x = jax.random.normal(self.seed, [2, 3, 4, 5, 6]) y = tfp_bijector(x) fwd_event_ndims = ndims + tfp_bijector.forward_min_event_ndims inv_event_ndims = ndims + tfp_bijector.inverse_min_event_ndims block = block_bijector.Block(tfp_bijector, ndims) np.testing.assert_allclose( tfp_bijector.forward_log_det_jacobian(x, fwd_event_ndims), self.variant(block.forward_log_det_jacobian)(x), atol=2e-5) np.testing.assert_allclose( tfp_bijector.inverse_log_det_jacobian(y, inv_event_ndims), self.variant(block.inverse_log_det_jacobian)(y), atol=2e-5)
def test_raises_on_invalid_inner_bijector(self): event_shape = (2, 3) bij = split_coupling.SplitCoupling( split_index=event_shape[-1] // 2, event_ndims=len(event_shape), conditioner=lambda x: x, bijector=lambda _: block.Block(lambda x: x, len(event_shape) + 1)) for fn in [ bij.forward, bij.inverse, bij.forward_log_det_jacobian, bij.inverse_log_det_jacobian, bij.forward_and_log_det, bij.inverse_and_log_det ]: with self.assertRaises(ValueError): fn(jnp.zeros(event_shape))
def test_forward_inverse_work_as_expected(self, bijector_fn, ndims): bijct = conversion.as_bijector(bijector_fn()) x = jax.random.normal(self.seed, [2, 3]) block = block_bijector.Block(bijct, ndims) np.testing.assert_array_equal( self.variant(bijct.forward)(x), self.variant(block.forward)(x)) np.testing.assert_array_equal( self.variant(bijct.inverse)(x), self.variant(block.inverse)(x)) np.testing.assert_array_equal( self.variant(bijct.forward_and_log_det)(x)[0], self.variant(block.forward_and_log_det)(x)[0]) np.testing.assert_array_equal( self.variant(bijct.inverse_and_log_det)(x)[0], self.variant(block.inverse_and_log_det)(x)[0])
def _inner_bijector(self, params: BijectorParams) -> base.Bijector: """Returns an inner bijector for the passed params.""" bijector = conversion.as_bijector(self._bijector(params)) if bijector.event_ndims_in != bijector.event_ndims_out: raise ValueError( f'The inner bijector must have `event_ndims_in==event_ndims_out`. ' f'Instead, it has `event_ndims_in=={bijector.event_ndims_in}` and ' f'`event_ndims_out=={bijector.event_ndims_out}`.') extra_ndims = self.event_ndims_in - bijector.event_ndims_in if extra_ndims < 0: raise ValueError( f'The inner bijector can\'t have more event dimensions than the ' f'coupling bijector. Got {bijector.event_ndims_in} for the inner ' f'bijector and {self.event_ndims_in} for the coupling bijector.' ) elif extra_ndims > 0: bijector = block.Block(bijector, extra_ndims) return bijector
def test_log_det_jacobian_works_as_expected(self, bijector_fn, ndims): bijct = conversion.as_bijector(bijector_fn()) x = jax.random.normal(self.seed, [2, 3]) block = block_bijector.Block(bijct, ndims) axes = tuple(range(-ndims, 0)) np.testing.assert_allclose( self.variant(bijct.forward_log_det_jacobian)(x).sum(axes), self.variant(block.forward_log_det_jacobian)(x), rtol=RTOL) np.testing.assert_allclose( self.variant(bijct.inverse_log_det_jacobian)(x).sum(axes), self.variant(block.inverse_log_det_jacobian)(x), rtol=RTOL) np.testing.assert_allclose( self.variant(bijct.forward_and_log_det)(x)[1].sum(axes), self.variant(block.forward_and_log_det)(x)[1], rtol=RTOL) np.testing.assert_allclose( self.variant(bijct.inverse_and_log_det)(x)[1].sum(axes), self.variant(block.inverse_and_log_det)(x)[1], rtol=RTOL)
def test_batched_bijector_against_tfp( self, bijector_fn, block_ndims, bijector_shape, params_shape): base = tfd.MultivariateNormalDiag( jnp.zeros(params_shape), jnp.ones(params_shape)) tfp_bijector = bijector_fn(jnp.ones(bijector_shape)) dx_bijector = block.Block(tfp_bijector, block_ndims) dx_dist = transformed.Transformed(base, dx_bijector) tfp_dist = tfd.TransformedDistribution( conversion.to_tfp(base), tfp_bijector) with self.subTest('event_shape property matches TFP'): np.testing.assert_equal(dx_dist.event_shape, tfp_dist.event_shape) with self.subTest('sample shape matches TFP'): dx_sample = self.variant(dx_dist.sample)(seed=self.seed) tfp_sample = self.variant(tfp_dist.sample)(seed=self.seed) chex.assert_equal_shape([dx_sample, tfp_sample]) with self.subTest('log_prob(dx_sample) matches TFP'): dx_logp_dx = self.variant(dx_dist.log_prob)(dx_sample) tfp_logp_dx = self.variant(tfp_dist.log_prob)(dx_sample) np.testing.assert_allclose(dx_logp_dx, tfp_logp_dx, rtol=RTOL) with self.subTest('log_prob(tfp_sample) matches TFP'): dx_logp_tfp = self.variant(dx_dist.log_prob)(tfp_sample) tfp_logp_tfp = self.variant(tfp_dist.log_prob)(tfp_sample) np.testing.assert_allclose(dx_logp_tfp, tfp_logp_tfp, rtol=RTOL) with self.subTest('sample/lp shape is self-consistent'): second_sample, log_prob = self.variant(dx_dist.sample_and_log_prob)( seed=self.seed) chex.assert_equal_shape([dx_sample, second_sample]) chex.assert_equal_shape([dx_logp_dx, log_prob])
def test_invalid_properties(self): bijct = conversion.as_bijector(jnp.tanh) with self.assertRaises(ValueError): block_bijector.Block(bijct, -1)
def test_properties(self): bijct = conversion.as_bijector(jnp.tanh) block = block_bijector.Block(bijct, 1) assert block.ndims == 1 assert isinstance(block.bijector, base.Bijector)
def test_raises_on_incorrect_shape(self, block_dims): base = tfd.MultivariateNormalDiag(jnp.zeros((2, 3)), jnp.ones((2, 3))) scalar_bijector = tfb.Scale(jnp.ones((1, 2, 3))) block_bijector = block.Block(scalar_bijector, block_dims) with self.assertRaises(ValueError): transformed.Transformed(base, block_bijector)
def test_raises_on_incompatible_dimensions(self): with self.assertRaises(ValueError): chain.Chain([jnp.log, block.Block(jnp.exp, 1)])