Exemple #1
0
def guide(data):
    loc_eta = torch.randn(J, 1)
    # note that we initialize our scales to be pretty narrow
    scale_eta = 0.1 * torch.rand(J, 1)
    loc_mu = torch.randn(1)
    scale_mu = 0.1 * torch.rand(1)
    loc_logtau = torch.randn(1)
    scale_logtau = 0.1 * torch.rand(1)

    # register learnable params in the param store
    m_eta_param = pyro.param("loc_eta", loc_eta)
    s_eta_param = pyro.param("scale_eta", scale_eta, constraint=constraints.positive)
    m_mu_param = pyro.param("loc_mu", loc_mu)
    s_mu_param = pyro.param("scale_mu", scale_mu, constraint=constraints.positive)
    m_logtau_param = pyro.param("loc_logtau", loc_logtau)
    s_logtau_param = pyro.param("scale_logtau", scale_logtau, constraint=constraints.positive)

    # guide distributions
    dist_eta = dist.Normal(m_eta_param, s_eta_param)
    dist_mu = dist.Normal(m_mu_param, s_mu_param)
    dist_tau = dist.TransformedDistribution(dist.Normal(m_logtau_param, s_logtau_param),
                                            transforms=transforms.ExpTransform())

    pyro.sample('eta', dist_eta)
    pyro.sample('mu', dist_mu)
    pyro.sample('tau', dist_tau)
Exemple #2
0
    def test_transform_forward(self):
        x = torch.rand(3, 3)
        t = transforms.ExpTransform()
        p = Param(x, transform=t)

        actual_forward = p.transform()
        assert all(
            [
                e.data.numpy() == pytest.approx(a.data.numpy())
                for e, a in zip(x.flatten(), actual_forward.flatten())
            ]
        )
Exemple #3
0
    def test_transform_inverse(self):
        x = torch.rand(3, 3)
        t = transforms.ExpTransform()
        p = Param(x, transform=t)

        expected_data = t.inv(x)
        actual_data = p.data
        assert all(
            [
                e.data.numpy() == pytest.approx(a.data.numpy())
                for e, a in zip(expected_data.flatten(), actual_data.flatten())
            ]
        )
def _transform_to_less_than(constraint):
    return transforms.ComposeTransform([
        transforms.ExpTransform(),
        transforms.AffineTransform(constraint.upper_bound, -1)
    ])
def _transform_to_greater_than(constraint):
    return transforms.ComposeTransform([
        transforms.ExpTransform(),
        transforms.AffineTransform(constraint.lower_bound, 1)
    ])
def _transform_to_positive(constraint):
    return transforms.ExpTransform()
Exemple #7
0
    p1 = dist.Normal(loc, scale).to_event(1)
    p2 = dist.MultivariateNormal(loc, scale_tril=scale.diag_embed())

    loc = torch.randn(batch_shape + (size, ))
    cov = torch.randn(batch_shape + (size, size))
    cov = cov @ cov.transpose(-1, -2) + 0.01 * torch.eye(size)
    q = dist.MultivariateNormal(loc, covariance_matrix=cov)

    actual = kl_divergence(p1, q)
    expected = kl_divergence(p2, q)
    assert_close(actual, expected)


@pytest.mark.parametrize('shape', [(5, ), (4, 5), (2, 3, 5)], ids=str)
@pytest.mark.parametrize('event_dim', [0, 1])
@pytest.mark.parametrize(
    'transform',
    [transforms.ExpTransform(),
     transforms.StickBreakingTransform()])
def test_kl_transformed_transformed(shape, event_dim, transform):
    p_base = dist.Normal(torch.zeros(shape),
                         torch.ones(shape)).to_event(event_dim)
    q_base = dist.Normal(torch.ones(shape) * 2,
                         torch.ones(shape)).to_event(event_dim)
    p = dist.TransformedDistribution(p_base, transform)
    q = dist.TransformedDistribution(q_base, transform)
    kl = kl_divergence(q, p)
    expected_shape = shape[:-1] if max(transform.event_dim,
                                       event_dim) == 1 else shape
    assert kl.shape == expected_shape
