Esempio n. 1
0
def test_reshape(old_shape, new_shape, dim):
    gamma_gaussian = random_gamma_gaussian(old_shape, dim)

    # reshape to new
    new = gamma_gaussian.reshape(new_shape)
    assert new.batch_shape == new_shape

    # reshape back to old
    g = new.reshape(old_shape)
    assert_close_gamma_gaussian(g, gamma_gaussian)
Esempio n. 2
0
def test_sequential_gamma_gaussian_tensordot(batch_shape, state_dim,
                                             num_steps):
    g = random_gamma_gaussian(batch_shape + (num_steps, ),
                              state_dim + state_dim)
    actual = _sequential_gamma_gaussian_tensordot(g)
    assert actual.dim() == g.dim()
    assert actual.batch_shape == batch_shape

    # Check against hand computation.
    expected = g[..., 0]
    for t in range(1, num_steps):
        expected = gamma_gaussian_tensordot(expected, g[..., t], state_dim)
    assert_close_gamma_gaussian(actual, expected)
Esempio n. 3
0
def test_cat(shape, cat_dim, split, dim):
    assert sum(split) == shape[cat_dim]
    gamma_gaussian = random_gamma_gaussian(shape, dim)
    parts = []
    end = 0
    for size in split:
        beg, end = end, end + size
        if cat_dim == -1:
            part = gamma_gaussian[..., beg:end]
        elif cat_dim == -2:
            part = gamma_gaussian[..., beg:end, :]
        elif cat_dim == 1:
            part = gamma_gaussian[:, beg:end]
        else:
            raise ValueError
        parts.append(part)

    actual = GammaGaussian.cat(parts, cat_dim)
    assert_close_gamma_gaussian(actual, gamma_gaussian)