def test_message_to_parent(self): """ Test the message from SumMultiply node to its parents. """ data = 2 tau = 3 def check_message(true_m0, true_m1, parent, *args, F=None): if F is None: A = SumMultiply(*args) B = GaussianARD(A, tau) B.observe(data*np.ones(A.plates + A.dims[0])) else: A = F (A_m0, A_m1) = A._message_to_parent(parent) self.assertAllClose(true_m0, A_m0) self.assertAllClose(true_m1, A_m1) pass # Check: different message to each of multiple parents X1 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1) x2 = X2.get_moments() m0 = tau * data * x2[0] m1 = -0.5 * tau * x2[1] * np.identity(2) check_message(m0, m1, 0, 'i,i->i', X1, X2) check_message(m0, m1, 0, X1, [9], X2, [9], [9]) m0 = tau * data * x1[0] m1 = -0.5 * tau * x1[1] * np.identity(2) check_message(m0, m1, 1, 'i,i->i', X1, X2) check_message(m0, m1, 1, X1, [9], X2, [9], [9]) # Check: key not in output X1 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1) x1 = X1.get_moments() m0 = tau * data * np.ones(2) m1 = -0.5 * tau * np.ones((2,2)) check_message(m0, m1, 0, 'i', X1) check_message(m0, m1, 0, 'i->', X1) check_message(m0, m1, 0, X1, [9]) check_message(m0, m1, 0, X1, [9], []) # Check: key not in some input X1 = GaussianARD(np.random.randn(), np.random.rand()) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1) x2 = X2.get_moments() m0 = tau * data * np.sum(x2[0], axis=-1) m1 = -0.5 * tau * np.sum(x2[1] * np.identity(2), axis=(-1,-2)) check_message(m0, m1, 0, ',i->i', X1, X2) check_message(m0, m1, 0, X1, [], X2, [9], [9]) m0 = tau * data * x1[0] * np.ones(2) m1 = -0.5 * tau * x1[1] * np.identity(2) check_message(m0, m1, 1, ',i->i', X1, X2) check_message(m0, m1, 1, X1, [], X2, [9], [9]) # Check: keys in different order Y1 = GaussianARD(np.random.randn(3,2), np.random.rand(3,2), ndim=2) y1 = Y1.get_moments() Y2 = GaussianARD(np.random.randn(2,3), np.random.rand(2,3), ndim=2) y2 = Y2.get_moments() m0 = tau * data * y2[0].T m1 = -0.5 * tau * np.einsum('ijlk->jikl', y2[1] * misc.identity(2,3)) check_message(m0, m1, 0, 'ij,ji->ij', Y1, Y2) check_message(m0, m1, 0, Y1, ['i','j'], Y2, ['j','i'], ['i','j']) m0 = tau * data * y1[0].T m1 = -0.5 * tau * np.einsum('ijlk->jikl', y1[1] * misc.identity(3,2)) check_message(m0, m1, 1, 'ij,ji->ij', Y1, Y2) check_message(m0, m1, 1, Y1, ['i','j'], Y2, ['j','i'], ['i','j']) # Check: plates when different dimensionality X1 = GaussianARD(np.random.randn(5), np.random.rand(5), shape=(), plates=(5,)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(5,3), np.random.rand(5,3), shape=(3,), plates=(5,)) x2 = X2.get_moments() m0 = tau * data * np.sum(np.ones((5,3)) * x2[0], axis=-1) m1 = -0.5 * tau * np.sum(x2[1] * misc.identity(3), axis=(-1,-2)) check_message(m0, m1, 0, ',i->i', X1, X2) check_message(m0, m1, 0, X1, [], X2, ['i'], ['i']) m0 = tau * data * x1[0][:,np.newaxis] * np.ones((5,3)) m1 = -0.5 * tau * x1[1][:,np.newaxis,np.newaxis] * misc.identity(3) check_message(m0, m1, 1, ',i->i', X1, X2) check_message(m0, m1, 1, X1, [], X2, ['i'], ['i']) # Check: other parent's moments broadcasts over plates when node has the # same plates X1 = GaussianARD(np.random.randn(5,4,3), np.random.rand(5,4,3), shape=(3,), plates=(5,4)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3,), plates=(5,4)) x2 = X2.get_moments() m0 = tau * data * np.ones((5,4,3)) * x2[0] m1 = -0.5 * tau * x2[1] * misc.identity(3) check_message(m0, m1, 0, 'i,i->i', X1, X2) check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i']) # Check: other parent's moments broadcasts over plates when node does # not have that plate X1 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3,), plates=()) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3,), plates=(5,4)) x2 = X2.get_moments() m0 = tau * data * np.sum(np.ones((5,4,3)) * x2[0], axis=(0,1)) m1 = -0.5 * tau * np.sum(np.ones((5,4,1,1)) * misc.identity(3) * x2[1], axis=(0,1)) check_message(m0, m1, 0, 'i,i->i', X1, X2) check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i']) # Check: other parent's moments broadcasts over plates when the node # only broadcasts that plate X1 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3,), plates=(1,1)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3,), plates=(5,4)) x2 = X2.get_moments() m0 = tau * data * np.sum(np.ones((5,4,3)) * x2[0], axis=(0,1), keepdims=True) m1 = -0.5 * tau * np.sum(np.ones((5,4,1,1)) * misc.identity(3) * x2[1], axis=(0,1), keepdims=True) check_message(m0, m1, 0, 'i,i->i', X1, X2) check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i']) # Check: broadcasted dimensions X1 = GaussianARD(np.random.randn(1,1), np.random.rand(1,1), ndim=2) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3,2), np.random.rand(3,2), ndim=2) x2 = X2.get_moments() m0 = tau * data * np.sum(np.ones((3,2)) * x2[0], keepdims=True) m1 = -0.5 * tau * np.sum(misc.identity(3,2) * x2[1], keepdims=True) check_message(m0, m1, 0, 'ij,ij->ij', X1, X2) check_message(m0, m1, 0, X1, [0,1], X2, [0,1], [0,1]) m0 = tau * data * np.ones((3,2)) * x1[0] m1 = -0.5 * tau * misc.identity(3,2) * x1[1] check_message(m0, m1, 1, 'ij,ij->ij', X1, X2) check_message(m0, m1, 1, X1, [0,1], X2, [0,1], [0,1]) # Check: non-ARD observations X1 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1) x1 = X1.get_moments() Lambda = np.array([[2, 1.5], [1.5, 2]]) F = SumMultiply('i->i', X1) Y = Gaussian(F, Lambda) y = np.random.randn(2) Y.observe(y) m0 = np.dot(Lambda, y) m1 = -0.5 * Lambda check_message(m0, m1, 0, 'i->i', X1, F=F) check_message(m0, m1, 0, X1, ['i'], ['i'], F=F) # Check: mask with same shape X1 = GaussianARD(np.random.randn(3,2), np.random.rand(3,2), shape=(2,), plates=(3,)) x1 = X1.get_moments() mask = np.array([True, False, True]) F = SumMultiply('i->i', X1) Y = GaussianARD(F, tau, ndim=1) Y.observe(data*np.ones((3,2)), mask=mask) m0 = tau * data * mask[:,np.newaxis] * np.ones(2) m1 = -0.5 * tau * mask[:,np.newaxis,np.newaxis] * np.identity(2) check_message(m0, m1, 0, 'i->i', X1, F=F) check_message(m0, m1, 0, X1, ['i'], ['i'], F=F) # Check: mask larger X1 = GaussianARD(np.random.randn(2), np.random.rand(2), shape=(2,), plates=()) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3,2), np.random.rand(3,2), shape=(2,), plates=(3,)) x2 = X2.get_moments() mask = np.array([True, False, True]) F = SumMultiply('i,i->i', X1, X2) Y = GaussianARD(F, tau, plates=(3,), ndim=1) Y.observe(data*np.ones((3,2)), mask=mask) m0 = tau * data * np.sum(mask[:,np.newaxis] * x2[0], axis=0) m1 = -0.5 * tau * np.sum(mask[:,np.newaxis,np.newaxis] * x2[1] * np.identity(2), axis=0) check_message(m0, m1, 0, 'i,i->i', X1, X2, F=F) check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i'], F=F) # Check: mask for broadcasted plate X1 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1, plates=(1,)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1, plates=(3,)) x2 = X2.get_moments() mask = np.array([True, False, True]) F = SumMultiply('i,i->i', X1, X2) Y = GaussianARD(F, tau, plates=(3,), ndim=1) Y.observe(data*np.ones((3,2)), mask=mask) m0 = tau * data * np.sum(mask[:,np.newaxis] * x2[0], axis=0, keepdims=True) m1 = -0.5 * tau * np.sum(mask[:,np.newaxis,np.newaxis] * x2[1] * np.identity(2), axis=0, keepdims=True) check_message(m0, m1, 0, 'i->i', X1, F=F) check_message(m0, m1, 0, X1, ['i'], ['i'], F=F) # Check: Gaussian-gamma parents X1 = GaussianGamma( np.random.randn(2), random.covariance(2), np.random.rand(), np.random.rand() ) x1 = X1.get_moments() X2 = GaussianGamma( np.random.randn(2), random.covariance(2), np.random.rand(), np.random.rand() ) x2 = X2.get_moments() F = SumMultiply('i,i->i', X1, X2) V = random.covariance(2) y = np.random.randn(2) Y = Gaussian(F, V) Y.observe(y) m0 = np.dot(V, y) * x2[0] m1 = -0.5 * V * x2[1] m2 = -0.5 * np.einsum('i,ij,j', y, V, y) * x2[2]#linalg.inner(V, x2[2], ndim=2) m3 = 0.5 * 2 #linalg.chol_logdet(linalg.chol(V)) + 2*x2[3] m = F._message_to_parent(0) self.assertAllClose(m[0], m0) self.assertAllClose(m[1], m1) self.assertAllClose(m[2], m2) self.assertAllClose(m[3], m3) pass
def test_message_to_parent_mu(self): """ Test that GaussianARD computes the message to the 1st parent correctly. """ # Check formula with uncertain parent alpha mu = GaussianARD(0, 1) alpha = Gamma(2, 1) X = GaussianARD(mu, alpha) X.observe(3) (m0, m1) = mu._message_from_children() #(m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2 * 3) self.assertAllClose(m1, -0.5 * 2) # Check formula with uncertain node mu = GaussianARD(1, 1e10) X = GaussianARD(mu, 2) Y = GaussianARD(X, 1) Y.observe(5) X.update() (m0, m1) = mu._message_from_children() self.assertAllClose(m0, 2 * 1 / (2 + 1) * (2 * 1 + 1 * 5)) self.assertAllClose(m1, -0.5 * 2) # Check alpha larger than mu mu = GaussianARD(np.zeros((2, 3)), 1e10, shape=(2, 3)) X = GaussianARD(mu, 2 * np.ones((3, 2, 3))) X.observe(3 * np.ones((3, 2, 3))) (m0, m1) = mu._message_from_children() self.assertAllClose(m0, 2 * 3 * 3 * np.ones((2, 3))) self.assertAllClose(m1, -0.5 * 3 * 2 * misc.identity(2, 3)) # Check mu larger than alpha mu = GaussianARD(np.zeros((3, 2, 3)), 1e10, shape=(3, 2, 3)) X = GaussianARD(mu, 2 * np.ones((2, 3))) X.observe(3 * np.ones((3, 2, 3))) (m0, m1) = mu._message_from_children() self.assertAllClose(m0, 2 * 3 * np.ones((3, 2, 3))) self.assertAllClose(m1, -0.5 * 2 * misc.identity(3, 2, 3)) # Check node larger than mu and alpha mu = GaussianARD(np.zeros((2, 3)), 1e10, shape=(2, 3)) X = GaussianARD(mu, 2 * np.ones((3, )), shape=(3, 2, 3)) X.observe(3 * np.ones((3, 2, 3))) (m0, m1) = mu._message_from_children() self.assertAllClose(m0, 2 * 3 * 3 * np.ones((2, 3))) self.assertAllClose(m1, -0.5 * 2 * 3 * misc.identity(2, 3)) # Check broadcasting of dimensions mu = GaussianARD(np.zeros((2, 1)), 1e10, shape=(2, 1)) X = GaussianARD(mu, 2 * np.ones((2, 3)), shape=(2, 3)) X.observe(3 * np.ones((2, 3))) (m0, m1) = mu._message_from_children() self.assertAllClose(m0, 2 * 3 * 3 * np.ones((2, 1))) self.assertAllClose(m1, -0.5 * 2 * 3 * misc.identity(2, 1)) # Check plates for smaller mu than node mu = GaussianARD(0, 1, shape=(3, ), plates=(4, 1, 1)) X = GaussianARD(mu, 2 * np.ones((3, )), shape=(2, 3), plates=(4, 5)) X.observe(3 * np.ones((4, 5, 2, 3))) (m0, m1) = mu._message_from_children() self.assertAllClose(m0 * np.ones((4, 1, 1, 3)), 2 * 3 * 5 * 2 * np.ones((4, 1, 1, 3))) self.assertAllClose( m1 * np.ones((4, 1, 1, 3, 3)), -0.5 * 2 * 5 * 2 * misc.identity(3) * np.ones((4, 1, 1, 3, 3))) # Check mask mu = GaussianARD(np.zeros((2, 1, 3)), 1e10, shape=(3, )) X = GaussianARD(mu, 2 * np.ones((2, 4, 3)), shape=(3, ), plates=( 2, 4, )) X.observe(3 * np.ones((2, 4, 3)), mask=[[True, True, True, False], [False, True, False, True]]) (m0, m1) = mu._message_from_children() self.assertAllClose(m0, (2 * 3 * np.ones( (2, 1, 3)) * np.array([[[3]], [[2]]]))) self.assertAllClose(m1, (-0.5 * 2 * misc.identity(3) * np.ones( (2, 1, 1, 1)) * np.array([[[[3]]], [[[2]]]]))) # Check mask with different shapes mu = GaussianARD(np.zeros((2, 1, 3)), 1e10, shape=()) X = GaussianARD(mu, 2 * np.ones((2, 4, 3)), shape=(3, ), plates=( 2, 4, )) mask = np.array([[True, True, True, False], [False, True, False, True]]) X.observe(3 * np.ones((2, 4, 3)), mask=mask) (m0, m1) = mu._message_from_children() self.assertAllClose( m0, 2 * 3 * np.sum( np.ones((2, 4, 3)) * mask[..., None], axis=-2, keepdims=True)) self.assertAllClose(m1, (-0.5 * 2 * np.sum( np.ones((2, 4, 3)) * mask[..., None], axis=-2, keepdims=True))) # Check non-ARD Gaussian child mu = np.array([1, 2]) Mu = GaussianARD(mu, 1e10, shape=(2, )) alpha = np.array([3, 4]) Lambda = np.array([[1, 0.5], [0.5, 1]]) X = GaussianARD(Mu, alpha) Y = Gaussian(X, Lambda) y = np.array([5, 6]) Y.observe(y) X.update() (m0, m1) = Mu._message_from_children() mean = np.dot(np.linalg.inv(np.diag(alpha) + Lambda), np.dot(np.diag(alpha), mu) + np.dot(Lambda, y)) self.assertAllClose(m0, np.dot(np.diag(alpha), mean)) self.assertAllClose(m1, -0.5 * np.diag(alpha)) # Check broadcasted variable axes mu = GaussianARD(np.zeros(1), 1e10, shape=(1, )) X = GaussianARD(mu, 2, shape=(3, )) X.observe(3 * np.ones(3)) (m0, m1) = mu._message_from_children() self.assertAllClose(m0, 2 * 3 * np.sum(np.ones(3), axis=-1, keepdims=True)) self.assertAllClose( m1, -0.5 * 2 * np.sum(np.identity(3), axis=(-1, -2), keepdims=True)) pass
def test_message_to_child(self): """ Test moments of GaussianARD. """ # Check that moments have full shape when broadcasting X = GaussianARD(np.zeros((2, )), np.ones((3, 2)), shape=(4, 3, 2)) (u0, u1) = X._message_to_child() self.assertEqual(np.shape(u0), (4, 3, 2)) self.assertEqual(np.shape(u1), (4, 3, 2, 4, 3, 2)) # Check the formula X = GaussianARD(2, 3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2) self.assertAllClose(u1, 2**2 + 1 / 3) # Check the formula for multidimensional arrays X = GaussianARD(2 * np.ones((2, 1, 4)), 3 * np.ones((2, 3, 1)), ndim=3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2 * np.ones((2, 3, 4))) self.assertAllClose( u1, 2**2 * np.ones( (2, 3, 4, 2, 3, 4)) + 1 / 3 * misc.identity(2, 3, 4)) # Check the formula for dim-broadcasted mu X = GaussianARD(2 * np.ones((3, 1)), 3 * np.ones((2, 3, 4)), ndim=3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2 * np.ones((2, 3, 4))) self.assertAllClose( u1, 2**2 * np.ones( (2, 3, 4, 2, 3, 4)) + 1 / 3 * misc.identity(2, 3, 4)) # Check the formula for dim-broadcasted alpha X = GaussianARD(2 * np.ones((2, 3, 4)), 3 * np.ones((3, 1)), ndim=3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2 * np.ones((2, 3, 4))) self.assertAllClose( u1, 2**2 * np.ones( (2, 3, 4, 2, 3, 4)) + 1 / 3 * misc.identity(2, 3, 4)) # Check the formula for dim-broadcasted mu and alpha X = GaussianARD(2 * np.ones((3, 1)), 3 * np.ones((3, 1)), shape=(2, 3, 4)) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2 * np.ones((2, 3, 4))) self.assertAllClose( u1, 2**2 * np.ones( (2, 3, 4, 2, 3, 4)) + 1 / 3 * misc.identity(2, 3, 4)) # Check the formula for dim-broadcasted mu with plates mu = GaussianARD(2 * np.ones((5, 1, 3, 4)), np.ones((5, 1, 3, 4)), shape=(3, 4), plates=(5, 1)) X = GaussianARD(mu, 3 * np.ones((5, 2, 3, 4)), shape=(2, 3, 4), plates=(5, )) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2 * np.ones((5, 2, 3, 4))) self.assertAllClose( u1, 2**2 * np.ones( (5, 2, 3, 4, 2, 3, 4)) + 1 / 3 * misc.identity(2, 3, 4)) # Check posterior X = GaussianARD(2, 3) Y = GaussianARD(X, 1) Y.observe(10) X.update() (u0, u1) = X._message_to_child() self.assertAllClose(u0, 1 / (3 + 1) * (3 * 2 + 1 * 10)) self.assertAllClose(u1, (1 / (3 + 1) * (3 * 2 + 1 * 10))**2 + 1 / (3 + 1)) pass
def test_message_to_parent_mu(self): """ Test that GaussianARD computes the message to the 1st parent correctly. """ # Check formula with uncertain parent alpha mu = GaussianARD(0, 1) alpha = Gamma(2,1) X = GaussianARD(mu, alpha) X.observe(3) (m0, m1) = mu._message_from_children() #(m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2*3) self.assertAllClose(m1, -0.5*2) # Check formula with uncertain node mu = GaussianARD(1, 1e10) X = GaussianARD(mu, 2) Y = GaussianARD(X, 1) Y.observe(5) X.update() (m0, m1) = mu._message_from_children() self.assertAllClose(m0, 2 * 1/(2+1)*(2*1+1*5)) self.assertAllClose(m1, -0.5*2) # Check alpha larger than mu mu = GaussianARD(np.zeros((2,3)), 1e10, shape=(2,3)) X = GaussianARD(mu, 2*np.ones((3,2,3))) X.observe(3*np.ones((3,2,3))) (m0, m1) = mu._message_from_children() self.assertAllClose(m0, 2*3 * 3 * np.ones((2,3))) self.assertAllClose(m1, -0.5 * 3 * 2*misc.identity(2,3)) # Check mu larger than alpha mu = GaussianARD(np.zeros((3,2,3)), 1e10, shape=(3,2,3)) X = GaussianARD(mu, 2*np.ones((2,3))) X.observe(3*np.ones((3,2,3))) (m0, m1) = mu._message_from_children() self.assertAllClose(m0, 2 * 3 * np.ones((3,2,3))) self.assertAllClose(m1, -0.5 * 2*misc.identity(3,2,3)) # Check node larger than mu and alpha mu = GaussianARD(np.zeros((2,3)), 1e10, shape=(2,3)) X = GaussianARD(mu, 2*np.ones((3,)), shape=(3,2,3)) X.observe(3*np.ones((3,2,3))) (m0, m1) = mu._message_from_children() self.assertAllClose(m0, 2*3 * 3*np.ones((2,3))) self.assertAllClose(m1, -0.5 * 2 * 3*misc.identity(2,3)) # Check broadcasting of dimensions mu = GaussianARD(np.zeros((2,1)), 1e10, shape=(2,1)) X = GaussianARD(mu, 2*np.ones((2,3)), shape=(2,3)) X.observe(3*np.ones((2,3))) (m0, m1) = mu._message_from_children() self.assertAllClose(m0, 2*3 * 3*np.ones((2,1))) self.assertAllClose(m1, -0.5 * 2 * 3*misc.identity(2,1)) # Check plates for smaller mu than node mu = GaussianARD(0,1, shape=(3,), plates=(4,1,1)) X = GaussianARD(mu, 2*np.ones((3,)), shape=(2,3), plates=(4,5)) X.observe(3*np.ones((4,5,2,3))) (m0, m1) = mu._message_from_children() self.assertAllClose(m0 * np.ones((4,1,1,3)), 2*3 * 5*2*np.ones((4,1,1,3))) self.assertAllClose(m1 * np.ones((4,1,1,3,3)), -0.5*2 * 5*2*misc.identity(3) * np.ones((4,1,1,3,3))) # Check mask mu = GaussianARD(np.zeros((2,1,3)), 1e10, shape=(3,)) X = GaussianARD(mu, 2*np.ones((2,4,3)), shape=(3,), plates=(2,4,)) X.observe(3*np.ones((2,4,3)), mask=[[True, True, True, False], [False, True, False, True]]) (m0, m1) = mu._message_from_children() self.assertAllClose(m0, (2*3 * np.ones((2,1,3)) * np.array([[[3]], [[2]]]))) self.assertAllClose(m1, (-0.5*2 * misc.identity(3) * np.ones((2,1,1,1)) * np.array([[[[3]]], [[[2]]]]))) # Check mask with different shapes mu = GaussianARD(np.zeros((2,1,3)), 1e10, shape=()) X = GaussianARD(mu, 2*np.ones((2,4,3)), shape=(3,), plates=(2,4,)) mask = np.array([[True, True, True, False], [False, True, False, True]]) X.observe(3*np.ones((2,4,3)), mask=mask) (m0, m1) = mu._message_from_children() self.assertAllClose(m0, 2*3 * np.sum(np.ones((2,4,3))*mask[...,None], axis=-2, keepdims=True)) self.assertAllClose(m1, (-0.5*2 * np.sum(np.ones((2,4,3))*mask[...,None], axis=-2, keepdims=True))) # Check non-ARD Gaussian child mu = np.array([1,2]) Mu = GaussianARD(mu, 1e10, shape=(2,)) alpha = np.array([3,4]) Lambda = np.array([[1, 0.5], [0.5, 1]]) X = GaussianARD(Mu, alpha, ndim=1) Y = Gaussian(X, Lambda) y = np.array([5,6]) Y.observe(y) X.update() (m0, m1) = Mu._message_from_children() mean = np.dot(np.linalg.inv(np.diag(alpha)+Lambda), np.dot(np.diag(alpha), mu) + np.dot(Lambda, y)) self.assertAllClose(m0, np.dot(np.diag(alpha), mean)) self.assertAllClose(m1, -0.5*np.diag(alpha)) # Check broadcasted variable axes mu = GaussianARD(np.zeros(1), 1e10, shape=(1,)) X = GaussianARD(mu, 2, shape=(3,)) X.observe(3*np.ones(3)) (m0, m1) = mu._message_from_children() self.assertAllClose(m0, 2*3 * np.sum(np.ones(3), axis=-1, keepdims=True)) self.assertAllClose(m1, -0.5*2 * np.sum(np.identity(3), axis=(-1,-2), keepdims=True)) pass
def test_message_to_child(self): """ Test moments of GaussianARD. """ # Check that moments have full shape when broadcasting X = GaussianARD(np.zeros((2,)), np.ones((3,2)), shape=(4,3,2)) (u0, u1) = X._message_to_child() self.assertEqual(np.shape(u0), (4,3,2)) self.assertEqual(np.shape(u1), (4,3,2,4,3,2)) # Check the formula X = GaussianARD(2, 3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2) self.assertAllClose(u1, 2**2 + 1/3) # Check the formula for multidimensional arrays X = GaussianARD(2*np.ones((2,1,4)), 3*np.ones((2,3,1)), ndim=3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2*np.ones((2,3,4))) self.assertAllClose(u1, 2**2 * np.ones((2,3,4,2,3,4)) + 1/3 * misc.identity(2,3,4)) # Check the formula for dim-broadcasted mu X = GaussianARD(2*np.ones((3,1)), 3*np.ones((2,3,4)), ndim=3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2*np.ones((2,3,4))) self.assertAllClose(u1, 2**2 * np.ones((2,3,4,2,3,4)) + 1/3 * misc.identity(2,3,4)) # Check the formula for dim-broadcasted alpha X = GaussianARD(2*np.ones((2,3,4)), 3*np.ones((3,1)), ndim=3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2*np.ones((2,3,4))) self.assertAllClose(u1, 2**2 * np.ones((2,3,4,2,3,4)) + 1/3 * misc.identity(2,3,4)) # Check the formula for dim-broadcasted mu and alpha X = GaussianARD(2*np.ones((3,1)), 3*np.ones((3,1)), shape=(2,3,4)) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2*np.ones((2,3,4))) self.assertAllClose(u1, 2**2 * np.ones((2,3,4,2,3,4)) + 1/3 * misc.identity(2,3,4)) # Check the formula for dim-broadcasted mu with plates mu = GaussianARD(2*np.ones((5,1,3,4)), np.ones((5,1,3,4)), shape=(3,4), plates=(5,1)) X = GaussianARD(mu, 3*np.ones((5,2,3,4)), shape=(2,3,4), plates=(5,)) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2*np.ones((5,2,3,4))) self.assertAllClose(u1, 2**2 * np.ones((5,2,3,4,2,3,4)) + 1/3 * misc.identity(2,3,4)) # Check posterior X = GaussianARD(2, 3) Y = GaussianARD(X, 1) Y.observe(10) X.update() (u0, u1) = X._message_to_child() self.assertAllClose(u0, 1/(3+1) * (3*2 + 1*10)) self.assertAllClose(u1, (1/(3+1) * (3*2 + 1*10))**2 + 1/(3+1)) pass
def test_message_to_parent(self): """ Test the message from SumMultiply node to its parents. """ data = 2 tau = 3 def check_message(true_m0, true_m1, parent, *args, F=None): if F is None: A = SumMultiply(*args) B = GaussianARD(A, tau) B.observe(data * np.ones(A.plates + A.dims[0])) else: A = F (A_m0, A_m1) = A._message_to_parent(parent) self.assertAllClose(true_m0, A_m0) self.assertAllClose(true_m1, A_m1) pass # Check: different message to each of multiple parents X1 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1) x2 = X2.get_moments() m0 = tau * data * x2[0] m1 = -0.5 * tau * x2[1] * np.identity(2) check_message(m0, m1, 0, 'i,i->i', X1, X2) check_message(m0, m1, 0, X1, [9], X2, [9], [9]) m0 = tau * data * x1[0] m1 = -0.5 * tau * x1[1] * np.identity(2) check_message(m0, m1, 1, 'i,i->i', X1, X2) check_message(m0, m1, 1, X1, [9], X2, [9], [9]) # Check: key not in output X1 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1) x1 = X1.get_moments() m0 = tau * data * np.ones(2) m1 = -0.5 * tau * np.ones((2, 2)) check_message(m0, m1, 0, 'i', X1) check_message(m0, m1, 0, 'i->', X1) check_message(m0, m1, 0, X1, [9]) check_message(m0, m1, 0, X1, [9], []) # Check: key not in some input X1 = GaussianARD(np.random.randn(), np.random.rand()) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1) x2 = X2.get_moments() m0 = tau * data * np.sum(x2[0], axis=-1) m1 = -0.5 * tau * np.sum(x2[1] * np.identity(2), axis=(-1, -2)) check_message(m0, m1, 0, ',i->i', X1, X2) check_message(m0, m1, 0, X1, [], X2, [9], [9]) m0 = tau * data * x1[0] * np.ones(2) m1 = -0.5 * tau * x1[1] * np.identity(2) check_message(m0, m1, 1, ',i->i', X1, X2) check_message(m0, m1, 1, X1, [], X2, [9], [9]) # Check: keys in different order Y1 = GaussianARD(np.random.randn(3, 2), np.random.rand(3, 2), ndim=2) y1 = Y1.get_moments() Y2 = GaussianARD(np.random.randn(2, 3), np.random.rand(2, 3), ndim=2) y2 = Y2.get_moments() m0 = tau * data * y2[0].T m1 = -0.5 * tau * np.einsum('ijlk->jikl', y2[1] * misc.identity(2, 3)) check_message(m0, m1, 0, 'ij,ji->ij', Y1, Y2) check_message(m0, m1, 0, Y1, ['i', 'j'], Y2, ['j', 'i'], ['i', 'j']) m0 = tau * data * y1[0].T m1 = -0.5 * tau * np.einsum('ijlk->jikl', y1[1] * misc.identity(3, 2)) check_message(m0, m1, 1, 'ij,ji->ij', Y1, Y2) check_message(m0, m1, 1, Y1, ['i', 'j'], Y2, ['j', 'i'], ['i', 'j']) # Check: plates when different dimensionality X1 = GaussianARD(np.random.randn(5), np.random.rand(5), shape=(), plates=(5, )) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(5, 3), np.random.rand(5, 3), shape=(3, ), plates=(5, )) x2 = X2.get_moments() m0 = tau * data * np.sum(np.ones((5, 3)) * x2[0], axis=-1) m1 = -0.5 * tau * np.sum(x2[1] * misc.identity(3), axis=(-1, -2)) check_message(m0, m1, 0, ',i->i', X1, X2) check_message(m0, m1, 0, X1, [], X2, ['i'], ['i']) m0 = tau * data * x1[0][:, np.newaxis] * np.ones((5, 3)) m1 = -0.5 * tau * x1[1][:, np.newaxis, np.newaxis] * misc.identity(3) check_message(m0, m1, 1, ',i->i', X1, X2) check_message(m0, m1, 1, X1, [], X2, ['i'], ['i']) # Check: other parent's moments broadcasts over plates when node has the # same plates X1 = GaussianARD(np.random.randn(5, 4, 3), np.random.rand(5, 4, 3), shape=(3, ), plates=(5, 4)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3, ), plates=(5, 4)) x2 = X2.get_moments() m0 = tau * data * np.ones((5, 4, 3)) * x2[0] m1 = -0.5 * tau * x2[1] * misc.identity(3) check_message(m0, m1, 0, 'i,i->i', X1, X2) check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i']) # Check: other parent's moments broadcasts over plates when node does # not have that plate X1 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3, ), plates=()) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3, ), plates=(5, 4)) x2 = X2.get_moments() m0 = tau * data * np.sum(np.ones((5, 4, 3)) * x2[0], axis=(0, 1)) m1 = -0.5 * tau * np.sum(np.ones( (5, 4, 1, 1)) * misc.identity(3) * x2[1], axis=(0, 1)) check_message(m0, m1, 0, 'i,i->i', X1, X2) check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i']) # Check: other parent's moments broadcasts over plates when the node # only broadcasts that plate X1 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3, ), plates=(1, 1)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3, ), plates=(5, 4)) x2 = X2.get_moments() m0 = tau * data * np.sum( np.ones((5, 4, 3)) * x2[0], axis=(0, 1), keepdims=True) m1 = -0.5 * tau * np.sum(np.ones( (5, 4, 1, 1)) * misc.identity(3) * x2[1], axis=(0, 1), keepdims=True) check_message(m0, m1, 0, 'i,i->i', X1, X2) check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i']) # Check: broadcasted dimensions X1 = GaussianARD(np.random.randn(1, 1), np.random.rand(1, 1), ndim=2) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3, 2), np.random.rand(3, 2), ndim=2) x2 = X2.get_moments() m0 = tau * data * np.sum(np.ones((3, 2)) * x2[0], keepdims=True) m1 = -0.5 * tau * np.sum(misc.identity(3, 2) * x2[1], keepdims=True) check_message(m0, m1, 0, 'ij,ij->ij', X1, X2) check_message(m0, m1, 0, X1, [0, 1], X2, [0, 1], [0, 1]) m0 = tau * data * np.ones((3, 2)) * x1[0] m1 = -0.5 * tau * misc.identity(3, 2) * x1[1] check_message(m0, m1, 1, 'ij,ij->ij', X1, X2) check_message(m0, m1, 1, X1, [0, 1], X2, [0, 1], [0, 1]) # Check: non-ARD observations X1 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1) x1 = X1.get_moments() Lambda = np.array([[2, 1.5], [1.5, 2]]) F = SumMultiply('i->i', X1) Y = Gaussian(F, Lambda) y = np.random.randn(2) Y.observe(y) m0 = np.dot(Lambda, y) m1 = -0.5 * Lambda check_message(m0, m1, 0, 'i->i', X1, F=F) check_message(m0, m1, 0, X1, ['i'], ['i'], F=F) # Check: mask with same shape X1 = GaussianARD(np.random.randn(3, 2), np.random.rand(3, 2), shape=(2, ), plates=(3, )) x1 = X1.get_moments() mask = np.array([True, False, True]) F = SumMultiply('i->i', X1) Y = GaussianARD(F, tau, ndim=1) Y.observe(data * np.ones((3, 2)), mask=mask) m0 = tau * data * mask[:, np.newaxis] * np.ones(2) m1 = -0.5 * tau * mask[:, np.newaxis, np.newaxis] * np.identity(2) check_message(m0, m1, 0, 'i->i', X1, F=F) check_message(m0, m1, 0, X1, ['i'], ['i'], F=F) # Check: mask larger X1 = GaussianARD(np.random.randn(2), np.random.rand(2), shape=(2, ), plates=()) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3, 2), np.random.rand(3, 2), shape=(2, ), plates=(3, )) x2 = X2.get_moments() mask = np.array([True, False, True]) F = SumMultiply('i,i->i', X1, X2) Y = GaussianARD(F, tau, plates=(3, ), ndim=1) Y.observe(data * np.ones((3, 2)), mask=mask) m0 = tau * data * np.sum(mask[:, np.newaxis] * x2[0], axis=0) m1 = -0.5 * tau * np.sum( mask[:, np.newaxis, np.newaxis] * x2[1] * np.identity(2), axis=0) check_message(m0, m1, 0, 'i,i->i', X1, X2, F=F) check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i'], F=F) # Check: mask for broadcasted plate X1 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1, plates=(1, )) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1, plates=(3, )) x2 = X2.get_moments() mask = np.array([True, False, True]) F = SumMultiply('i,i->i', X1, X2) Y = GaussianARD(F, tau, plates=(3, ), ndim=1) Y.observe(data * np.ones((3, 2)), mask=mask) m0 = tau * data * np.sum( mask[:, np.newaxis] * x2[0], axis=0, keepdims=True) m1 = -0.5 * tau * np.sum( mask[:, np.newaxis, np.newaxis] * x2[1] * np.identity(2), axis=0, keepdims=True) check_message(m0, m1, 0, 'i->i', X1, F=F) check_message(m0, m1, 0, X1, ['i'], ['i'], F=F) # Test with constant nodes N = 10 M = 8 D = 5 K = 3 a = np.random.randn(N, D) B = Gaussian( np.random.randn(D), random.covariance(D), ) C = GaussianARD(np.random.randn(M, 1, D, K), np.random.rand(M, 1, D, K), ndim=2) F = SumMultiply('i,i,ij->', a, B, C) tau = np.random.rand(M, N) Y = GaussianARD(F, tau, plates=(M, N)) y = np.random.randn(M, N) Y.observe(y) (m0, m1) = F._message_to_parent(1) np.testing.assert_allclose( m0, np.einsum('mn,ni,mnik->i', tau * y, a, C.get_moments()[0]), ) np.testing.assert_allclose( m1, np.einsum('mn,ni,nj,mnikjl->ij', -0.5 * tau, a, a, C.get_moments()[1]), ) # Check: Gaussian-gamma parents X1 = GaussianGamma(np.random.randn(2), random.covariance(2), np.random.rand(), np.random.rand()) x1 = X1.get_moments() X2 = GaussianGamma(np.random.randn(2), random.covariance(2), np.random.rand(), np.random.rand()) x2 = X2.get_moments() F = SumMultiply('i,i->i', X1, X2) V = random.covariance(2) y = np.random.randn(2) Y = Gaussian(F, V) Y.observe(y) m0 = np.dot(V, y) * x2[0] m1 = -0.5 * V * x2[1] m2 = -0.5 * np.einsum('i,ij,j', y, V, y) * x2[2] #linalg.inner(V, x2[2], ndim=2) m3 = 0.5 * 2 #linalg.chol_logdet(linalg.chol(V)) + 2*x2[3] m = F._message_to_parent(0) self.assertAllClose(m[0], m0) self.assertAllClose(m[1], m1) self.assertAllClose(m[2], m2) self.assertAllClose(m[3], m3) # Delta moments N = 10 M = 8 D = 5 a = np.random.randn(N, D) B = GaussianGamma(np.random.randn(D), random.covariance(D), np.random.rand(), np.random.rand(), ndim=1) F = SumMultiply('i,i->', a, B) tau = np.random.rand(M, N) Y = GaussianARD(F, tau, plates=(M, N)) y = np.random.randn(M, N) Y.observe(y) (m0, m1, m2, m3) = F._message_to_parent(1) np.testing.assert_allclose( m0, np.einsum('mn,ni->i', tau * y, a), ) np.testing.assert_allclose( m1, np.einsum('mn,ni,nj->ij', -0.5 * tau, a, a), ) np.testing.assert_allclose( m2, np.einsum('mn->', -0.5 * tau * y**2), ) np.testing.assert_allclose( m3, np.einsum('mn->', 0.5 * np.ones(np.shape(tau))), ) pass