Esempio n. 1
0
def test_shorthands():
    model = Graph()
    p = GP(EQ(), graph=model)

    # Construct a normal distribution that serves as in input.
    x = p(1)
    assert isinstance(x, At)
    assert type_parameter(x) is p
    assert x.get() == 1
    assert str(p(x)) == '{}({})'.format(str(p), str(x))
    assert repr(p(x)) == '{}({})'.format(repr(p), repr(x))

    # Construct a normal distribution that does not serve as an input.
    x = Normal(np.ones((1, 1)))
    with pytest.raises(RuntimeError):
        type_parameter(x)
    with pytest.raises(RuntimeError):
        x.get()
    with pytest.raises(RuntimeError):
        p | (x, 1)

    # Test shorthands for stretching and selection.
    p = GP(EQ(), graph=Graph())
    assert str(p > 2) == str(p.stretch(2))
    assert str(p[0]) == str(p.select(0))
Esempio n. 2
0
def test_sum_other():
    p = GP(TensorProductMean(lambda x: x ** 2), EQ())

    def five(y):
        return 5 * B.ones(B.shape(y)[0], 1)

    x = B.randn(5, 1)
    for p_sum in [
        # Add a numeric thing.
        p + 5.0,
        5.0 + p,
        p.measure.sum(GP(), p, 5.0),
        p.measure.sum(GP(), 5.0, p),
        # Add a function.
        p + five,
        five + p,
        p.measure.sum(GP(), p, five),
        p.measure.sum(GP(), five, p),
    ]:
        approx(p.mean(x) + 5.0, p_sum.mean(x))
        approx(p.mean(x) + 5.0, p_sum.mean(x))
        approx(p.kernel(x), p_sum.kernel(x))
        approx(p.kernel(x), p_sum.kernel(x))

    # Check that a `GP` cannot be summed with a `Normal`.
    with pytest.raises(NotFoundLookupError):
        p + Normal(np.eye(3))
    with pytest.raises(NotFoundLookupError):
        Normal(np.eye(3)) + p
Esempio n. 3
0
def test_mul_other():
    p = GP(TensorProductMean(lambda x: x ** 2), EQ())

    def five(y):
        return 5 * B.ones(B.shape(y)[0], 1)

    x = B.randn(5, 1)
    for p_mul in [
        # Multiply numeric thing.
        p * 5.0,
        5.0 * p,
        p.measure.mul(GP(), p, 5.0),
        p.measure.mul(GP(), 5.0, p),
        # Multiply with a function.
        p * five,
        five * p,
        p.measure.mul(GP(), p, five),
        p.measure.mul(GP(), five, p),
    ]:
        approx(5.0 * p.mean(x), p_mul.mean(x))
        approx(5.0 * p.mean(x), p_mul.mean(x))
        approx(25.0 * p.kernel(x), p_mul.kernel(x))
        approx(25.0 * p.kernel(x), p_mul.kernel(x))

    # Check that a `GP` cannot be multiplied with a `Normal`.
    with pytest.raises(NotFoundLookupError):
        p * Normal(np.eye(3))
    with pytest.raises(NotFoundLookupError):
        Normal(np.eye(3)) * p
Esempio n. 4
0
def test_normal_lazy_var_diag():
    # If `var_diag` isn't set, the variance will be constructed to get the diagonal.
    dist = Normal(lambda: B.eye(3))
    approx(dist.var_diag, B.ones(3))
    approx(dist._var, B.eye(3))

    # If `var_diag` is set, the variance will _not_ be constructed to get the diagonal.
    dist = Normal(lambda: B.eye(3), var_diag=lambda: 9)
    approx(dist.var_diag, 9)
    assert dist._var is None
