def test_dtype_is_consistent_with_tfp(self, dist_fn, bijector_fn): base = dist_fn() bijector = bijector_fn() dist = transformed.Transformed(base, bijector) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector) assert dist.dtype == tfp_dist.dtype
def test_batched_bijector_shapes(self, batch_shape, sample_shape): base = tfd.MultivariateNormalDiag(jnp.zeros(3), jnp.ones(3)) bijector = block.Block(tfb.Scale(jnp.ones(batch_shape + (3,))), 1) dist = transformed.Transformed(base, bijector) with self.subTest('batch_shape'): chex.assert_equal(dist.batch_shape, batch_shape) with self.subTest('sample.shape'): sample = dist.sample(seed=self.seed, sample_shape=sample_shape) chex.assert_equal(sample.shape, sample_shape + batch_shape + (3,)) with self.subTest('sample_and_log_prob sample.shape'): sample, log_prob = dist.sample_and_log_prob( seed=self.seed, sample_shape=sample_shape) chex.assert_equal(sample.shape, sample_shape + batch_shape + (3,)) with self.subTest('sample_and_log_prob log_prob.shape'): sample, log_prob = dist.sample_and_log_prob( seed=self.seed, sample_shape=sample_shape) chex.assert_equal(log_prob.shape, sample_shape + batch_shape) with self.subTest('sample_and_log_prob log_prob value'): sample, log_prob = dist.sample_and_log_prob( seed=self.seed, sample_shape=sample_shape) np.testing.assert_allclose(log_prob, dist.log_prob(sample))
def test_dtype_is_as_expected(self, dist_fn, bijector_fn, expected_dtype): base = dist_fn() bijector = bijector_fn() dist = transformed.Transformed(base, bijector) sample = self.variant(dist.sample)(seed=self.seed) assert dist.dtype == sample.dtype assert dist.dtype == expected_dtype
def test_sample_and_log_prob(self, mu, sigma, sample_shape, base_dist): base = base_dist(mu, sigma) bijector = chain.Chain([tfb.Scale(10), tfb.Tanh(), tfb.Scale(0.1)]) dist = transformed.Transformed(base, bijector) def sample_and_log_prob_fn(seed, sample_shape): return dist.sample_and_log_prob(seed=seed, sample_shape=sample_shape) samples, log_prob = self.variant(sample_and_log_prob_fn, ignore_argnums=(1, ), static_argnums=(1, ))(self.seed, sample_shape) expected_samples = bijector.forward( base.sample(seed=self.seed, sample_shape=sample_shape)) tfp_bijector = tfb.Chain([tfb.Scale(10), tfb.Tanh(), tfb.Scale(0.1)]) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) tfp_samples = tfp_dist.sample(seed=self.seed, sample_shape=sample_shape) tfp_log_prob = tfp_dist.log_prob(samples) chex.assert_equal_shape([samples, tfp_samples]) np.testing.assert_allclose(log_prob, tfp_log_prob, rtol=RTOL) np.testing.assert_allclose(samples, expected_samples, rtol=RTOL)
def test_event_shape(self, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = tfb.Scale(2) dist = transformed.Transformed(base, bijector) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector) assert dist.event_shape == tfp_dist.event_shape
def test_prob(self, mu, sigma, value, base_dist): base = base_dist(mu, sigma) bijector = tfb.Scale(2) dist = transformed.Transformed(base, bijector) actual = self.variant(dist.prob)(value) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector) expected = tfp_dist.prob(value) np.testing.assert_array_equal(actual, expected)
def test_method(self, function_string, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = tfb.Scale(2) dist = transformed.Transformed(base, bijector) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector) np.testing.assert_allclose( self.variant(getattr(dist, function_string))(), getattr(tfp_dist, function_string)())
def test_jittable(self): @jax.jit def f(x, d): return d.log_prob(x) base = normal.Normal(0, 1) bijector = scalar_affine.ScalarAffine(0, 1) dist = transformed.Transformed(base, bijector) x = np.zeros(()) f(x, dist)
def test_event_shape(self, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = lambda_bijector.Lambda(jnp.tanh) dist = transformed.Transformed(base, bijector) tfp_bijector = tfb.Tanh() tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) assert dist.event_shape == tfp_dist.event_shape
def test_event_shape(self, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = chain.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()]) dist = transformed.Transformed(base, bijector) tfp_bijector = tfb.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()]) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) assert dist.event_shape == tfp_dist.event_shape
def test_prob(self, mu, sigma, value, base_dist): base = base_dist(mu, sigma) bijector = tanh.Tanh() dist = transformed.Transformed(base, bijector) actual = self.variant(dist.prob)(value) tfp_bijector = tfb.Tanh() tfp_dist = tfd.TransformedDistribution( conversion.to_tfp(base), tfp_bijector) expected = tfp_dist.prob(value) np.testing.assert_allclose(actual, expected, atol=1e-9)
def test_sample_shape(self, mu, sigma, sample_shape, base_dist): base = base_dist(mu, sigma) bijector = tfb.Scale(2) dist = transformed.Transformed(base, bijector) def sample_fn(seed, sample_shape): return dist.sample(seed=seed, sample_shape=sample_shape) samples = self.variant(sample_fn, ignore_argnums=(1,), static_argnums=1)( self.seed, sample_shape) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector) tfp_samples = tfp_dist.sample(sample_shape=sample_shape, seed=self.seed) chex.assert_equal_shape([samples, tfp_samples])
def test_method(self, function_string, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = lambda_bijector.Lambda(lambda x: x + 3) dist = transformed.Transformed(base, bijector) tfp_bijector = tfb.Shift(3) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) np.testing.assert_allclose(self.variant(getattr(dist, function_string))(), getattr(tfp_dist, function_string)(), rtol=RTOL)
def test_integer_inputs(self, inputs, base_dist): base = base_dist(jnp.zeros_like(inputs, dtype=jnp.float32), jnp.ones_like(inputs, dtype=jnp.float32)) bijector = scalar_affine.ScalarAffine(shift=0.0) dist = transformed.Transformed(base, bijector) log_prob = self.variant(dist.log_prob)(inputs) standard_normal_log_prob_of_zero = -0.9189385 expected_log_prob = jnp.full_like( inputs, standard_normal_log_prob_of_zero, dtype=jnp.float32) np.testing.assert_array_equal(log_prob, expected_log_prob)
def test_auto_lambda(self, forward_fn, tfp_bijector, base_dist): mu = np.array([-1.0, 0.0, 1.0], dtype=np.float32) sigma = np.array([0.5, 1.0, 2.5], dtype=np.float32) base = base_dist(mu, sigma) dist = transformed.Transformed(base, forward_fn) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector()) y = np.array([0.05, 0.3, 0.95], dtype=np.float32) lp_y = tfp_dist.log_prob(y) lp_y_ = self.variant(dist.log_prob)(y) np.testing.assert_allclose(lp_y_, lp_y, rtol=RTOL) p_y = tfp_dist.prob(y) p_y_ = self.variant(dist.prob)(y) np.testing.assert_allclose(p_y_, p_y, rtol=RTOL)
def test_bijector_that_assumes_batch_dimensions(self): # Create a Haiku conditioner that assumes a single batch dimension. def forward(x): network = hk.Sequential([hk.Flatten(preserve_dims=1), hk.Linear(3)]) return network(x) init, apply = hk.transform(forward) params = init(self.seed, jnp.ones((2, 3))) conditioner = functools.partial(apply, params, self.seed) bijector = masked_coupling.MaskedCoupling( jnp.ones(3) > 0, conditioner, tfb.Scale) base = tfd.MultivariateNormalDiag(jnp.zeros((2, 3)), jnp.ones((2, 3))) dist = transformed.Transformed(base, bijector) # Exercise the trace-based functions assert dist.batch_shape == (2,) assert dist.event_shape == (3,) assert dist.dtype == jnp.float32 sample = self.variant(dist.sample)(seed=self.seed) assert sample.dtype == dist.dtype self.variant(dist.log_prob)(sample)
def test_batched_bijector_against_tfp( self, bijector_fn, block_ndims, bijector_shape, params_shape): base = tfd.MultivariateNormalDiag( jnp.zeros(params_shape), jnp.ones(params_shape)) tfp_bijector = bijector_fn(jnp.ones(bijector_shape)) dx_bijector = block.Block(tfp_bijector, block_ndims) dx_dist = transformed.Transformed(base, dx_bijector) tfp_dist = tfd.TransformedDistribution( conversion.to_tfp(base), tfp_bijector) with self.subTest('event_shape property matches TFP'): np.testing.assert_equal(dx_dist.event_shape, tfp_dist.event_shape) with self.subTest('sample shape matches TFP'): dx_sample = self.variant(dx_dist.sample)(seed=self.seed) tfp_sample = self.variant(tfp_dist.sample)(seed=self.seed) chex.assert_equal_shape([dx_sample, tfp_sample]) with self.subTest('log_prob(dx_sample) matches TFP'): dx_logp_dx = self.variant(dx_dist.log_prob)(dx_sample) tfp_logp_dx = self.variant(tfp_dist.log_prob)(dx_sample) np.testing.assert_allclose(dx_logp_dx, tfp_logp_dx, rtol=RTOL) with self.subTest('log_prob(tfp_sample) matches TFP'): dx_logp_tfp = self.variant(dx_dist.log_prob)(tfp_sample) tfp_logp_tfp = self.variant(tfp_dist.log_prob)(tfp_sample) np.testing.assert_allclose(dx_logp_tfp, tfp_logp_tfp, rtol=RTOL) with self.subTest('sample/lp shape is self-consistent'): second_sample, log_prob = self.variant(dx_dist.sample_and_log_prob)( seed=self.seed) chex.assert_equal_shape([dx_sample, second_sample]) chex.assert_equal_shape([dx_logp_dx, log_prob])
def test_raises_on_incorrect_shape(self, block_dims): base = tfd.MultivariateNormalDiag(jnp.zeros((2, 3)), jnp.ones((2, 3))) scalar_bijector = tfb.Scale(jnp.ones((1, 2, 3))) block_bijector = block.Block(scalar_bijector, block_dims) with self.assertRaises(ValueError): transformed.Transformed(base, block_bijector)