Example #1
0
def test_reshape(old_shape, new_shape, dim):
    gaussian = random_gaussian(old_shape, dim)

    # reshape to new
    new = gaussian.reshape(new_shape)
    assert new.batch_shape == new_shape

    # reshape back to old
    g = new.reshape(old_shape)
    assert_close_gaussian(g, gaussian)
Example #2
0
def test_sequential_gaussian_tensordot(batch_shape, state_dim, num_steps):
    g = random_gaussian(batch_shape + (num_steps, ), state_dim + state_dim)
    actual = _sequential_gaussian_tensordot(g)
    assert actual.dim() == g.dim()
    assert actual.batch_shape == batch_shape

    # Check against hand computation.
    expected = g[..., 0]
    for t in range(1, num_steps):
        expected = gaussian_tensordot(expected, g[..., t], state_dim)
    assert_close_gaussian(actual, expected)
Example #3
0
def test_matrix_and_mvn_to_gaussian_2(sample_shape, batch_shape, x_dim, y_dim):
    matrix = torch.randn(batch_shape + (x_dim, y_dim))
    y_mvn = random_mvn(batch_shape, y_dim)
    x_mvn = random_mvn(batch_shape, x_dim)
    Mx_cov = matrix.transpose(-2, -1).matmul(
        x_mvn.covariance_matrix).matmul(matrix)
    Mx_loc = matrix.transpose(-2,
                              -1).matmul(x_mvn.loc.unsqueeze(-1)).squeeze(-1)
    mvn = dist.MultivariateNormal(Mx_loc + y_mvn.loc,
                                  Mx_cov + y_mvn.covariance_matrix)
    expected = mvn_to_gaussian(mvn)

    actual = gaussian_tensordot(mvn_to_gaussian(x_mvn),
                                matrix_and_mvn_to_gaussian(matrix, y_mvn),
                                dims=x_dim)
    assert_close_gaussian(expected, actual)
Example #4
0
def test_cat(shape, cat_dim, split, dim):
    assert sum(split) == shape[cat_dim]
    gaussian = random_gaussian(shape, dim)
    parts = []
    end = 0
    for size in split:
        beg, end = end, end + size
        if cat_dim == -1:
            part = gaussian[..., beg:end]
        elif cat_dim == -2:
            part = gaussian[..., beg:end, :]
        elif cat_dim == 1:
            part = gaussian[:, beg:end]
        else:
            raise ValueError
        parts.append(part)

    actual = Gaussian.cat(parts, cat_dim)
    assert_close_gaussian(actual, gaussian)