Ejemplo n.º 1
0
def test_gaussian_tensordot(dot_dims, x_batch_shape, x_dim, x_rank,
                            y_batch_shape, y_dim, y_rank):
    x_rank = min(x_rank, x_dim)
    y_rank = min(y_rank, y_dim)
    x = random_gaussian(x_batch_shape, x_dim, x_rank)
    y = random_gaussian(y_batch_shape, y_dim, y_rank)
    na = x_dim - dot_dims
    nb = dot_dims
    nc = y_dim - dot_dims
    try:
        torch.linalg.cholesky(x.precision[..., na:, na:] +
                              y.precision[..., :nb, :nb])
    except RuntimeError:
        pytest.skip(
            "Cannot marginalize the common variables of two Gaussians.")

    z = gaussian_tensordot(x, y, dot_dims)
    assert z.dim() == x_dim + y_dim - 2 * dot_dims

    # We make these precision matrices positive definite to test the math
    x.precision = x.precision + 1e-1 * torch.eye(x.dim())
    y.precision = y.precision + 1e-1 * torch.eye(y.dim())
    z = gaussian_tensordot(x, y, dot_dims)
    # compare against broadcasting, adding, and marginalizing
    precision = pad(x.precision, (0, nc, 0, nc)) + pad(y.precision,
                                                       (na, 0, na, 0))
    info_vec = pad(x.info_vec, (0, nc)) + pad(y.info_vec, (na, 0))
    covariance = torch.inverse(precision)
    loc = (covariance.matmul(info_vec.unsqueeze(-1)).squeeze(-1)
           if info_vec.size(-1) > 0 else info_vec)
    z_covariance = torch.inverse(z.precision)
    z_loc = z_covariance.matmul(
        z.info_vec.view(z.info_vec.shape + (int(z.dim() > 0), ))).sum(-1)
    assert_close(loc[..., :na], z_loc[..., :na])
    assert_close(loc[..., x_dim:], z_loc[..., na:])
    assert_close(covariance[..., :na, :na], z_covariance[..., :na, :na])
    assert_close(covariance[..., :na, x_dim:], z_covariance[..., :na, na:])
    assert_close(covariance[..., x_dim:, :na], z_covariance[..., na:, :na])
    assert_close(covariance[..., x_dim:, x_dim:], z_covariance[..., na:, na:])

    # Assume a = c = 0, integrate out b
    # FIXME: this might be not a stable way to compute integral
    num_samples = 200000
    scale = 20
    # generate samples in [-10, 10]
    value_b = torch.rand((num_samples, ) + z.batch_shape +
                         (nb, )) * scale - scale / 2
    value_x = pad(value_b, (na, 0))
    value_y = pad(value_b, (0, nc))
    expect = torch.logsumexp(x.log_density(value_x) + y.log_density(value_y),
                             dim=0)
    expect += math.log(scale**nb / num_samples)
    actual = z.log_density(torch.zeros(z.batch_shape + (z.dim(), )))
    # TODO(fehiepsi): find some condition to make this test stable, so we can compare large value
    # log densities.
    assert_close(actual.clamp(max=10.0),
                 expect.clamp(max=10.0),
                 atol=0.1,
                 rtol=0.1)
Ejemplo n.º 2
0
def test_marginalize_condition(sample_shape, batch_shape, left, right):
    dim = left + right
    g = random_gaussian(batch_shape, dim)
    x = torch.randn(sample_shape + (1, ) * len(batch_shape) + (right, ))
    assert_close(
        g.marginalize(left=left).log_density(x),
        g.condition(x).event_logsumexp())
Ejemplo n.º 3
0
def test_marginalize(batch_shape, left, right):
    dim = left + right
    g = random_gaussian(batch_shape, dim)
    assert_close(
        g.marginalize(left=left).event_logsumexp(), g.event_logsumexp())
    assert_close(
        g.marginalize(right=right).event_logsumexp(), g.event_logsumexp())
