Exemple #1
0
def test_expand(extra_shape, log_normalizer_shape, info_vec_shape,
                precision_shape, dim):
    rank = dim + dim
    log_normalizer = torch.randn(log_normalizer_shape)
    info_vec = torch.randn(info_vec_shape + (dim, ))
    precision = torch.randn(precision_shape + (dim, rank))
    precision = precision.matmul(precision.transpose(-1, -2))
    gaussian = Gaussian(log_normalizer, info_vec, precision)

    expected_shape = extra_shape + broadcast_shape(
        log_normalizer_shape, info_vec_shape, precision_shape)
    actual = gaussian.expand(expected_shape)
    assert actual.batch_shape == expected_shape
Exemple #2
0
def random_gaussian(batch_shape, dim, rank=None):
    """
    Generate a random Gaussian for testing.
    """
    if rank is None:
        rank = dim + dim
    log_normalizer = torch.randn(batch_shape)
    info_vec = torch.randn(batch_shape + (dim, ))
    samples = torch.randn(batch_shape + (dim, rank))
    precision = torch.matmul(samples, samples.transpose(-2, -1))
    result = Gaussian(log_normalizer, info_vec, precision)
    assert result.dim() == dim
    assert result.batch_shape == batch_shape
    return result
Exemple #3
0
    def log_prob(self, value):
        # We compute a normalized distribution as p(obs,hidden) / p(hidden).
        logp_oh = self._trans
        logp_h = self._trans

        # Combine observation and transition factors.
        logp_oh += self._obs.condition(value).event_pad(left=self.hidden_dim)
        logp_h += self._obs.marginalize(right=self.obs_dim).event_pad(
            left=self.hidden_dim)

        # Concatenate p(obs,hidden) and p(hidden) into a single Gaussian.
        batch_dim = 1 + max(
            len(self._init.batch_shape) + 1, len(logp_oh.batch_shape))
        batch_shape = (1, ) * (batch_dim -
                               len(logp_oh.batch_shape)) + logp_oh.batch_shape
        logp = Gaussian.cat(
            [logp_oh.expand(batch_shape),
             logp_h.expand(batch_shape)])

        # Eliminate time dimension.
        logp = _sequential_gaussian_tensordot(logp)

        # Combine initial factor.
        logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim)

        # Marginalize out final state.
        logp_oh, logp_h = logp.event_logsumexp()
        return logp_oh - logp_h  # = log( p(obs,hidden) / p(hidden) )
Exemple #4
0
def test_cat(shape, cat_dim, split, dim):
    assert sum(split) == shape[cat_dim]
    gaussian = random_gaussian(shape, dim)
    parts = []
    end = 0
    for size in split:
        beg, end = end, end + size
        if cat_dim == -1:
            part = gaussian[..., beg:end]
        elif cat_dim == -2:
            part = gaussian[..., beg:end, :]
        elif cat_dim == 1:
            part = gaussian[:, beg:end]
        else:
            raise ValueError
        parts.append(part)

    actual = Gaussian.cat(parts, cat_dim)
    assert_close_gaussian(actual, gaussian)
Exemple #5
0
def _sequential_gaussian_tensordot(gaussian):
    """
    Integrates a Gaussian ``x`` whose rightmost batch dimension is time, computes::

        x[..., 0] @ x[..., 1] @ ... @ x[..., T-1]
    """
    assert isinstance(gaussian, Gaussian)
    assert gaussian.dim() % 2 == 0, "dim is not even"
    batch_shape = gaussian.batch_shape[:-1]
    state_dim = gaussian.dim() // 2
    while gaussian.batch_shape[-1] > 1:
        time = gaussian.batch_shape[-1]
        even_time = time // 2 * 2
        even_part = gaussian[..., :even_time]
        x_y = even_part.reshape(batch_shape + (even_time // 2, 2))
        x, y = x_y[..., 0], x_y[..., 1]
        contracted = gaussian_tensordot(x, y, state_dim)
        if time > even_time:
            contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1)
        gaussian = contracted
    return gaussian[..., 0]
Exemple #6
0
def test_gaussian_funsor(batch_shape):
    # This tests sample distribution, rsample gradients, log_prob, and log_prob
    # gradients for both Pyro's and Funsor's Gaussian.
    import funsor

    funsor.set_backend("torch")
    num_samples = 100000

    # Declare unconstrained parameters.
    loc = torch.randn(batch_shape + (3, )).requires_grad_()
    t = transform_to(constraints.positive_definite)
    m = torch.randn(batch_shape + (3, 3))
    precision_unconstrained = t.inv(m @ m.transpose(-1, -2)).requires_grad_()

    # Transform to constrained space.
    log_normalizer = torch.zeros(batch_shape)
    precision = t(precision_unconstrained)
    info_vec = (precision @ loc[..., None])[..., 0]

    def check_equal(actual, expected, atol=0.01, rtol=0):
        assert_close(actual.data, expected.data, atol=atol, rtol=rtol)
        grads = torch.autograd.grad(
            (actual - expected).abs().sum(),
            [loc, precision_unconstrained],
            retain_graph=True,
        )
        for grad in grads:
            assert grad.abs().max() < atol

    entropy = dist.MultivariateNormal(loc,
                                      precision_matrix=precision).entropy()

    # Monte carlo estimate entropy via pyro.
    p_gaussian = Gaussian(log_normalizer, info_vec, precision)
    p_log_Z = p_gaussian.event_logsumexp()
    p_rsamples = p_gaussian.rsample((num_samples, ))
    pp_entropy = (p_log_Z - p_gaussian.log_density(p_rsamples)).mean(0)
    check_equal(pp_entropy, entropy)

    # Monte carlo estimate entropy via funsor.
    inputs = OrderedDict([(k, funsor.Bint[v])
                          for k, v in zip("ij", batch_shape)])
    inputs["x"] = funsor.Reals[3]
    f_gaussian = funsor.gaussian.Gaussian(mean=loc,
                                          precision=precision,
                                          inputs=inputs)
    f_log_Z = f_gaussian.reduce(funsor.ops.logaddexp, "x")
    sample_inputs = OrderedDict(particle=funsor.Bint[num_samples])
    deltas = f_gaussian.sample("x", sample_inputs)
    f_rsamples = funsor.montecarlo.extract_samples(deltas)["x"]
    ff_entropy = (f_log_Z - f_gaussian(x=f_rsamples)).reduce(
        funsor.ops.mean, "particle")
    check_equal(ff_entropy.data, entropy)

    # Check Funsor's .rsample against Pyro's .log_prob.
    pf_entropy = (p_log_Z - p_gaussian.log_density(f_rsamples.data)).mean(0)
    check_equal(pf_entropy, entropy)

    # Check Pyro's .rsample against Funsor's .log_prob.
    fp_rsamples = funsor.Tensor(p_rsamples)["particle"]
    for i in "ij"[:len(batch_shape)]:
        fp_rsamples = fp_rsamples[i]
    fp_entropy = (f_log_Z - f_gaussian(x=fp_rsamples)).reduce(
        funsor.ops.mean, "particle")
    check_equal(fp_entropy.data, entropy)