Пример #1
0
def test_shapes(batch_shape):
    log_factor = torch.randn(batch_shape)

    d = dist.Unit(log_factor=log_factor)
    x = d.sample()
    assert x.shape == batch_shape + (0, )
    assert (d.log_prob(x) == log_factor).all()
Пример #2
0
def test_expand(sample_shape, batch_shape):
    log_factor = torch.randn(batch_shape)
    d1 = dist.Unit(log_factor)
    v1 = d1.sample()

    d2 = d1.expand(sample_shape + batch_shape)
    assert d2.batch_shape == sample_shape + batch_shape
    v2 = d2.sample()
    assert v2.shape == sample_shape + batch_shape + (0, )
    assert_equal(d1.log_prob(v2), d2.log_prob(v1))
Пример #3
0
def factor(name, log_factor):
    """
    Factor statement to add arbitrary log probability factor to a
    probabilisitic model.

    :param str name: Name of the trivial sample
    :param torch.Tensor log_factor: A possibly batched log probability factor.
    """
    unit_dist = dist.Unit(log_factor)
    unit_value = unit_dist.sample()
    sample(name, unit_dist, obs=unit_value)
Пример #4
0
def factor(name, log_factor, *, has_rsample=None):
    """
    Factor statement to add arbitrary log probability factor to a
    probabilisitic model.

    .. warning:: When using factor statements in guides, you'll need to specify
        whether the factor statement originated from fully reparametrized
        sampling (e.g. the Jacobian determinant of a transformation of a
        reparametrized variable) or from nonreparameterized sampling (e.g.
        discrete samples). For the fully reparametrized case, set
        ``has_rsample=True``; for the nonreparametrized case, set
        ``has_rsample=False``. This is needed only in guides, not in models.

    :param str name: Name of the trivial sample
    :param torch.Tensor log_factor: A possibly batched log probability factor.
    :param bool has_rsample: Whether the ``log_factor`` arose from a fully
        reparametrized distribution. Defaults to False when used in models, but
        must be specified for use in guides.
    """
    unit_dist = dist.Unit(log_factor, has_rsample=has_rsample)
    unit_value = unit_dist.sample()
    sample(name, unit_dist, obs=unit_value, infer={"is_auxiliary": True})