Esempio n. 5
0
def test_normal_arithmetic():
    chol = np.random.randn(3, 3)
    dist = Normal(chol.dot(chol.T), np.random.randn(3, 1))
    chol = np.random.randn(3, 3)
    dist2 = Normal(chol.dot(chol.T), np.random.randn(3, 1))

    A = np.random.randn(3, 3)
    a = np.random.randn(1, 3)
    b = 5.

    # Test matrix multiplication.
    yield ok, allclose((dist.rmatmul(a)).mean, dist.mean.dot(a)), 'mean mul'
    yield ok, allclose((dist.rmatmul(a)).var,
                       a.dot(dense(dist.var)).dot(a.T)), 'var mul'
    yield ok, allclose((dist.lmatmul(A)).mean, A.dot(dist.mean)), 'mean rmul'
    yield ok, allclose((dist.lmatmul(A)).var,
                       A.dot(dense(dist.var)).dot(A.T)), 'var rmul'

    # Test multiplication.
    yield ok, allclose((dist * b).mean, dist.mean * b), 'mean mul 2'
    yield ok, allclose((dist * b).var, dist.var * b**2), 'var mul 2'
    yield ok, allclose((b * dist).mean, dist.mean * b), 'mean rmul 2'
    yield ok, allclose((b * dist).var, dist.var * b**2), 'var rmul 2'
    yield raises, NotImplementedError, lambda: dist.__mul__(dist)
    yield raises, NotImplementedError, lambda: dist.__rmul__(dist)

    # Test addition.
    yield ok, allclose((dist + dist2).mean, dist.mean + dist2.mean), 'mean sum'
    yield ok, allclose((dist + dist2).var, dist.var + dist2.var), 'var sum'
    yield ok, allclose((dist.__add__(b)).mean, dist.mean + b), 'mean add'
    yield ok, allclose((dist.__radd__(b)).mean, dist.mean + b), 'mean radd'
Esempio n. 6
0
def test_normal_sampling():
    for mean in [0, 1]:
        dist = Normal(mean, 3 * B.eye(np.int32, 200))

        # Sample without noise.
        samples = dist.sample(2000)
        approx(B.mean(samples), mean, atol=5e-2)
        approx(B.std(samples)**2, 3, atol=5e-2)

        # Sample with noise
        samples = dist.sample(2000, noise=2)
        approx(B.mean(samples), mean, atol=5e-2)
        approx(B.std(samples)**2, 5, atol=5e-2)
Esempio n. 7
0
def test_normal_mean_is_zero():
    # Check zero case.
    dist = Normal(B.eye(3))
    assert dist.mean_is_zero
    approx(dist.mean, B.zeros(3, 1))

    # Check another zero case.
    dist = Normal(Zero(np.float32, 3, 1), B.eye(3))
    assert dist.mean_is_zero
    approx(dist.mean, B.zeros(3, 1))

    # Check nonzero case.
    assert not Normal(B.randn(3, 1), B.eye(3)).mean_is_zero
Esempio n. 8
0
def test_mul_other():
    model = Graph()
    p1 = GP(EQ(), TensorProductMean(lambda x: x ** 2), graph=model)
    p2 = 5. * p1
    p3 = p1 * 5.

    x = np.random.randn(5, 1)
    yield assert_allclose, 5. * p1.mean(x), p2.mean(x)
    yield assert_allclose, 5. * p1.mean(x), p3.mean(x)
    yield assert_allclose, 25. * p1.kernel(x), p2.kernel(x)
    yield assert_allclose, 25. * p1.kernel(x), p3.kernel(x)
    yield assert_allclose, model.kernels[p2, p3](x, x), 25. * p1.kernel(x)

    # Check that a `GP` cannot be multiplied with a `Normal`.
    yield raises, NotImplementedError, lambda: p1 * Normal(np.eye(3))
    yield raises, NotImplementedError, lambda: Normal(np.eye(3)) * p1
Esempio n. 9
0
def test_mul_other():
    model = Graph()
    p1 = GP(EQ(), TensorProductMean(lambda x: x ** 2), graph=model)
    p2 = 5. * p1
    p3 = p1 * 5.

    x = np.random.randn(5, 1)
    allclose(5. * p1.mean(x), p2.mean(x))
    allclose(5. * p1.mean(x), p3.mean(x))
    allclose(25. * p1.kernel(x), p2.kernel(x))
    allclose(25. * p1.kernel(x), p3.kernel(x))
    allclose(model.kernels[p2, p3](x, x), 25. * p1.kernel(x))

    # Check that a `GP` cannot be multiplied with a `Normal`.
    with pytest.raises(NotImplementedError):
        p1 * Normal(np.eye(3))
    with pytest.raises(NotImplementedError):
        Normal(np.eye(3)) * p1
