Ejemplo n.º 1
0
def test_normal_arithmetic():
    chol = np.random.randn(3, 3)
    dist = Normal(chol.dot(chol.T), np.random.randn(3, 1))
    chol = np.random.randn(3, 3)
    dist2 = Normal(chol.dot(chol.T), np.random.randn(3, 1))

    A = np.random.randn(3, 3)
    a = np.random.randn(1, 3)
    b = 5.

    # Test matrix multiplication.
    allclose((dist.rmatmul(a)).mean, dist.mean.dot(a))
    allclose((dist.rmatmul(a)).var, a.dot(dense(dist.var)).dot(a.T))
    allclose((dist.lmatmul(A)).mean, A.dot(dist.mean))
    allclose((dist.lmatmul(A)).var, A.dot(dense(dist.var)).dot(A.T))

    # Test multiplication.
    allclose((dist * b).mean, dist.mean * b)
    allclose((dist * b).var, dist.var * b**2)
    allclose((b * dist).mean, dist.mean * b)
    allclose((b * dist).var, dist.var * b**2)
    with pytest.raises(NotImplementedError):
        dist.__mul__(dist)
    with pytest.raises(NotImplementedError):
        dist.__rmul__(dist)

    # Test addition.
    allclose((dist + dist2).mean, dist.mean + dist2.mean)
    allclose((dist + dist2).var, dist.var + dist2.var)
    allclose((dist.__add__(b)).mean, dist.mean + b)
    allclose((dist.__radd__(b)).mean, dist.mean + b)
Ejemplo n.º 2
0
def test_normal_arithmetic():
    chol = np.random.randn(3, 3)
    dist = Normal(chol.dot(chol.T), np.random.randn(3, 1))
    chol = np.random.randn(3, 3)
    dist2 = Normal(chol.dot(chol.T), np.random.randn(3, 1))

    A = np.random.randn(3, 3)
    a = np.random.randn(1, 3)
    b = 5.

    # Test matrix multiplication.
    yield ok, allclose((dist.rmatmul(a)).mean, dist.mean.dot(a)), 'mean mul'
    yield ok, allclose((dist.rmatmul(a)).var,
                       a.dot(dense(dist.var)).dot(a.T)), 'var mul'
    yield ok, allclose((dist.lmatmul(A)).mean, A.dot(dist.mean)), 'mean rmul'
    yield ok, allclose((dist.lmatmul(A)).var,
                       A.dot(dense(dist.var)).dot(A.T)), 'var rmul'

    # Test multiplication.
    yield ok, allclose((dist * b).mean, dist.mean * b), 'mean mul 2'
    yield ok, allclose((dist * b).var, dist.var * b**2), 'var mul 2'
    yield ok, allclose((b * dist).mean, dist.mean * b), 'mean rmul 2'
    yield ok, allclose((b * dist).var, dist.var * b**2), 'var rmul 2'
    yield raises, NotImplementedError, lambda: dist.__mul__(dist)
    yield raises, NotImplementedError, lambda: dist.__rmul__(dist)

    # Test addition.
    yield ok, allclose((dist + dist2).mean, dist.mean + dist2.mean), 'mean sum'
    yield ok, allclose((dist + dist2).var, dist.var + dist2.var), 'var sum'
    yield ok, allclose((dist.__add__(b)).mean, dist.mean + b), 'mean add'
    yield ok, allclose((dist.__radd__(b)).mean, dist.mean + b), 'mean radd'