예제 #1
0
def test_transformed_potential_energy():
    beta_dist = dist.Beta(np.ones(5), np.ones(5))
    transform = transforms.AffineTransform(3, 4)
    inv_transform = transforms.AffineTransform(-0.75, 0.25)

    z = random.normal(random.PRNGKey(0), (5,))
    pe_expected = -dist.TransformedDistribution(beta_dist, transform).log_prob(z)
    potential_fn = lambda x: -beta_dist.log_prob(x)  # noqa: E731
    pe_actual = transformed_potential_energy(potential_fn, inv_transform, z)
    assert_allclose(pe_actual, pe_expected)
예제 #2
0
 def actual_model(data):
     alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
     with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
         loc = numpyro.sample(
             "loc",
             dist.TransformedDistribution(
                 dist.Uniform(0, 1), transforms.AffineTransform(0, alpha)),
         )
     numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)
예제 #3
0
def test_transformed_transformed_distribution():
    loc, scale = -2, 3
    dist1 = dist.TransformedDistribution(dist.Normal(2, 3), transforms.PowerTransform(2.))
    dist2 = dist.TransformedDistribution(dist1, transforms.AffineTransform(-2, 3))
    assert isinstance(dist2.base_dist, dist.Normal)
    assert len(dist2.transforms) == 2
    assert isinstance(dist2.transforms[0], transforms.PowerTransform)
    assert isinstance(dist2.transforms[1], transforms.AffineTransform)

    rng_key = random.PRNGKey(0)
    assert_allclose(loc + scale * dist1.sample(rng_key), dist2.sample(rng_key))
    intermediates = dist2.sample_with_intermediates(rng_key)
    assert len(intermediates) == 2
예제 #4
0
        if len(event_shape) == 1:
            expected = onp.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = onp.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6)


@pytest.mark.parametrize('transformed_dist', [
    dist.TransformedDistribution(dist.Normal(np.array([2., 3.]), 1.), transforms.ExpTransform()),
    dist.TransformedDistribution(dist.Exponential(np.ones(2)), [
        transforms.PowerTransform(0.7),
        transforms.AffineTransform(0., np.ones(2) * 3)
    ]),
])
def test_transformed_distribution_intermediates(transformed_dist):
    sample, intermediates = transformed_dist.sample_with_intermediates(random.PRNGKey(1))
    assert_allclose(transformed_dist.log_prob(sample, intermediates), transformed_dist.log_prob(sample))


def test_transformed_transformed_distribution():
    loc, scale = -2, 3
    dist1 = dist.TransformedDistribution(dist.Normal(2, 3), transforms.PowerTransform(2.))
    dist2 = dist.TransformedDistribution(dist1, transforms.AffineTransform(-2, 3))
    assert isinstance(dist2.base_dist, dist.Normal)
    assert len(dist2.transforms) == 2
    assert isinstance(dist2.transforms[0], transforms.PowerTransform)
    assert isinstance(dist2.transforms[1], transforms.AffineTransform)
예제 #5
0
            inv_expected = onp.linalg.slogdet(jax.jacobian(
                transform.inv)(y))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6)


@pytest.mark.parametrize('transformed_dist', [
    dist.TransformedDistribution(dist.Normal(np.array([2., 3.]), 1.),
                                 transforms.ExpTransform()),
    dist.TransformedDistribution(dist.Exponential(np.ones(2)), [
        transforms.PowerTransform(0.7),
        transforms.AffineTransform(0.,
                                   np.ones(2) * 3)
    ]),
])
def test_transformed_distribution_intermediates(transformed_dist):
    sample, intermediates = transformed_dist.sample_with_intermediates(
        random.PRNGKey(1))
    assert_allclose(transformed_dist.log_prob(sample, intermediates),
                    transformed_dist.log_prob(sample))


def test_transformed_transformed_distribution():
    loc, scale = -2, 3
    dist1 = dist.TransformedDistribution(dist.Normal(2, 3),
                                         transforms.PowerTransform(2.))
    dist2 = dist.TransformedDistribution(dist1,
                                         transforms.AffineTransform(-2, 3))
예제 #6
0
 def actual_model(data):
     alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
     with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
         loc = numpyro.sample('loc', dist.TransformedDistribution(
             dist.Uniform(0, 1), transforms.AffineTransform(0, alpha)))
     numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)