Ejemplo n.º 1
0
def test_gamma_gaussian_hmm_log_prob(sample_shape, batch_shape, num_steps,
                                     hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim))
    trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim))
    obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim)
    scale_dist = random_gamma(batch_shape)
    d = dist.GammaGaussianHMM(scale_dist, init_dist, trans_mat, trans_dist,
                              obs_mat, obs_dist)
    obs_mvn = obs_dist
    data = obs_dist.sample(sample_shape)
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussian-gammas with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    # and then combine these using gamma_gaussian_tensordot().
    T = num_steps
    init = gamma_and_mvn_to_gamma_gaussian(scale_dist, init_dist)
    trans = matrix_and_mvn_to_gamma_gaussian(trans_mat, trans_dist)
    obs = matrix_and_mvn_to_gamma_gaussian(obs_mat, obs_mvn)

    unrolled_trans = reduce(
        operator.add,
        [
            trans[..., t].event_pad(left=t * hidden_dim,
                                    right=(T - t - 1) * hidden_dim)
            for t in range(T)
        ],
    )
    unrolled_obs = reduce(
        operator.add,
        [
            obs[..., t].event_pad(left=t * obs.dim(),
                                  right=(T - t - 1) * obs.dim())
            for t in range(T)
        ],
    )
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat(
        [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
        [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, ))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp = gamma_gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp = gamma_gaussian_tensordot(logp, unrolled_obs, T * hidden_dim)
    # compute log_prob of the joint student-t distribution
    expected_log_prob = logp.compound().log_prob(unrolled_data)
    assert_close(actual_log_prob, expected_log_prob)
Ejemplo n.º 2
0
def test_gamma_gaussian_tensordot(dot_dims, x_batch_shape, x_dim, x_rank,
                                  y_batch_shape, y_dim, y_rank):
    x_rank = min(x_rank, x_dim)
    y_rank = min(y_rank, y_dim)
    x = random_gamma_gaussian(x_batch_shape, x_dim, x_rank)
    y = random_gamma_gaussian(y_batch_shape, y_dim, y_rank)
    na = x_dim - dot_dims
    nb = dot_dims
    nc = y_dim - dot_dims
    try:
        torch.linalg.cholesky(x.precision[..., na:, na:] +
                              y.precision[..., :nb, :nb])
    except RuntimeError:
        pytest.skip(
            "Cannot marginalize the common variables of two Gaussians.")

    z = gamma_gaussian_tensordot(x, y, dot_dims)
    assert z.dim() == x_dim + y_dim - 2 * dot_dims

    # We make these precision matrices positive definite to test the math
    x.precision = x.precision + 3 * torch.eye(x.dim())
    y.precision = y.precision + 3 * torch.eye(y.dim())
    z = gamma_gaussian_tensordot(x, y, dot_dims)
    # compare against broadcasting, adding, and marginalizing
    precision = pad(x.precision, (0, nc, 0, nc)) + pad(y.precision,
                                                       (na, 0, na, 0))
    info_vec = pad(x.info_vec, (0, nc)) + pad(y.info_vec, (na, 0))
    covariance = torch.inverse(precision)
    loc = (covariance.matmul(info_vec.unsqueeze(-1)).squeeze(-1)
           if info_vec.size(-1) > 0 else info_vec)
    z_covariance = torch.inverse(z.precision)
    z_loc = z_covariance.matmul(
        z.info_vec.view(z.info_vec.shape + (int(z.dim() > 0), ))).sum(-1)
    assert_close(loc[..., :na], z_loc[..., :na])
    assert_close(loc[..., x_dim:], z_loc[..., na:])
    assert_close(covariance[..., :na, :na], z_covariance[..., :na, :na])
    assert_close(covariance[..., :na, x_dim:], z_covariance[..., :na, na:])
    assert_close(covariance[..., x_dim:, :na], z_covariance[..., na:, :na])
    assert_close(covariance[..., x_dim:, x_dim:], z_covariance[..., na:, na:])

    s = torch.randn(z.batch_shape).exp()
    # Assume a = c = 0, integrate out b
    num_samples = 200000
    scale = 10
    # generate samples in [-10, 10]
    value_b = torch.rand((num_samples, ) + z.batch_shape +
                         (nb, )) * scale - scale / 2
    value_x = pad(value_b, (na, 0))
    value_y = pad(value_b, (0, nc))
    expect = torch.logsumexp(x.log_density(value_x, s) +
                             y.log_density(value_y, s),
                             dim=0)
    expect += math.log(scale**nb / num_samples)
    actual = z.log_density(torch.zeros(z.batch_shape + (z.dim(), )), s)
    assert_close(actual.clamp(max=10.0),
                 expect.clamp(max=10.0),
                 atol=0.1,
                 rtol=0.1)
Ejemplo n.º 3
0
def test_sequential_gamma_gaussian_tensordot(batch_shape, state_dim,
                                             num_steps):
    g = random_gamma_gaussian(batch_shape + (num_steps, ),
                              state_dim + state_dim)
    actual = _sequential_gamma_gaussian_tensordot(g)
    assert actual.dim() == g.dim()
    assert actual.batch_shape == batch_shape

    # Check against hand computation.
    expected = g[..., 0]
    for t in range(1, num_steps):
        expected = gamma_gaussian_tensordot(expected, g[..., t], state_dim)
    assert_close_gamma_gaussian(actual, expected)
Ejemplo n.º 4
0
    def log_prob(self, value):
        # Combine observation and transition factors.
        result = self._trans + self._obs.condition(value).event_pad(
            left=self.hidden_dim)

        # Eliminate time dimension.
        result = _sequential_gamma_gaussian_tensordot(
            result.expand(result.batch_shape))

        # Combine initial factor.
        result = gamma_gaussian_tensordot(self._init,
                                          result,
                                          dims=self.hidden_dim)

        # Marginalize out final state.
        result = result.event_logsumexp()

        # Marginalize out multiplier.
        result = result.logsumexp()
        return result
Ejemplo n.º 5
0
def _sequential_gamma_gaussian_tensordot(gamma_gaussian):
    """
    Integrates a GammaGaussian ``x`` whose rightmost batch dimension is time, computes::

        x[..., 0] @ x[..., 1] @ ... @ x[..., T-1]
    """
    assert isinstance(gamma_gaussian, GammaGaussian)
    assert gamma_gaussian.dim() % 2 == 0, "dim is not even"
    batch_shape = gamma_gaussian.batch_shape[:-1]
    state_dim = gamma_gaussian.dim() // 2
    while gamma_gaussian.batch_shape[-1] > 1:
        time = gamma_gaussian.batch_shape[-1]
        even_time = time // 2 * 2
        even_part = gamma_gaussian[..., :even_time]
        x_y = even_part.reshape(batch_shape + (even_time // 2, 2))
        x, y = x_y[..., 0], x_y[..., 1]
        contracted = gamma_gaussian_tensordot(x, y, state_dim)
        if time > even_time:
            contracted = GammaGaussian.cat(
                (contracted, gamma_gaussian[..., -1:]), dim=-1)
        gamma_gaussian = contracted
    return gamma_gaussian[..., 0]
Ejemplo n.º 6
0
    def filter(self, value):
        """
        Compute posteriors over the multiplier and the final state
        given a sequence of observations. The posterior is a pair of
        Gamma and MultivariateNormal distributions (i.e. a GammaGaussian
        instance).

        :param ~torch.Tensor value: A sequence of observations.
        :return: A pair of posterior distributions over the mixing and the latent
            state at the final time step.
        :rtype: a tuple of ~pyro.distributions.Gamma and ~pyro.distributions.MultivariateNormal
        """
        # Combine observation and transition factors.
        logp = self._trans + self._obs.condition(value).event_pad(
            left=self.hidden_dim)

        # Eliminate time dimension.
        logp = _sequential_gamma_gaussian_tensordot(
            logp.expand(logp.batch_shape))

        # Combine initial factor.
        logp = gamma_gaussian_tensordot(self._init, logp, dims=self.hidden_dim)

        # Posterior of the scale
        gamma_dist = logp.event_logsumexp()
        scale_post = Gamma(gamma_dist.concentration,
                           gamma_dist.rate,
                           validate_args=self._validate_args)
        # Conditional of last state on unit scale
        scale_tril = logp.precision.cholesky()
        loc = logp.info_vec.unsqueeze(-1).cholesky_solve(scale_tril).squeeze(
            -1)
        mvn = MultivariateNormal(loc,
                                 scale_tril=scale_tril,
                                 validate_args=self._validate_args)
        return scale_post, mvn