def test_dtype_is_consistent_with_tfp(self, dist_fn, bijector_fn): base = dist_fn() bijector = bijector_fn() dist = transformed.Transformed(base, bijector) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector) assert dist.dtype == tfp_dist.dtype
def test_on_distrax_bijector(self): bij = Tanh() wrapped_bij = conversion.to_tfp(bij) assert isinstance(wrapped_bij, Tanh) # Call the `forward` attribute of a wrapped Tanh. np.testing.assert_equal(wrapped_bij.forward(np.zeros(())), bij.forward(np.zeros(())))
def test_sample_and_log_prob(self, mu, sigma, sample_shape, base_dist): base = base_dist(mu, sigma) bijector = inverse.Inverse(tfb.Scale(2)) dist = transformed.Transformed(base, bijector) def sample_and_log_prob_fn(seed, sample_shape): return dist.sample_and_log_prob(seed=seed, sample_shape=sample_shape) samples, log_prob = self.variant(sample_and_log_prob_fn, ignore_argnums=(1, ), static_argnums=(1, ))(self.seed, sample_shape) expected_samples = bijector.forward( base.sample(seed=self.seed, sample_shape=sample_shape)) tfp_bijector = tfb.Invert(tfb.Scale(2)) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) tfp_samples = tfp_dist.sample(seed=self.seed, sample_shape=sample_shape) tfp_log_prob = tfp_dist.log_prob(samples) chex.assert_equal_shape([samples, tfp_samples]) np.testing.assert_allclose(log_prob, tfp_log_prob, rtol=RTOL) np.testing.assert_allclose(samples, expected_samples, rtol=RTOL)
def test_event_shape(self, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = tfb.Scale(2) dist = transformed.Transformed(base, bijector) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector) assert dist.event_shape == tfp_dist.event_shape
def test_method(self, function_string, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = tfb.Scale(2) dist = transformed.Transformed(base, bijector) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector) np.testing.assert_allclose( self.variant(getattr(dist, function_string))(), getattr(tfp_dist, function_string)())
def test_prob(self, mu, sigma, value, base_dist): base = base_dist(mu, sigma) bijector = tfb.Scale(2) dist = transformed.Transformed(base, bijector) actual = self.variant(dist.prob)(value) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector) expected = tfp_dist.prob(value) np.testing.assert_array_equal(actual, expected)
def test_event_shape(self, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = lambda_bijector.Lambda(jnp.tanh) dist = transformed.Transformed(base, bijector) tfp_bijector = tfb.Tanh() tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) assert dist.event_shape == tfp_dist.event_shape
def test_event_shape(self, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = chain.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()]) dist = transformed.Transformed(base, bijector) tfp_bijector = tfb.Chain([tfb.Scale(2), tfb.Shift(3), tfb.Tanh()]) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) assert dist.event_shape == tfp_dist.event_shape
def test_prob(self, mu, sigma, value, base_dist): base = base_dist(mu, sigma) bijector = tanh.Tanh() dist = transformed.Transformed(base, bijector) actual = self.variant(dist.prob)(value) tfp_bijector = tfb.Tanh() tfp_dist = tfd.TransformedDistribution( conversion.to_tfp(base), tfp_bijector) expected = tfp_dist.prob(value) np.testing.assert_allclose(actual, expected, atol=1e-9)
def test_method(self, function_string, mu, sigma, base_dist): base = base_dist(mu, sigma) bijector = lambda_bijector.Lambda(lambda x: x + 3) dist = transformed.Transformed(base, bijector) tfp_bijector = tfb.Shift(3) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector) np.testing.assert_allclose(self.variant(getattr(dist, function_string))(), getattr(tfp_dist, function_string)(), rtol=RTOL)
def test_sample_shape(self, mu, sigma, sample_shape, base_dist): base = base_dist(mu, sigma) bijector = tfb.Scale(2) dist = transformed.Transformed(base, bijector) def sample_fn(seed, sample_shape): return dist.sample(seed=seed, sample_shape=sample_shape) samples = self.variant(sample_fn, ignore_argnums=(1,), static_argnums=1)( self.seed, sample_shape) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), bijector) tfp_samples = tfp_dist.sample(sample_shape=sample_shape, seed=self.seed) chex.assert_equal_shape([samples, tfp_samples])
def test_auto_lambda(self, forward_fn, tfp_bijector, base_dist): mu = np.array([-1.0, 0.0, 1.0], dtype=np.float32) sigma = np.array([0.5, 1.0, 2.5], dtype=np.float32) base = base_dist(mu, sigma) dist = transformed.Transformed(base, forward_fn) tfp_dist = tfd.TransformedDistribution(conversion.to_tfp(base), tfp_bijector()) y = np.array([0.05, 0.3, 0.95], dtype=np.float32) lp_y = tfp_dist.log_prob(y) lp_y_ = self.variant(dist.log_prob)(y) np.testing.assert_allclose(lp_y_, lp_y, rtol=RTOL) p_y = tfp_dist.prob(y) p_y_ = self.variant(dist.prob)(y) np.testing.assert_allclose(p_y_, p_y, rtol=RTOL)
def test_batched_bijector_against_tfp( self, bijector_fn, block_ndims, bijector_shape, params_shape): base = tfd.MultivariateNormalDiag( jnp.zeros(params_shape), jnp.ones(params_shape)) tfp_bijector = bijector_fn(jnp.ones(bijector_shape)) dx_bijector = block.Block(tfp_bijector, block_ndims) dx_dist = transformed.Transformed(base, dx_bijector) tfp_dist = tfd.TransformedDistribution( conversion.to_tfp(base), tfp_bijector) with self.subTest('event_shape property matches TFP'): np.testing.assert_equal(dx_dist.event_shape, tfp_dist.event_shape) with self.subTest('sample shape matches TFP'): dx_sample = self.variant(dx_dist.sample)(seed=self.seed) tfp_sample = self.variant(tfp_dist.sample)(seed=self.seed) chex.assert_equal_shape([dx_sample, tfp_sample]) with self.subTest('log_prob(dx_sample) matches TFP'): dx_logp_dx = self.variant(dx_dist.log_prob)(dx_sample) tfp_logp_dx = self.variant(tfp_dist.log_prob)(dx_sample) np.testing.assert_allclose(dx_logp_dx, tfp_logp_dx, rtol=RTOL) with self.subTest('log_prob(tfp_sample) matches TFP'): dx_logp_tfp = self.variant(dx_dist.log_prob)(tfp_sample) tfp_logp_tfp = self.variant(tfp_dist.log_prob)(tfp_sample) np.testing.assert_allclose(dx_logp_tfp, tfp_logp_tfp, rtol=RTOL) with self.subTest('sample/lp shape is self-consistent'): second_sample, log_prob = self.variant(dx_dist.sample_and_log_prob)( seed=self.seed) chex.assert_equal_shape([dx_sample, second_sample]) chex.assert_equal_shape([dx_logp_dx, log_prob])
def test_on_tfp_bijector(self): bij = tfb.Exp() wrapped_bij = conversion.to_tfp(bij) self.assertIs(wrapped_bij, bij)
def test_on_tfp_distribution(self): dist = tfd.Normal(0., 1.) wrapped_dist = conversion.to_tfp(dist) self.assertIs(wrapped_dist, dist)
def test_on_distrax_distribution(self): dist = Normal(loc=0., scale=1.) wrapped_dist = conversion.to_tfp(dist) assert isinstance(wrapped_dist, Normal) # Access the `loc` attribute of a wrapped Normal. np.testing.assert_almost_equal(wrapped_dist.loc, 0.)