Esempio n. 10
0
def test_normal_sampling():
    for mean in [0, 1]:
        dist = Normal(mean, 3 * B.eye(np.int32, 200))

        # Sample without noise.
        samples = dist.sample(2000)
        approx(B.mean(samples), mean, atol=5e-2)
        approx(B.std(samples) ** 2, 3, atol=5e-2)

        # Sample with noise
        samples = dist.sample(2000, noise=2)
        approx(B.mean(samples), mean, atol=5e-2)
        approx(B.std(samples) ** 2, 5, atol=5e-2)

        state, sample1 = dist.sample(B.create_random_state(B.dtype(dist), seed=0))
        state, sample2 = dist.sample(B.create_random_state(B.dtype(dist), seed=0))
        assert isinstance(state, B.RandomState)
        approx(sample1, sample2)
Esempio n. 11
0
def test_normal_lazy_nonzero_mean():
    dist = Normal(lambda: B.ones(3, 1), lambda: B.eye(3))
    # Nothing should be populated yet.
    assert dist._mean is None
    assert dist._var is None

    # But they should be populated upon request.
    approx(dist.mean, B.ones(3, 1))
    assert dist._var is None
    approx(dist.var, B.eye(3))
Esempio n. 12
0
def test_normal_logpdf_missing_data(normal1):
    x = B.randn(3, 1)
    x[1] = B.nan
    approx(
        normal1.logpdf(x),
        Normal(
            normal1.mean[[0, 2]],
            normal1.var[[0, 2], :][:, [0, 2]],
        ).logpdf(x[[0, 2]]),
    )
Esempio n. 13
0
def test_normal_lazy_nonzero_mean():
    dist = Normal(lambda: B.ones(3, 1), lambda: B.eye(3))

    assert not dist.mean_is_zero
    approx(dist._mean, B.ones(3, 1))
    assert dist._var is None

    approx(dist.mean, B.ones(3, 1))
    assert dist._var is None

    approx(dist.var, B.eye(3))
Esempio n. 14
0
def test_sum_other():
    model = Graph()
    p1 = GP(EQ(), TensorProductMean(lambda x: x ** 2), graph=model)
    p2 = p1 + 5.
    p3 = 5. + p1
    p4 = model.sum(5., p1)

    x = np.random.randn(5, 1)
    yield assert_allclose, p1.mean(x) + 5., p2.mean(x)
    yield assert_allclose, p1.mean(x) + 5., p3.mean(x)
    yield assert_allclose, p1.mean(x) + 5., p4.mean(x)
    yield assert_allclose, p1.kernel(x), p2.kernel(x)
    yield assert_allclose, p1.kernel(x), p3.kernel(x)
    yield assert_allclose, p1.kernel(x), p4.kernel(x)
    yield assert_allclose, p1.kernel(p2(x), p3(x)), \
          p1.kernel(x)
    yield assert_allclose, p1.kernel(p2(x), p4(x)), \
          p1.kernel(x)

    # Check that a `GP` cannot be summed with a `Normal`.
    yield raises, NotImplementedError, lambda: p1 + Normal(np.eye(3))
    yield raises, NotImplementedError, lambda: Normal(np.eye(3)) + p1
Esempio n. 15
0
def test_normal_lazy_zero_mean():
    dist = Normal(lambda: B.eye(3))

    assert dist.mean_is_zero
    assert dist._mean is 0
    assert dist._var is None

    approx(dist.mean, B.zeros(3, 1))
    # At this point, the variance should be constructed, because it is used to get the
    # dimensionality and data type for the mean.
    assert dist._var is not None

    approx(dist.var, B.eye(3))
Esempio n. 16
0
def test_shorthands():
    model = Graph()
    p = GP(EQ(), graph=model)

    # Construct a normal distribution that serves as in input.
    x = p(1)
    yield assert_instance, x, At
    yield ok, type_parameter(x) is p
    yield eq, x.get(), 1
    yield eq, str(p(x)), '{}({})'.format(str(p), str(x))
    yield eq, repr(p(x)), '{}({})'.format(repr(p), repr(x))

    # Construct a normal distribution that does not serve as an input.
    x = Normal(np.ones((1, 1)))
    yield raises, RuntimeError, lambda: type_parameter(x)
    yield raises, RuntimeError, lambda: x.get()
    yield raises, RuntimeError, lambda: p | (x, 1)

    # Test shorthands for stretching and selection.
    p = GP(EQ(), graph=Graph())
    yield eq, str(p > 2), str(p.stretch(2))
    yield eq, str(p[0]), str(p.select(0))
