def test_normal_arithmetic(): chol = np.random.randn(3, 3) dist = Normal(chol.dot(chol.T), np.random.randn(3, 1)) chol = np.random.randn(3, 3) dist2 = Normal(chol.dot(chol.T), np.random.randn(3, 1)) A = np.random.randn(3, 3) a = np.random.randn(1, 3) b = 5. # Test matrix multiplication. allclose((dist.rmatmul(a)).mean, dist.mean.dot(a)) allclose((dist.rmatmul(a)).var, a.dot(dense(dist.var)).dot(a.T)) allclose((dist.lmatmul(A)).mean, A.dot(dist.mean)) allclose((dist.lmatmul(A)).var, A.dot(dense(dist.var)).dot(A.T)) # Test multiplication. allclose((dist * b).mean, dist.mean * b) allclose((dist * b).var, dist.var * b**2) allclose((b * dist).mean, dist.mean * b) allclose((b * dist).var, dist.var * b**2) with pytest.raises(NotImplementedError): dist.__mul__(dist) with pytest.raises(NotImplementedError): dist.__rmul__(dist) # Test addition. allclose((dist + dist2).mean, dist.mean + dist2.mean) allclose((dist + dist2).var, dist.var + dist2.var) allclose((dist.__add__(b)).mean, dist.mean + b) allclose((dist.__radd__(b)).mean, dist.mean + b)
def test_normal_arithmetic(): chol = np.random.randn(3, 3) dist = Normal(chol.dot(chol.T), np.random.randn(3, 1)) chol = np.random.randn(3, 3) dist2 = Normal(chol.dot(chol.T), np.random.randn(3, 1)) A = np.random.randn(3, 3) a = np.random.randn(1, 3) b = 5. # Test matrix multiplication. yield ok, allclose((dist.rmatmul(a)).mean, dist.mean.dot(a)), 'mean mul' yield ok, allclose((dist.rmatmul(a)).var, a.dot(dense(dist.var)).dot(a.T)), 'var mul' yield ok, allclose((dist.lmatmul(A)).mean, A.dot(dist.mean)), 'mean rmul' yield ok, allclose((dist.lmatmul(A)).var, A.dot(dense(dist.var)).dot(A.T)), 'var rmul' # Test multiplication. yield ok, allclose((dist * b).mean, dist.mean * b), 'mean mul 2' yield ok, allclose((dist * b).var, dist.var * b**2), 'var mul 2' yield ok, allclose((b * dist).mean, dist.mean * b), 'mean rmul 2' yield ok, allclose((b * dist).var, dist.var * b**2), 'var rmul 2' yield raises, NotImplementedError, lambda: dist.__mul__(dist) yield raises, NotImplementedError, lambda: dist.__rmul__(dist) # Test addition. yield ok, allclose((dist + dist2).mean, dist.mean + dist2.mean), 'mean sum' yield ok, allclose((dist + dist2).var, dist.var + dist2.var), 'var sum' yield ok, allclose((dist.__add__(b)).mean, dist.mean + b), 'mean add' yield ok, allclose((dist.__radd__(b)).mean, dist.mean + b), 'mean radd'