def _transform_to_less_than(constraint):
    loc = constraint.upper_bound
    scale = loc.new([-1]).expand_as(loc)
    return transforms.ComposeTransform(
        [transforms.ExpTransform(),
         transforms.AffineTransform(loc, scale)])
            raise NotImplementedError('Cannot transform {} constraints'.format(
                type(constraint).__name__))
        return factory(constraint)


biject_to = ConstraintRegistry()
transform_to = ConstraintRegistry()

################################################################################
# Registration Table
################################################################################

biject_to.register(constraints.real, transforms.identity_transform)
transform_to.register(constraints.real, transforms.identity_transform)

biject_to.register(constraints.positive, transforms.ExpTransform())
transform_to.register(constraints.positive, transforms.ExpTransform())


@biject_to.register(constraints.greater_than)
@transform_to.register(constraints.greater_than)
def _transform_to_greater_than(constraint):
    loc = constraint.lower_bound
    scale = loc.new([1]).expand_as(loc)
    return transforms.ComposeTransform(
        [transforms.ExpTransform(),
         transforms.AffineTransform(loc, scale)])


@biject_to.register(constraints.less_than)
@transform_to.register(constraints.less_than)
Exemple #10
0
def _transform_to_less_than(constraint):
    loc, scale = broadcast_all(constraint.upper_bound, -1)
    return transforms.ComposeTransform(
        [transforms.ExpTransform(),
         transforms.AffineTransform(loc, scale)])
Exemple #11
0
def _transform_to_greater_than(constraint):
    loc, scale = broadcast_all(constraint.lower_bound, 1)
    return transforms.ComposeTransform(
        [transforms.ExpTransform(),
         transforms.AffineTransform(loc, scale)])
Exemple #12
0
@pytest.mark.parametrize("size", [1, 2, 3])
def test_kl_independent_normal_mvn(batch_shape, size):
    loc = torch.randn(batch_shape + (size,))
    scale = torch.randn(batch_shape + (size,)).exp()
    p1 = dist.Normal(loc, scale).to_event(1)
    p2 = dist.MultivariateNormal(loc, scale_tril=scale.diag_embed())

    loc = torch.randn(batch_shape + (size,))
    cov = torch.randn(batch_shape + (size, size))
    cov = cov @ cov.transpose(-1, -2) + 0.01 * torch.eye(size)
    q = dist.MultivariateNormal(loc, covariance_matrix=cov)

    actual = kl_divergence(p1, q)
    expected = kl_divergence(p2, q)
    assert_close(actual, expected)


@pytest.mark.parametrize("shape", [(5,), (4, 5), (2, 3, 5)], ids=str)
@pytest.mark.parametrize("event_dim", [0, 1])
@pytest.mark.parametrize(
    "transform", [transforms.ExpTransform(), transforms.StickBreakingTransform()]
)
def test_kl_transformed_transformed(shape, event_dim, transform):
    p_base = dist.Normal(torch.zeros(shape), torch.ones(shape)).to_event(event_dim)
    q_base = dist.Normal(torch.ones(shape) * 2, torch.ones(shape)).to_event(event_dim)
    p = dist.TransformedDistribution(p_base, transform)
    q = dist.TransformedDistribution(q_base, transform)
    kl = kl_divergence(q, p)
    expected_shape = shape[:-1] if max(transform.event_dim, event_dim) == 1 else shape
    assert kl.shape == expected_shape
Exemple #13
0
    def test_init(self):
        x = torch.eye(3) + torch.ones(3, 3)

        Param(x)
        Param(x, transform=transforms.ExpTransform())
        Param(x, transform=transforms.LowerCholeskyTransform())