Esempio n. 17
0
def test_normal_lazy_mean_var():
    # The lazy `mean_var` should only be called when neither the mean nor the variance
    # exists. Otherwise, it's more efficient to just construct the other one. We
    # go over all branches in the `if`-statement.

    dist = Normal(lambda: B.ones(3, 1), lambda: B.eye(3), mean_var=lambda: (8, 9))
    approx(dist.mean_var, (8, 9))
    approx(dist.mean, 8)
    approx(dist.var, 9)

    dist = Normal(lambda: B.ones(3, 1), lambda: B.eye(3), mean_var=lambda: (8, 9))
    approx(dist.mean, B.ones(3, 1))
    approx(dist.mean_var, (B.ones(3, 1), B.eye(3)))
    approx(dist.var, B.eye(3))

    dist = Normal(lambda: B.ones(3, 1), lambda: B.eye(3), mean_var=lambda: (8, 9))
    approx(dist.var, B.eye(3))
    approx(dist.mean_var, (B.ones(3, 1), B.eye(3)))
    approx(dist.mean, B.ones(3, 1))

    dist = Normal(lambda: B.ones(3, 1), lambda: B.eye(3), mean_var=lambda: (8, 9))
    approx(dist.var, B.eye(3))
    approx(dist.mean, B.ones(3, 1))
    approx(dist.mean_var, (B.ones(3, 1), B.eye(3)))
Esempio n. 18
0
def test_normal_printing():
    dist = Normal(lambda: np.array([1, 2]), lambda: np.array([[3, 1], [1, 3]]))
    res = "<Normal:\n mean=unresolved,\n var=unresolved>"
    assert str(dist) == repr(dist) == res
    # Resolve mean.
    dist.mean
    assert str(dist) == "<Normal:\n mean=[1 2],\n var=unresolved>"
    assert repr(dist) == "<Normal:\n mean=array([1, 2]),\n var=unresolved>"
    # Resolve variance.
    dist.var
    assert str(dist) == (
        "<Normal:\n"
        " mean=[1 2],\n"
        " var=<dense matrix: batch=(), shape=(2, 2), dtype=int64>>"
    )
    assert repr(dist) == (
        "<Normal:\n"
        " mean=array([1, 2]),\n"
        " var=<dense matrix: batch=(), shape=(2, 2), dtype=int64\n"
        "      mat=[[3 1]\n"
        "           [1 3]]>>"
    )
Esempio n. 19
0
def test_normal_sampling():
    # Test sampling and dtype conversion.
    dist = Normal(3 * np.eye(200, dtype=np.integer))
    assert np.abs(np.std(dist.sample(1000))**2 - 3) <= 5e-2, 'full'
    assert np.abs(np.std(dist.sample(1000, noise=2))**2 - 5) <= 5e-2, 'full 2'

    # Test `__str__` and `__repr__`.
    assert str(dist) == RandomVector.__str__(dist)
    assert repr(dist) == RandomVector.__repr__(dist)

    # Test zero mean determination.
    assert Normal(np.eye(3))._zero_mean
    assert not Normal(np.eye(3), np.random.randn(3, 1))._zero_mean

    x = np.random.randn(3)
    assert GP(1)(x)._zero_mean
    assert not GP(1, 1)(x)._zero_mean
    assert GP(1, 0)(x)._zero_mean
