def test_transformed_transformed_distribution():
    loc, scale = -2, 3
    dist1 = dist.TransformedDistribution(dist.Normal(2, 3), transforms.PowerTransform(2.))
    dist2 = dist.TransformedDistribution(dist1, transforms.AffineTransform(-2, 3))
    assert isinstance(dist2.base_dist, dist.Normal)
    assert len(dist2.transforms) == 2
    assert isinstance(dist2.transforms[0], transforms.PowerTransform)
    assert isinstance(dist2.transforms[1], transforms.AffineTransform)

    rng_key = random.PRNGKey(0)
    assert_allclose(loc + scale * dist1.sample(rng_key), dist2.sample(rng_key))
    intermediates = dist2.sample_with_intermediates(rng_key)
    assert len(intermediates) == 2
    if len(shape) == transform.event_dim:
        if len(event_shape) == 1:
            expected = onp.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = onp.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6)


@pytest.mark.parametrize('transformed_dist', [
    dist.TransformedDistribution(dist.Normal(np.array([2., 3.]), 1.), transforms.ExpTransform()),
    dist.TransformedDistribution(dist.Exponential(np.ones(2)), [
        transforms.PowerTransform(0.7),
        transforms.AffineTransform(0., np.ones(2) * 3)
    ]),
])
def test_transformed_distribution_intermediates(transformed_dist):
    sample, intermediates = transformed_dist.sample_with_intermediates(random.PRNGKey(1))
    assert_allclose(transformed_dist.log_prob(sample, intermediates), transformed_dist.log_prob(sample))


def test_transformed_transformed_distribution():
    loc, scale = -2, 3
    dist1 = dist.TransformedDistribution(dist.Normal(2, 3), transforms.PowerTransform(2.))
    dist2 = dist.TransformedDistribution(dist1, transforms.AffineTransform(-2, 3))
    assert isinstance(dist2.base_dist, dist.Normal)
    assert len(dist2.transforms) == 2
    assert isinstance(dist2.transforms[0], transforms.PowerTransform)
示例#3
0
            expected = onp.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = onp.linalg.slogdet(jax.jacobian(
                transform.inv)(y))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6)


@pytest.mark.parametrize('transformed_dist', [
    dist.TransformedDistribution(dist.Normal(np.array([2., 3.]), 1.),
                                 transforms.ExpTransform()),
    dist.TransformedDistribution(dist.Exponential(np.ones(2)), [
        transforms.PowerTransform(0.7),
        transforms.AffineTransform(0.,
                                   np.ones(2) * 3)
    ]),
])
def test_transformed_distribution_intermediates(transformed_dist):
    sample, intermediates = transformed_dist.sample_with_intermediates(
        random.PRNGKey(1))
    assert_allclose(transformed_dist.log_prob(sample, intermediates),
                    transformed_dist.log_prob(sample))


def test_transformed_transformed_distribution():
    loc, scale = -2, 3
    dist1 = dist.TransformedDistribution(dist.Normal(2, 3),
                                         transforms.PowerTransform(2.))