Esempio n. 1
0
def test_gaussian_tensordot(dot_dims, x_batch_shape, x_dim, x_rank,
                            y_batch_shape, y_dim, y_rank):
    x_rank = min(x_rank, x_dim)
    y_rank = min(y_rank, y_dim)
    x = random_gaussian(x_batch_shape, x_dim, x_rank)
    y = random_gaussian(y_batch_shape, y_dim, y_rank)
    na = x_dim - dot_dims
    nb = dot_dims
    nc = y_dim - dot_dims
    try:
        torch.linalg.cholesky(x.precision[..., na:, na:] +
                              y.precision[..., :nb, :nb])
    except RuntimeError:
        pytest.skip(
            "Cannot marginalize the common variables of two Gaussians.")

    z = gaussian_tensordot(x, y, dot_dims)
    assert z.dim() == x_dim + y_dim - 2 * dot_dims

    # We make these precision matrices positive definite to test the math
    x.precision = x.precision + 1e-1 * torch.eye(x.dim())
    y.precision = y.precision + 1e-1 * torch.eye(y.dim())
    z = gaussian_tensordot(x, y, dot_dims)
    # compare against broadcasting, adding, and marginalizing
    precision = pad(x.precision, (0, nc, 0, nc)) + pad(y.precision,
                                                       (na, 0, na, 0))
    info_vec = pad(x.info_vec, (0, nc)) + pad(y.info_vec, (na, 0))
    covariance = torch.inverse(precision)
    loc = (covariance.matmul(info_vec.unsqueeze(-1)).squeeze(-1)
           if info_vec.size(-1) > 0 else info_vec)
    z_covariance = torch.inverse(z.precision)
    z_loc = z_covariance.matmul(
        z.info_vec.view(z.info_vec.shape + (int(z.dim() > 0), ))).sum(-1)
    assert_close(loc[..., :na], z_loc[..., :na])
    assert_close(loc[..., x_dim:], z_loc[..., na:])
    assert_close(covariance[..., :na, :na], z_covariance[..., :na, :na])
    assert_close(covariance[..., :na, x_dim:], z_covariance[..., :na, na:])
    assert_close(covariance[..., x_dim:, :na], z_covariance[..., na:, :na])
    assert_close(covariance[..., x_dim:, x_dim:], z_covariance[..., na:, na:])

    # Assume a = c = 0, integrate out b
    # FIXME: this might be not a stable way to compute integral
    num_samples = 200000
    scale = 20
    # generate samples in [-10, 10]
    value_b = torch.rand((num_samples, ) + z.batch_shape +
                         (nb, )) * scale - scale / 2
    value_x = pad(value_b, (na, 0))
    value_y = pad(value_b, (0, nc))
    expect = torch.logsumexp(x.log_density(value_x) + y.log_density(value_y),
                             dim=0)
    expect += math.log(scale**nb / num_samples)
    actual = z.log_density(torch.zeros(z.batch_shape + (z.dim(), )))
    # TODO(fehiepsi): find some condition to make this test stable, so we can compare large value
    # log densities.
    assert_close(actual.clamp(max=10.0),
                 expect.clamp(max=10.0),
                 atol=0.1,
                 rtol=0.1)