Esempio n. 20
0
def test_normal_sampling():
    # Test sampling and dtype conversion.
    dist = Normal(3 * np.eye(200, dtype=np.integer))
    yield le, np.abs(np.std(dist.sample(1000))**2 - 3), 5e-2, 'full'
    yield le, np.abs(np.std(dist.sample(1000, noise=2)) ** 2 - 5), 5e-2, \
          'full 2'

    dist = Normal(Diagonal(3 * np.ones(200, dtype=np.integer)))
    yield le, np.abs(np.std(dist.sample(1000))**2 - 3), 5e-2, 'diag'
    yield le, np.abs(np.std(dist.sample(1000, noise=2)) ** 2 - 5), 5e-2, \
          'diag 2'

    dist = Normal(UniformlyDiagonal(3, 200))
    yield le, np.abs(np.std(dist.sample(1000))**2 - 3), 5e-2, 'unif'
    yield le, np.abs(np.std(dist.sample(1000, noise=2)) ** 2 - 5), 5e-2, \
          'unif 2'

    # Test `__str__` and `__repr__`.
    yield eq, str(dist), RandomVector.__str__(dist)
    yield eq, repr(dist), RandomVector.__repr__(dist)
Esempio n. 21
0
def test_normal_dtype(normal1):
    assert B.dtype(Normal(0, B.eye(3))) == np.float64
    assert B.dtype(Normal(B.ones(3), B.zeros(int, 3))) == np.float64
    assert B.dtype(Normal(B.ones(int, 3), B.zeros(int, 3))) == np.int64
Esempio n. 22
0
def test_normal_lazy_mean_var_diag():
    # The lazy `mean_var_diag` should only be called when neither the mean nor the
    # diagonal of the variance exists. Otherwise, it's more efficient to just construct
    # the other one. We go over all branches in the `if`-statement.

    dist = Normal(lambda: B.ones(3, 1), lambda: B.eye(3), mean_var_diag=lambda: (8, 9))
    approx(dist.marginals(), (8, 9))
    approx(dist.mean, 8)
    approx(dist.var_diag, 9)

    dist = Normal(lambda: B.ones(3, 1), lambda: B.eye(3), mean_var_diag=lambda: (8, 9))
    approx(dist.mean, B.ones(3, 1))
    approx(dist.marginals(), (B.ones(3), B.ones(3)))
    approx(dist.var_diag, B.ones(3))

    dist = Normal(lambda: B.ones(3, 1), lambda: B.eye(3), mean_var_diag=lambda: (8, 9))
    approx(dist.var_diag, B.ones(3))
    approx(dist.marginals(), (B.ones(3), B.ones(3)))
    approx(dist.mean, B.ones(3, 1))

    dist = Normal(lambda: B.ones(3, 1), lambda: B.eye(3), mean_var_diag=lambda: (8, 9))
    approx(dist.var_diag, B.ones(3))
    approx(dist.mean, B.ones(3, 1))
    approx(dist.marginals(), (B.ones(3), B.ones(3)))
Esempio n. 23
0
def test_normal_arithmetic():
    chol = np.random.randn(3, 3)
    dist = Normal(chol.dot(chol.T), np.random.randn(3, 1))
    chol = np.random.randn(3, 3)
    dist2 = Normal(chol.dot(chol.T), np.random.randn(3, 1))

    A = np.random.randn(3, 3)
    a = np.random.randn(1, 3)
    b = 5.

    # Test matrix multiplication.
    allclose((dist.rmatmul(a)).mean, dist.mean.dot(a))
    allclose((dist.rmatmul(a)).var, a.dot(dense(dist.var)).dot(a.T))
    allclose((dist.lmatmul(A)).mean, A.dot(dist.mean))
    allclose((dist.lmatmul(A)).var, A.dot(dense(dist.var)).dot(A.T))

    # Test multiplication.
    allclose((dist * b).mean, dist.mean * b)
    allclose((dist * b).var, dist.var * b**2)
    allclose((b * dist).mean, dist.mean * b)
    allclose((b * dist).var, dist.var * b**2)
    with pytest.raises(NotImplementedError):
        dist.__mul__(dist)
    with pytest.raises(NotImplementedError):
        dist.__rmul__(dist)

    # Test addition.
    allclose((dist + dist2).mean, dist.mean + dist2.mean)
    allclose((dist + dist2).var, dist.var + dist2.var)
    allclose((dist.__add__(b)).mean, dist.mean + b)
    allclose((dist.__radd__(b)).mean, dist.mean + b)
