def test_transformed_potential_energy(): beta_dist = dist.Beta(np.ones(5), np.ones(5)) transform = transforms.AffineTransform(3, 4) inv_transform = transforms.AffineTransform(-0.75, 0.25) z = random.normal(random.PRNGKey(0), (5,)) pe_expected = -dist.TransformedDistribution(beta_dist, transform).log_prob(z) potential_fn = lambda x: -beta_dist.log_prob(x) # noqa: E731 pe_actual = transformed_potential_energy(potential_fn, inv_transform, z) assert_allclose(pe_actual, pe_expected)
def actual_model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution( dist.Uniform(0, 1), transforms.AffineTransform(0, alpha)), ) numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
def test_transformed_transformed_distribution(): loc, scale = -2, 3 dist1 = dist.TransformedDistribution(dist.Normal(2, 3), transforms.PowerTransform(2.)) dist2 = dist.TransformedDistribution(dist1, transforms.AffineTransform(-2, 3)) assert isinstance(dist2.base_dist, dist.Normal) assert len(dist2.transforms) == 2 assert isinstance(dist2.transforms[0], transforms.PowerTransform) assert isinstance(dist2.transforms[1], transforms.AffineTransform) rng_key = random.PRNGKey(0) assert_allclose(loc + scale * dist1.sample(rng_key), dist2.sample(rng_key)) intermediates = dist2.sample_with_intermediates(rng_key) assert len(intermediates) == 2
if len(event_shape) == 1: expected = onp.linalg.slogdet(jax.jacobian(transform)(x))[1] inv_expected = onp.linalg.slogdet(jax.jacobian(transform.inv)(y))[1] else: expected = np.log(np.abs(grad(transform)(x))) inv_expected = np.log(np.abs(grad(transform.inv)(y))) assert_allclose(actual, expected, atol=1e-6) assert_allclose(actual, -inv_expected, atol=1e-6) @pytest.mark.parametrize('transformed_dist', [ dist.TransformedDistribution(dist.Normal(np.array([2., 3.]), 1.), transforms.ExpTransform()), dist.TransformedDistribution(dist.Exponential(np.ones(2)), [ transforms.PowerTransform(0.7), transforms.AffineTransform(0., np.ones(2) * 3) ]), ]) def test_transformed_distribution_intermediates(transformed_dist): sample, intermediates = transformed_dist.sample_with_intermediates(random.PRNGKey(1)) assert_allclose(transformed_dist.log_prob(sample, intermediates), transformed_dist.log_prob(sample)) def test_transformed_transformed_distribution(): loc, scale = -2, 3 dist1 = dist.TransformedDistribution(dist.Normal(2, 3), transforms.PowerTransform(2.)) dist2 = dist.TransformedDistribution(dist1, transforms.AffineTransform(-2, 3)) assert isinstance(dist2.base_dist, dist.Normal) assert len(dist2.transforms) == 2 assert isinstance(dist2.transforms[0], transforms.PowerTransform) assert isinstance(dist2.transforms[1], transforms.AffineTransform)
inv_expected = onp.linalg.slogdet(jax.jacobian( transform.inv)(y))[1] else: expected = np.log(np.abs(grad(transform)(x))) inv_expected = np.log(np.abs(grad(transform.inv)(y))) assert_allclose(actual, expected, atol=1e-6) assert_allclose(actual, -inv_expected, atol=1e-6) @pytest.mark.parametrize('transformed_dist', [ dist.TransformedDistribution(dist.Normal(np.array([2., 3.]), 1.), transforms.ExpTransform()), dist.TransformedDistribution(dist.Exponential(np.ones(2)), [ transforms.PowerTransform(0.7), transforms.AffineTransform(0., np.ones(2) * 3) ]), ]) def test_transformed_distribution_intermediates(transformed_dist): sample, intermediates = transformed_dist.sample_with_intermediates( random.PRNGKey(1)) assert_allclose(transformed_dist.log_prob(sample, intermediates), transformed_dist.log_prob(sample)) def test_transformed_transformed_distribution(): loc, scale = -2, 3 dist1 = dist.TransformedDistribution(dist.Normal(2, 3), transforms.PowerTransform(2.)) dist2 = dist.TransformedDistribution(dist1, transforms.AffineTransform(-2, 3))
def actual_model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.TransformedDistribution( dist.Uniform(0, 1), transforms.AffineTransform(0, alpha))) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)