Esempio n. 2
0
def test_gaussian_mrf_log_prob(sample_shape, batch_shape, num_steps,
                               hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_dist = random_mvn(batch_shape + (num_steps, ),
                            hidden_dim + hidden_dim)
    obs_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim + obs_dim)
    d = dist.GaussianMRF(init_dist, trans_dist, obs_dist)
    data = obs_dist.sample(sample_shape)[..., hidden_dim:]
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussians with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    # and then combine these using gaussian_tensordot().
    T = num_steps
    init = mvn_to_gaussian(init_dist)
    trans = mvn_to_gaussian(trans_dist)
    obs = mvn_to_gaussian(obs_dist)

    unrolled_trans = reduce(
        operator.add,
        [
            trans[..., t].event_pad(left=t * hidden_dim,
                                    right=(T - t - 1) * hidden_dim)
            for t in range(T)
        ],
    )
    unrolled_obs = reduce(
        operator.add,
        [
            obs[..., t].event_pad(left=t * obs.dim(),
                                  right=(T - t - 1) * obs.dim())
            for t in range(T)
        ],
    )
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat(
        [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
        [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, ))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp_h = gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp_oh = gaussian_tensordot(logp_h, unrolled_obs, T * hidden_dim)
    logp_h += unrolled_obs.marginalize(right=T * obs_dim)
    expected_log_prob = logp_oh.log_density(
        unrolled_data) - logp_h.event_logsumexp()
    assert_close(actual_log_prob, expected_log_prob)
Esempio n. 3
0
def test_gaussian_hmm_log_prob(diag, sample_shape, batch_shape, num_steps, hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim))
    trans_dist = random_mvn(batch_shape + (num_steps,), hidden_dim)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim))
    obs_dist = random_mvn(batch_shape + (num_steps,), obs_dim)
    if diag:
        scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1)
        obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1)
    d = dist.GaussianHMM(init_dist, trans_mat, trans_dist, obs_mat, obs_dist)
    if diag:
        obs_mvn = dist.MultivariateNormal(obs_dist.base_dist.loc,
                                          scale_tril=obs_dist.base_dist.scale.diag_embed())
    else:
        obs_mvn = obs_dist
    data = obs_dist.sample(sample_shape)
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussians with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    # and then combine these using gaussian_tensordot().
    T = num_steps
    init = mvn_to_gaussian(init_dist)
    trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist)
    obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn)

    unrolled_trans = reduce(operator.add, [
        trans[..., t].event_pad(left=t * hidden_dim, right=(T - t - 1) * hidden_dim)
        for t in range(T)
    ])
    unrolled_obs = reduce(operator.add, [
        obs[..., t].event_pad(left=t * obs.dim(), right=(T - t - 1) * obs.dim())
        for t in range(T)
    ])
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat([torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
                     [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim,))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp = gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim)
    expected_log_prob = logp.log_density(unrolled_data)
    assert_close(actual_log_prob, expected_log_prob)
Esempio n. 4
0
    def log_prob(self, value):
        # We compute a normalized distribution as p(obs,hidden) / p(hidden).
        logp_oh = self._trans
        logp_h = self._trans

        # Combine observation and transition factors.
        logp_oh += self._obs.condition(value).event_pad(left=self.hidden_dim)
        logp_h += self._obs.marginalize(right=self.obs_dim).event_pad(
            left=self.hidden_dim)

        # Concatenate p(obs,hidden) and p(hidden) into a single Gaussian.
        batch_dim = 1 + max(
            len(self._init.batch_shape) + 1, len(logp_oh.batch_shape))
        batch_shape = (1, ) * (batch_dim -
                               len(logp_oh.batch_shape)) + logp_oh.batch_shape
        logp = Gaussian.cat(
            [logp_oh.expand(batch_shape),
             logp_h.expand(batch_shape)])

        # Eliminate time dimension.
        logp = _sequential_gaussian_tensordot(logp)

        # Combine initial factor.
        logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim)

        # Marginalize out final state.
        logp_oh, logp_h = logp.event_logsumexp()
        return logp_oh - logp_h  # = log( p(obs,hidden) / p(hidden) )
Esempio n. 5
0
    def filter(self, value):
        """
        Compute posterior over final state given a sequence of observations.

        :param ~torch.Tensor value: A sequence of observations.
        :return: A posterior
            distribution over latent states at the final time step. ``result``
            can then be used as ``initial_dist`` in a sequential Pyro model for
            prediction.
        :rtype: ~pyro.distributions.MultivariateNormal
        """
        # Combine observation and transition factors.
        logp = self._trans + self._obs.condition(value).event_pad(
            left=self.hidden_dim)

        # Eliminate time dimension.
        logp = _sequential_gaussian_tensordot(logp.expand(logp.batch_shape))

        # Combine initial factor.
        logp = gaussian_tensordot(self._init, logp, dims=self.hidden_dim)

        # Convert to a distribution
        precision = logp.precision
        loc = logp.info_vec.unsqueeze(-1).cholesky_solve(
            precision.cholesky()).squeeze(-1)
        return MultivariateNormal(loc,
                                  precision_matrix=precision,
                                  validate_args=self._validate_args)