Esempio n. 24
0
def test_normal():
    mean = np.random.randn(3, 1)
    chol = np.random.randn(3, 3)
    var = chol.dot(chol.T)

    dist = Normal(var, mean)
    dist_sp = multivariate_normal(mean[:, 0], var)

    # Test second moment.
    allclose(dist.m2, var + mean.dot(mean.T))

    # Test marginals.
    marg_mean, lower, upper = dist.marginals()
    allclose(mean.squeeze(), marg_mean)
    allclose(lower, marg_mean - 2 * np.diag(var)**.5)
    allclose(upper, marg_mean + 2 * np.diag(var)**.5)

    # Test `logpdf` and `entropy`.
    for _ in range(5):
        x = np.random.randn(3, 10)
        allclose(dist.logpdf(x), dist_sp.logpdf(x.T), desc='logpdf')
        allclose(dist.entropy(), dist_sp.entropy(), desc='entropy')

    # Test the the output of `logpdf` is flattened appropriately.
    assert np.shape(dist.logpdf(np.ones((3, 1)))) == ()
    assert np.shape(dist.logpdf(np.ones((3, 2)))) == (2, )

    # Test KL with Monte Carlo estimate.
    mean2 = np.random.randn(3, 1)
    chol2 = np.random.randn(3, 3)
    var2 = chol2.dot(chol2.T)
    dist2 = Normal(var2, mean2)
    samples = dist.sample(50000)
    kl_est = np.mean(dist.logpdf(samples)) - np.mean(dist2.logpdf(samples))
    kl = dist.kl(dist2)
    assert np.abs(kl_est - kl) / np.abs(kl) < 5e-2, 'kl sampled'
Esempio n. 25
0
def normal2():
    mean = B.randn(3, 1)
    chol = B.randn(3, 3)
    var = chol @ chol.T
    return Normal(mean, var)
Esempio n. 26
0
def test_normal_comparison():
    # Compare a diagonal normal and dense normal.
    mean = np.random.randn(3, 1)
    var_diag = np.random.randn(3)**2
    var = np.diag(var_diag)
    dist1 = Normal(var, mean)
    dist2 = Normal(Diagonal(var_diag), mean)
    samples = dist1.sample(100)
    allclose(dist1.logpdf(samples), dist2.logpdf(samples), desc='logpdf')
    allclose(dist1.entropy(), dist2.entropy(), desc='entropy')
    allclose(dist1.kl(dist2), 0.)
    allclose(dist1.kl(dist1), 0.)
    allclose(dist2.kl(dist2), 0.)
    allclose(dist2.kl(dist1), 0.)
    assert dist1.w2(dist1) <= 1e-3
    assert dist1.w2(dist2) <= 1e-3
    assert dist2.w2(dist1) <= 1e-3
    assert dist2.w2(dist2) <= 1e-3

    # Check a uniformly diagonal normal and dense normal.
    mean = np.random.randn(3, 1)
    var_diag_scale = np.random.randn()**2
    var = np.eye(3) * var_diag_scale
    dist1 = Normal(var, mean)
    dist2 = Normal(UniformlyDiagonal(var_diag_scale, 3), mean)
    samples = dist1.sample(100)
    allclose(dist1.logpdf(samples), dist2.logpdf(samples), desc='logpdf')
    allclose(dist1.entropy(), dist2.entropy(), desc='entropy')
    allclose(dist1.kl(dist2), 0.)
    allclose(dist1.kl(dist1), 0.)
    allclose(dist2.kl(dist2), 0.)
    allclose(dist2.kl(dist1), 0.)
    assert dist1.w2(dist1) <= 1e-3
    assert dist1.w2(dist2) <= 1e-3
    assert dist2.w2(dist1) <= 1e-3
    assert dist2.w2(dist2) <= 1e-3
Esempio n. 27
0
def test_normal_diagonalise(normal1):
    approx(
        normal1.diagonalise(),
        Normal(normal1.mean, B.diag(B.diag(B.dense(normal1.var)))),
    )