Ejemplo n.º 4
0
def test_pad(shape, left, right, dim):
    expected = random_gaussian(shape, dim)
    padded = expected.event_pad(left=left, right=right)
    assert padded.batch_shape == expected.batch_shape
    assert padded.dim() == left + expected.dim() + right
    mid = slice(left, padded.dim() - right)
    assert_close(padded.info_vec[..., mid], expected.info_vec)
    assert_close(padded.precision[..., mid, mid], expected.precision)
Ejemplo n.º 5
0
def test_reshape(old_shape, new_shape, dim):
    gaussian = random_gaussian(old_shape, dim)

    # reshape to new
    new = gaussian.reshape(new_shape)
    assert new.batch_shape == new_shape

    # reshape back to old
    g = new.reshape(old_shape)
    assert_close_gaussian(g, gaussian)
Ejemplo n.º 6
0
def test_sequential_gaussian_tensordot(batch_shape, state_dim, num_steps):
    g = random_gaussian(batch_shape + (num_steps, ), state_dim + state_dim)
    actual = _sequential_gaussian_tensordot(g)
    assert actual.dim() == g.dim()
    assert actual.batch_shape == batch_shape

    # Check against hand computation.
    expected = g[..., 0]
    for t in range(1, num_steps):
        expected = gaussian_tensordot(expected, g[..., t], state_dim)
    assert_close_gaussian(actual, expected)
Ejemplo n.º 7
0
def test_logsumexp(batch_shape, dim):
    gaussian = random_gaussian(batch_shape, dim)
    gaussian.info_vec *= 0.1  # approximately centered
    gaussian.precision += torch.eye(dim) * 0.1

    num_samples = 200000
    scale = 10
    samples = torch.rand((num_samples, ) + (1, ) * len(batch_shape) +
                         (dim, )) * scale - scale / 2
    expected = gaussian.log_density(samples).logsumexp(0) + math.log(
        scale**dim / num_samples)
    actual = gaussian.event_logsumexp()
    assert_close(actual, expected, atol=0.05, rtol=0.05)
Ejemplo n.º 8
0
def test_condition(sample_shape, batch_shape, left, right):
    dim = left + right
    gaussian = random_gaussian(batch_shape, dim)
    gaussian.precision += torch.eye(dim) * 0.1
    value = torch.randn(sample_shape + (1, ) * len(batch_shape) + (dim, ))
    left_value, right_value = value[..., :left], value[..., left:]

    conditioned = gaussian.condition(right_value)
    assert conditioned.batch_shape == sample_shape + gaussian.batch_shape
    assert conditioned.dim() == left

    actual = conditioned.log_density(left_value)
    expected = gaussian.log_density(value)
    assert_close(actual, expected)
Ejemplo n.º 9
0
def test_cat(shape, cat_dim, split, dim):
    assert sum(split) == shape[cat_dim]
    gaussian = random_gaussian(shape, dim)
    parts = []
    end = 0
    for size in split:
        beg, end = end, end + size
        if cat_dim == -1:
            part = gaussian[..., beg:end]
        elif cat_dim == -2:
            part = gaussian[..., beg:end, :]
        elif cat_dim == 1:
            part = gaussian[:, beg:end]
        else:
            raise ValueError
        parts.append(part)

    actual = Gaussian.cat(parts, cat_dim)
    assert_close_gaussian(actual, gaussian)
Ejemplo n.º 10
0
def test_sequential_gaussian_filter_sample(sample_shape, batch_shape,
                                           state_dim, num_steps):
    init = random_gaussian(batch_shape, state_dim)
    trans = random_gaussian(batch_shape + (num_steps, ), state_dim + state_dim)
    sample = _sequential_gaussian_filter_sample(init, trans, sample_shape)
    assert sample.shape == sample_shape + batch_shape + (num_steps, state_dim)
Ejemplo n.º 11
0
def test_marginalize_shape(batch_shape, left, right):
    dim = left + right
    g = random_gaussian(batch_shape, dim)
    assert g.marginalize(left=left).dim() == right
    assert g.marginalize(right=right).dim() == left
Ejemplo n.º 12
0
def test_add(shape, dim):
    x = random_gaussian(shape, dim)
    y = random_gaussian(shape, dim)
    value = torch.randn(dim)
    assert_close((x + y).log_density(value),
                 x.log_density(value) + y.log_density(value))