Esempio n. 6
0
def test_sequential_gaussian_tensordot(batch_shape, state_dim, num_steps):
    g = random_gaussian(batch_shape + (num_steps, ), state_dim + state_dim)
    actual = _sequential_gaussian_tensordot(g)
    assert actual.dim() == g.dim()
    assert actual.batch_shape == batch_shape

    # Check against hand computation.
    expected = g[..., 0]
    for t in range(1, num_steps):
        expected = gaussian_tensordot(expected, g[..., t], state_dim)
    assert_close_gaussian(actual, expected)
Esempio n. 7
0
    def log_prob(self, value):
        # Combine observation and transition factors.
        result = self._trans + self._obs.condition(value).event_pad(
            left=self.hidden_dim)

        # Eliminate time dimension.
        result = _sequential_gaussian_tensordot(
            result.expand(result.batch_shape))

        # Combine initial factor.
        result = gaussian_tensordot(self._init, result, dims=self.hidden_dim)

        # Marginalize out final state.
        result = result.event_logsumexp()
        return result
Esempio n. 8
0
def test_matrix_and_mvn_to_gaussian_2(sample_shape, batch_shape, x_dim, y_dim):
    matrix = torch.randn(batch_shape + (x_dim, y_dim))
    y_mvn = random_mvn(batch_shape, y_dim)
    x_mvn = random_mvn(batch_shape, x_dim)
    Mx_cov = matrix.transpose(-2, -1).matmul(
        x_mvn.covariance_matrix).matmul(matrix)
    Mx_loc = matrix.transpose(-2,
                              -1).matmul(x_mvn.loc.unsqueeze(-1)).squeeze(-1)
    mvn = dist.MultivariateNormal(Mx_loc + y_mvn.loc,
                                  Mx_cov + y_mvn.covariance_matrix)
    expected = mvn_to_gaussian(mvn)

    actual = gaussian_tensordot(mvn_to_gaussian(x_mvn),
                                matrix_and_mvn_to_gaussian(matrix, y_mvn),
                                dims=x_dim)
    assert_close_gaussian(expected, actual)
