def test_sample_and_log_prob(self, mu, sigma, sample_shape, base_dist): base = base_dist(mu, sigma) bijector = inverse.Inverse(tfb.Scale(2)) dist = transformed.Transformed(base, bijector) def sample_and_log_prob_fn(seed, sample_shape): return dist.sample_and_log_prob(seed=seed, sample_shape=sample_shape) samples, log_prob = self.variant(sample_and_log_prob_fn, ignore_argnums=(1, ), static_argnums=(1, ))(self.seed, sample_shape) expected_samples = bijector.forward( base.sample(seed=self.seed, sample_shape=sample_shape)) tfp_bijector = tfb.Invert(tfb.Scale(2)) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) tfp_samples = tfp_dist.sample(seed=self.seed, sample_shape=sample_shape) tfp_log_prob = tfp_dist.log_prob(samples) chex.assert_equal_shape([samples, tfp_samples]) np.testing.assert_allclose(log_prob, tfp_log_prob, rtol=RTOL) np.testing.assert_allclose(samples, expected_samples, rtol=RTOL)
def test_jittable(self): @jax.jit def f(x, b): return b.forward(x) bijector = inverse.Inverse(scalar_affine.ScalarAffine(0, 1)) x = np.zeros(()) f(x, bijector)
def test_integer_inputs(self, inputs): bijector = inverse.Inverse(scalar_affine.ScalarAffine(shift=1.0)) output, log_det = self.variant(bijector.forward_and_log_det)(inputs) expected_out = jnp.array(inputs, dtype=jnp.float32) - 1.0 expected_log_det = jnp.zeros_like(inputs, dtype=jnp.float32) np.testing.assert_array_equal(output, expected_out) np.testing.assert_array_equal(log_det, expected_log_det)
def test_event_shape(self, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = inverse.Inverse(tfb.Scale(2)) dist = transformed.Transformed(base, bijector) tfp_bijector = tfb.Invert(tfb.Scale(2)) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) assert dist.event_shape == tfp_dist.event_shape
def test_prob(self, mu, sigma, value, base_dist): base = base_dist(mu, sigma) bijector = inverse.Inverse(tfb.Scale(2)) dist = transformed.Transformed(base, bijector) actual = self.variant(dist.prob)(value) tfp_bijector = tfb.Invert(tfb.Scale(2)) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) expected = tfp_dist.prob(value) np.testing.assert_allclose(actual, expected, atol=1e-9)
def test_method(self, function_string, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = inverse.Inverse(tfb.Scale(2)) dist = transformed.Transformed(base, bijector) tfp_bijector = tfb.Invert(tfb.Scale(2)) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) np.testing.assert_allclose( self.variant(getattr(dist, function_string))(), getattr(tfp_dist, function_string)())
def test_sample_shape(self, mu, sigma, sample_shape, base_dist): base = base_dist(mu, sigma) bijector = inverse.Inverse(tfb.Scale(2)) dist = transformed.Transformed(base, bijector) def sample_fn(seed, sample_shape): return dist.sample(seed=seed, sample_shape=sample_shape) samples = self.variant(sample_fn, ignore_argnums=(1, ), static_argnums=1)(self.seed, sample_shape) tfp_bijector = tfb.Invert(tfb.Scale(2)) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) tfp_samples = tfp_dist.sample(sample_shape=sample_shape, seed=self.seed) chex.assert_equal_shape([samples, tfp_samples])
def test_properties(self): bijector = inverse.Inverse(tfb.Scale(2)) assert isinstance(bijector.bijector, base_bijector.Bijector)