예제 #1
0
    actual = transform.log_abs_det_jacobian(x, y)
    assert np.shape(actual) == batch_shape
    if len(shape) == transform.event_dim:
        if len(event_shape) == 1:
            expected = onp.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = onp.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6)


@pytest.mark.parametrize('transformed_dist', [
    dist.TransformedDistribution(dist.Normal(np.array([2., 3.]), 1.), transforms.ExpTransform()),
    dist.TransformedDistribution(dist.Exponential(np.ones(2)), [
        transforms.PowerTransform(0.7),
        transforms.AffineTransform(0., np.ones(2) * 3)
    ]),
])
def test_transformed_distribution_intermediates(transformed_dist):
    sample, intermediates = transformed_dist.sample_with_intermediates(random.PRNGKey(1))
    assert_allclose(transformed_dist.log_prob(sample, intermediates), transformed_dist.log_prob(sample))


def test_transformed_transformed_distribution():
    loc, scale = -2, 3
    dist1 = dist.TransformedDistribution(dist.Normal(2, 3), transforms.PowerTransform(2.))
    dist2 = dist.TransformedDistribution(dist1, transforms.AffineTransform(-2, 3))
    assert isinstance(dist2.base_dist, dist.Normal)
예제 #2
0
    if len(shape) == transform.event_dim:
        if len(event_shape) == 1:
            expected = onp.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = onp.linalg.slogdet(jax.jacobian(
                transform.inv)(y))[1]
        else:
            expected = np.log(np.abs(grad(transform)(x)))
            inv_expected = np.log(np.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-6)
        assert_allclose(actual, -inv_expected, atol=1e-6)


@pytest.mark.parametrize('transformed_dist', [
    dist.TransformedDistribution(dist.Normal(np.array([2., 3.]), 1.),
                                 transforms.ExpTransform()),
    dist.TransformedDistribution(dist.Exponential(np.ones(2)), [
        transforms.PowerTransform(0.7),
        transforms.AffineTransform(0.,
                                   np.ones(2) * 3)
    ]),
])
def test_transformed_distribution_intermediates(transformed_dist):
    sample, intermediates = transformed_dist.sample_with_intermediates(
        random.PRNGKey(1))
    assert_allclose(transformed_dist.log_prob(sample, intermediates),
                    transformed_dist.log_prob(sample))


def test_transformed_transformed_distribution():
    loc, scale = -2, 3