Esempio n. 9
0
def _sequential_gaussian_tensordot(gaussian):
    """
    Integrates a Gaussian ``x`` whose rightmost batch dimension is time, computes::

        x[..., 0] @ x[..., 1] @ ... @ x[..., T-1]
    """
    assert isinstance(gaussian, Gaussian)
    assert gaussian.dim() % 2 == 0, "dim is not even"
    batch_shape = gaussian.batch_shape[:-1]
    state_dim = gaussian.dim() // 2
    while gaussian.batch_shape[-1] > 1:
        time = gaussian.batch_shape[-1]
        even_time = time // 2 * 2
        even_part = gaussian[..., :even_time]
        x_y = even_part.reshape(batch_shape + (even_time // 2, 2))
        x, y = x_y[..., 0], x_y[..., 1]
        contracted = gaussian_tensordot(x, y, state_dim)
        if time > even_time:
            contracted = Gaussian.cat((contracted, gaussian[..., -1:]), dim=-1)
        gaussian = contracted
    return gaussian[..., 0]
Esempio n. 10
0
def test_gaussian_hmm_distribution(diag, sample_shape, batch_shape, num_steps,
                                   hidden_dim, obs_dim):
    init_dist = random_mvn(batch_shape, hidden_dim)
    trans_mat = torch.randn(batch_shape + (num_steps, hidden_dim, hidden_dim))
    trans_dist = random_mvn(batch_shape + (num_steps, ), hidden_dim)
    obs_mat = torch.randn(batch_shape + (num_steps, hidden_dim, obs_dim))
    obs_dist = random_mvn(batch_shape + (num_steps, ), obs_dim)
    if diag:
        scale = obs_dist.scale_tril.diagonal(dim1=-2, dim2=-1)
        obs_dist = dist.Normal(obs_dist.loc, scale).to_event(1)
    d = dist.GaussianHMM(init_dist,
                         trans_mat,
                         trans_dist,
                         obs_mat,
                         obs_dist,
                         duration=num_steps)
    if diag:
        obs_mvn = dist.MultivariateNormal(
            obs_dist.base_dist.loc,
            scale_tril=obs_dist.base_dist.scale.diag_embed())
    else:
        obs_mvn = obs_dist
    data = obs_dist.sample(sample_shape)
    assert data.shape == sample_shape + d.shape()
    actual_log_prob = d.log_prob(data)

    # Compare against hand-computed density.
    # We will construct enormous unrolled joint gaussians with shapes:
    #       t | 0 1 2 3 1 2 3      T = 3 in this example
    #   ------+-----------------------------------------
    #    init | H
    #   trans | H H H H            H = hidden
    #     obs |   H H H O O O      O = observed
    #    like |         O O O
    # and then combine these using gaussian_tensordot().
    T = num_steps
    init = mvn_to_gaussian(init_dist)
    trans = matrix_and_mvn_to_gaussian(trans_mat, trans_dist)
    obs = matrix_and_mvn_to_gaussian(obs_mat, obs_mvn)
    like_dist = dist.Normal(torch.randn(data.shape), 1).to_event(2)
    like = mvn_to_gaussian(like_dist)

    unrolled_trans = reduce(operator.add, [
        trans[..., t].event_pad(left=t * hidden_dim,
                                right=(T - t - 1) * hidden_dim)
        for t in range(T)
    ])
    unrolled_obs = reduce(operator.add, [
        obs[..., t].event_pad(left=t * obs.dim(),
                              right=(T - t - 1) * obs.dim()) for t in range(T)
    ])
    unrolled_like = reduce(operator.add, [
        like[..., t].event_pad(left=t * obs_dim, right=(T - t - 1) * obs_dim)
        for t in range(T)
    ])
    # Permute obs from HOHOHO to HHHOOO.
    perm = torch.cat(
        [torch.arange(hidden_dim) + t * obs.dim() for t in range(T)] +
        [torch.arange(obs_dim) + hidden_dim + t * obs.dim() for t in range(T)])
    unrolled_obs = unrolled_obs.event_permute(perm)
    unrolled_data = data.reshape(data.shape[:-2] + (T * obs_dim, ))

    assert init.dim() == hidden_dim
    assert unrolled_trans.dim() == (1 + T) * hidden_dim
    assert unrolled_obs.dim() == T * (hidden_dim + obs_dim)
    logp = gaussian_tensordot(init, unrolled_trans, hidden_dim)
    logp = gaussian_tensordot(logp, unrolled_obs, T * hidden_dim)
    expected_log_prob = logp.log_density(unrolled_data)
    assert_close(actual_log_prob, expected_log_prob)

    d_posterior, log_normalizer = d.conjugate_update(like_dist)
    assert_close(
        d.log_prob(data) + like_dist.log_prob(data),
        d_posterior.log_prob(data) + log_normalizer)

    if batch_shape or sample_shape:
        return

    # Test mean and covariance.
    prior = "prior", d, logp
    posterior = "posterior", d_posterior, logp + unrolled_like
    for name, d, g in [prior, posterior]:
        logging.info("testing {} moments".format(name))
        with torch.no_grad():
            num_samples = 100000
            samples = d.sample([num_samples]).reshape(num_samples, T * obs_dim)
            actual_mean = samples.mean(0)
            delta = samples - actual_mean
            actual_cov = (delta.unsqueeze(-1) * delta.unsqueeze(-2)).mean(0)
            actual_std = actual_cov.diagonal(dim1=-2, dim2=-1).sqrt()
            actual_corr = actual_cov / (actual_std.unsqueeze(-1) *
                                        actual_std.unsqueeze(-2))

            expected_cov = g.precision.cholesky().cholesky_inverse()
            expected_mean = expected_cov.matmul(
                g.info_vec.unsqueeze(-1)).squeeze(-1)
            expected_std = expected_cov.diagonal(dim1=-2, dim2=-1).sqrt()
            expected_corr = expected_cov / (expected_std.unsqueeze(-1) *
                                            expected_std.unsqueeze(-2))

            assert_close(actual_mean, expected_mean, atol=0.05, rtol=0.02)
            assert_close(actual_std, expected_std, atol=0.05, rtol=0.02)
            assert_close(actual_corr, expected_corr, atol=0.02)