Esempio n. 28
0
def test_normal():
    mean = np.random.randn(3, 1)
    chol = np.random.randn(3, 3)
    var = chol.dot(chol.T)

    dist = Normal(var, mean)
    dist_sp = multivariate_normal(mean[:, 0], var)

    # Test second moment.
    yield assert_allclose, dist.m2, var + mean.dot(mean.T)

    # Test marginals.
    marg_mean, lower, upper = dist.marginals()
    yield assert_allclose, mean.squeeze(), marg_mean
    yield assert_allclose, lower, marg_mean - 2 * np.diag(var)**.5
    yield assert_allclose, upper, marg_mean + 2 * np.diag(var)**.5

    # Test `logpdf` and `entropy`.
    for _ in range(5):
        x = np.random.randn(3, 10)
        yield ok, allclose(dist.logpdf(x), dist_sp.logpdf(x.T)), 'logpdf'
        yield ok, allclose(dist.entropy(), dist_sp.entropy()), 'entropy'

    # Test that inputs to `logpdf` are converted appropriately.
    yield assert_allclose, \
          dist.logpdf(np.array([0, 1, 2])), \
          dist.logpdf([0, 1, 2])
    yield assert_allclose, \
          dist.logpdf(np.array([0, 1, 2])), \
          dist.logpdf((0, 1, 2))

    # Test the the output of `logpdf` is flattened appropriately.
    yield eq, np.shape(dist.logpdf(np.ones((3, 1)))), ()
    yield eq, np.shape(dist.logpdf(np.ones((3, 2)))), (2, )

    # Test KL with Monte Carlo estimate.
    mean2 = np.random.randn(3, 1)
    chol2 = np.random.randn(3, 3)
    var2 = chol2.dot(chol2.T)
    dist2 = Normal(var2, mean2)
    samples = dist.sample(50000)
    kl_est = np.mean(dist.logpdf(samples)) - np.mean(dist2.logpdf(samples))
    kl = dist.kl(dist2)
    yield ok, np.abs(kl_est - kl) / np.abs(kl) < 5e-2, 'kl sampled'
Esempio n. 29
0
def test_normal_comparison():
    # Compare a diagonal normal and dense normal.
    mean = np.random.randn(3, 1)
    var_diag = np.random.randn(3)**2
    var = np.diag(var_diag)
    dist1 = Normal(var, mean)
    dist2 = Normal(Diagonal(var_diag), mean)
    samples = dist1.sample(100)
    yield ok, allclose(dist1.logpdf(samples), dist2.logpdf(samples)), 'logpdf'
    yield ok, allclose(dist1.entropy(), dist2.entropy()), 'entropy'
    yield ok, allclose(dist1.kl(dist2), 0.), 'kl 1'
    yield ok, allclose(dist1.kl(dist1), 0.), 'kl 2'
    yield ok, allclose(dist2.kl(dist2), 0.), 'kl 3'
    yield ok, allclose(dist2.kl(dist1), 0.), 'kl 4'
    yield le, dist1.w2(dist1), 5e-4, 'w2 1'
    yield le, dist1.w2(dist2), 5e-4, 'w2 2'
    yield le, dist2.w2(dist1), 5e-4, 'w2 3'
    yield le, dist2.w2(dist2), 5e-4, 'w2 4'

    # Check a uniformly diagonal normal and dense normal.
    mean = np.random.randn(3, 1)
    var_diag_scale = np.random.randn()**2
    var = np.eye(3) * var_diag_scale
    dist1 = Normal(var, mean)
    dist2 = Normal(UniformlyDiagonal(var_diag_scale, 3), mean)
    samples = dist1.sample(100)
    yield ok, allclose(dist1.logpdf(samples), dist2.logpdf(samples)), 'logpdf'
    yield ok, allclose(dist1.entropy(), dist2.entropy()), 'entropy'
    yield ok, allclose(dist1.kl(dist2), 0.), 'kl 1'
    yield ok, allclose(dist1.kl(dist1), 0.), 'kl 2'
    yield ok, allclose(dist2.kl(dist2), 0.), 'kl 3'
    yield ok, allclose(dist2.kl(dist1), 0.), 'kl 4'
    yield le, dist1.w2(dist1), 5e-4, 'w2 1'
    yield le, dist1.w2(dist2), 5e-4, 'w2 2'
    yield le, dist2.w2(dist1), 5e-4, 'w2 3'
    yield le, dist2.w2(dist2), 5e-4, 'w2 4'