Esempio n. 1
0
def _biject_to_simplex(constraint):
    return transforms.StickBreakingTransform()
Esempio n. 2
0
@transform_to.register(constraints.less_than)
def _transform_to_less_than(constraint):
    loc = constraint.upper_bound
    scale = loc.new([-1]).expand_as(loc)
    return transforms.ComposeTransform(
        [transforms.ExpTransform(),
         transforms.AffineTransform(loc, scale)])


biject_to.register(constraints.unit_interval, transforms.SigmoidTransform())
transform_to.register(constraints.unit_interval, transforms.SigmoidTransform())


@biject_to.register(constraints.interval)
@transform_to.register(constraints.interval)
def _transform_to_interval(constraint):
    loc = constraint.lower_bound
    scale = constraint.upper_bound - constraint.lower_bound
    return transforms.ComposeTransform([
        transforms.SigmoidTransform(),
        transforms.AffineTransform(loc, scale)
    ])


biject_to.register(constraints.simplex, transforms.StickBreakingTransform())
transform_to.register(constraints.simplex, transforms.BoltzmannTransform())

# TODO define a bijection for LowerCholeskyTransform
transform_to.register(constraints.lower_cholesky,
                      transforms.LowerCholeskyTransform())
Esempio n. 3
0
    p1 = dist.Normal(loc, scale).to_event(1)
    p2 = dist.MultivariateNormal(loc, scale_tril=scale.diag_embed())

    loc = torch.randn(batch_shape + (size, ))
    cov = torch.randn(batch_shape + (size, size))
    cov = cov @ cov.transpose(-1, -2) + 0.01 * torch.eye(size)
    q = dist.MultivariateNormal(loc, covariance_matrix=cov)

    actual = kl_divergence(p1, q)
    expected = kl_divergence(p2, q)
    assert_close(actual, expected)


@pytest.mark.parametrize('shape', [(5, ), (4, 5), (2, 3, 5)], ids=str)
@pytest.mark.parametrize('event_dim', [0, 1])
@pytest.mark.parametrize(
    'transform',
    [transforms.ExpTransform(),
     transforms.StickBreakingTransform()])
def test_kl_transformed_transformed(shape, event_dim, transform):
    p_base = dist.Normal(torch.zeros(shape),
                         torch.ones(shape)).to_event(event_dim)
    q_base = dist.Normal(torch.ones(shape) * 2,
                         torch.ones(shape)).to_event(event_dim)
    p = dist.TransformedDistribution(p_base, transform)
    q = dist.TransformedDistribution(q_base, transform)
    kl = kl_divergence(q, p)
    expected_shape = shape[:-1] if max(transform.event_dim,
                                       event_dim) == 1 else shape
    assert kl.shape == expected_shape
Esempio n. 4
0
@pytest.mark.parametrize("size", [1, 2, 3])
def test_kl_independent_normal_mvn(batch_shape, size):
    loc = torch.randn(batch_shape + (size,))
    scale = torch.randn(batch_shape + (size,)).exp()
    p1 = dist.Normal(loc, scale).to_event(1)
    p2 = dist.MultivariateNormal(loc, scale_tril=scale.diag_embed())

    loc = torch.randn(batch_shape + (size,))
    cov = torch.randn(batch_shape + (size, size))
    cov = cov @ cov.transpose(-1, -2) + 0.01 * torch.eye(size)
    q = dist.MultivariateNormal(loc, covariance_matrix=cov)

    actual = kl_divergence(p1, q)
    expected = kl_divergence(p2, q)
    assert_close(actual, expected)


@pytest.mark.parametrize("shape", [(5,), (4, 5), (2, 3, 5)], ids=str)
@pytest.mark.parametrize("event_dim", [0, 1])
@pytest.mark.parametrize(
    "transform", [transforms.ExpTransform(), transforms.StickBreakingTransform()]
)
def test_kl_transformed_transformed(shape, event_dim, transform):
    p_base = dist.Normal(torch.zeros(shape), torch.ones(shape)).to_event(event_dim)
    q_base = dist.Normal(torch.ones(shape) * 2, torch.ones(shape)).to_event(event_dim)
    p = dist.TransformedDistribution(p_base, transform)
    q = dist.TransformedDistribution(q_base, transform)
    kl = kl_divergence(q, p)
    expected_shape = shape[:-1] if max(transform.event_dim, event_dim) == 1 else shape
    assert kl.shape == expected_shape