Ejemplo n.º 1
0
def test_expand(extra_shape, log_normalizer_shape, info_vec_shape,
                precision_shape, alpha_shape, beta_shape, dim):
    rank = dim + dim
    log_normalizer = torch.randn(log_normalizer_shape)
    info_vec = torch.randn(info_vec_shape + (dim, ))
    precision = torch.randn(precision_shape + (dim, rank))
    precision = precision.matmul(precision.transpose(-1, -2))
    alpha = torch.randn(alpha_shape).exp()
    beta = torch.randn(beta_shape).exp()
    gamma_gaussian = GammaGaussian(log_normalizer, info_vec, precision, alpha,
                                   beta)

    expected_shape = extra_shape + broadcast_shape(
        log_normalizer_shape, info_vec_shape, precision_shape, alpha_shape,
        beta_shape)
    actual = gamma_gaussian.expand(expected_shape)
    assert actual.batch_shape == expected_shape
Ejemplo n.º 2
0
def random_gamma_gaussian(batch_shape, dim, rank=None):
    """
    Generate a random Gaussian for testing.
    """
    if rank is None:
        rank = dim + dim
    log_normalizer = torch.randn(batch_shape)
    loc = torch.randn(batch_shape + (dim, ))
    samples = torch.randn(batch_shape + (dim, rank))
    precision = torch.matmul(samples, samples.transpose(-2, -1))
    if dim > 0:
        info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1)
    else:
        info_vec = loc
    alpha = torch.randn(batch_shape).exp() + 0.5 * dim - 1
    beta = torch.randn(batch_shape).exp() + 0.5 * (info_vec * loc).sum(-1)
    result = GammaGaussian(log_normalizer, info_vec, precision, alpha, beta)
    assert result.dim() == dim
    assert result.batch_shape == batch_shape
    return result
Ejemplo n.º 3
0
def test_cat(shape, cat_dim, split, dim):
    assert sum(split) == shape[cat_dim]
    gamma_gaussian = random_gamma_gaussian(shape, dim)
    parts = []
    end = 0
    for size in split:
        beg, end = end, end + size
        if cat_dim == -1:
            part = gamma_gaussian[..., beg:end]
        elif cat_dim == -2:
            part = gamma_gaussian[..., beg:end, :]
        elif cat_dim == 1:
            part = gamma_gaussian[:, beg:end]
        else:
            raise ValueError
        parts.append(part)

    actual = GammaGaussian.cat(parts, cat_dim)
    assert_close_gamma_gaussian(actual, gamma_gaussian)
Ejemplo n.º 4
0
def _sequential_gamma_gaussian_tensordot(gamma_gaussian):
    """
    Integrates a GammaGaussian ``x`` whose rightmost batch dimension is time, computes::

        x[..., 0] @ x[..., 1] @ ... @ x[..., T-1]
    """
    assert isinstance(gamma_gaussian, GammaGaussian)
    assert gamma_gaussian.dim() % 2 == 0, "dim is not even"
    batch_shape = gamma_gaussian.batch_shape[:-1]
    state_dim = gamma_gaussian.dim() // 2
    while gamma_gaussian.batch_shape[-1] > 1:
        time = gamma_gaussian.batch_shape[-1]
        even_time = time // 2 * 2
        even_part = gamma_gaussian[..., :even_time]
        x_y = even_part.reshape(batch_shape + (even_time // 2, 2))
        x, y = x_y[..., 0], x_y[..., 1]
        contracted = gamma_gaussian_tensordot(x, y, state_dim)
        if time > even_time:
            contracted = GammaGaussian.cat(
                (contracted, gamma_gaussian[..., -1:]), dim=-1)
        gamma_gaussian = contracted
    return gamma_gaussian[..., 0]