def test_message_to_parent_mu(self): """ Test that GaussianARD computes the message to the 1st parent correctly. """ # Check formula with uncertain parent alpha alpha = Gamma(2,1) X = GaussianArrayARD(0, alpha) X.observe(3) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2*3) self.assertAllClose(m1, -0.5*2) # Check formula with uncertain node X = GaussianArrayARD(1, 2) Y = GaussianArrayARD(X, 1) Y.observe(5) X.update() (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2 * 1/(2+1)*(2*1+1*5)) self.assertAllClose(m1, -0.5*2) # Check alpha larger than mu X = GaussianArrayARD(np.zeros((2,3)), 2*np.ones((3,2,3))) X.observe(3*np.ones((3,2,3))) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2*3 * 3 * np.ones((2,3))) self.assertAllClose(m1, -0.5 * 3 * 2*utils.identity(2,3)) # Check mu larger than alpha X = GaussianArrayARD(np.zeros((3,2,3)), 2*np.ones((2,3))) X.observe(3*np.ones((3,2,3))) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2 * 3 * np.ones((3,2,3))) self.assertAllClose(m1, -0.5 * 2*utils.identity(3,2,3)) # Check node larger than mu and alpha X = GaussianArrayARD(np.zeros((2,3)), 2*np.ones((3,)), shape=(3,2,3)) X.observe(3*np.ones((3,2,3))) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2*3 * 3*np.ones((2,3))) self.assertAllClose(m1, -0.5 * 2 * 3*utils.identity(2,3)) # Check broadcasting of dimensions X = GaussianArrayARD(np.zeros((2,1)), 2*np.ones((2,3)), shape=(2,3)) X.observe(3*np.ones((2,3))) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2*3 * 3*np.ones((2,1))) self.assertAllClose(m1, -0.5 * 2 * 3*utils.identity(2,1)) # Check plates for smaller mu than node X = GaussianArrayARD(GaussianArrayARD(0,1, shape=(3,), plates=(4,1)), 2*np.ones((3,)), shape=(2,3), plates=(4,5)) X.observe(3*np.ones((4,5,2,3))) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2*3 * 5*2*np.ones((4,1,3))) self.assertAllClose(m1, -0.5*2 * 5*2*utils.identity(3)) # Check mask X = GaussianArrayARD(np.zeros((2,1,3)), 2*np.ones((2,4,3)), shape=(3,), plates=(2,4,)) X.observe(3*np.ones((2,4,3)), mask=[[True, True, True, False], [False, True, False, True]]) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, (2*3 * np.ones((2,1,3)) * np.array([[[3]], [[2]]]))) self.assertAllClose(m1, (-0.5*2 * utils.identity(3) * np.ones((2,1,1,1)) * np.array([[[[3]]], [[[2]]]]))) # Check non-ARD Gaussian child mu = np.array([1,2]) alpha = np.array([3,4]) Lambda = np.array([[1, 0.5], [0.5, 1]]) X = GaussianArrayARD(mu, alpha) Y = Gaussian(X, Lambda) y = np.array([5,6]) Y.observe(y) X.update() (m0, m1) = X._message_to_parent(0) mean = np.dot(np.linalg.inv(np.diag(alpha)+Lambda), np.dot(np.diag(alpha), mu) + np.dot(Lambda, y)) self.assertAllClose(m0, np.dot(np.diag(alpha), mean)) self.assertAllClose(m1, -0.5*np.diag(alpha)) pass
def test_message_to_child(self): """ Test that GaussianArrayARD computes the message to children correctly. """ # Check that moments have full shape when broadcasting X = GaussianArrayARD(np.zeros((2,)), np.ones((3,2)), shape=(4,3,2)) (u0, u1) = X._message_to_child() self.assertEqual(np.shape(u0), (4,3,2)) self.assertEqual(np.shape(u1), (4,3,2,4,3,2)) # Check the formula X = GaussianArrayARD(2, 3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2) self.assertAllClose(u1, 2**2 + 1/3) # Check the formula for multidimensional arrays X = GaussianArrayARD(2*np.ones((2,1,4)), 3*np.ones((2,3,1)), ndim=3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2*np.ones((2,3,4))) self.assertAllClose(u1, 2**2 * np.ones((2,3,4,2,3,4)) + 1/3 * utils.identity(2,3,4)) # Check the formula for dim-broadcasted mu X = GaussianArrayARD(2*np.ones((3,1)), 3*np.ones((2,3,4)), ndim=3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2*np.ones((2,3,4))) self.assertAllClose(u1, 2**2 * np.ones((2,3,4,2,3,4)) + 1/3 * utils.identity(2,3,4)) # Check the formula for dim-broadcasted alpha X = GaussianArrayARD(2*np.ones((2,3,4)), 3*np.ones((3,1)), ndim=3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2*np.ones((2,3,4))) self.assertAllClose(u1, 2**2 * np.ones((2,3,4,2,3,4)) + 1/3 * utils.identity(2,3,4)) # Check the formula for dim-broadcasted mu and alpha X = GaussianArrayARD(2*np.ones((3,1)), 3*np.ones((3,1)), shape=(2,3,4)) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2*np.ones((2,3,4))) self.assertAllClose(u1, 2**2 * np.ones((2,3,4,2,3,4)) + 1/3 * utils.identity(2,3,4)) # Check the formula for dim-broadcasted mu with plates mu = GaussianArrayARD(2*np.ones((5,3,4)), np.ones((5,3,4)), shape=(3,4), plates=(5,)) X = GaussianArrayARD(mu, 3*np.ones((5,2,3,4)), shape=(2,3,4), plates=(5,)) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2*np.ones((5,2,3,4))) self.assertAllClose(u1, 2**2 * np.ones((5,2,3,4,2,3,4)) + 1/3 * utils.identity(2,3,4)) # Check posterior X = GaussianArrayARD(2, 3) Y = GaussianArrayARD(X, 1) Y.observe(10) X.update() (u0, u1) = X._message_to_child() self.assertAllClose(u0, 1/(3+1) * (3*2 + 1*10)) self.assertAllClose(u1, (1/(3+1) * (3*2 + 1*10))**2 + 1/(3+1)) pass
def test_message_to_child(self): """ Test that GaussianARD computes the message to children correctly. """ # Check that moments have full shape when broadcasting X = GaussianARD(np.zeros((2, )), np.ones((3, 2)), shape=(4, 3, 2)) (u0, u1) = X._message_to_child() self.assertEqual(np.shape(u0), (4, 3, 2)) self.assertEqual(np.shape(u1), (4, 3, 2, 4, 3, 2)) # Check the formula X = GaussianARD(2, 3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2) self.assertAllClose(u1, 2**2 + 1 / 3) # Check the formula for multidimensional arrays X = GaussianARD(2 * np.ones((2, 1, 4)), 3 * np.ones((2, 3, 1)), ndim=3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2 * np.ones((2, 3, 4))) self.assertAllClose( u1, 2**2 * np.ones( (2, 3, 4, 2, 3, 4)) + 1 / 3 * utils.identity(2, 3, 4)) # Check the formula for dim-broadcasted mu X = GaussianARD(2 * np.ones((3, 1)), 3 * np.ones((2, 3, 4)), ndim=3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2 * np.ones((2, 3, 4))) self.assertAllClose( u1, 2**2 * np.ones( (2, 3, 4, 2, 3, 4)) + 1 / 3 * utils.identity(2, 3, 4)) # Check the formula for dim-broadcasted alpha X = GaussianARD(2 * np.ones((2, 3, 4)), 3 * np.ones((3, 1)), ndim=3) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2 * np.ones((2, 3, 4))) self.assertAllClose( u1, 2**2 * np.ones( (2, 3, 4, 2, 3, 4)) + 1 / 3 * utils.identity(2, 3, 4)) # Check the formula for dim-broadcasted mu and alpha X = GaussianARD(2 * np.ones((3, 1)), 3 * np.ones((3, 1)), shape=(2, 3, 4)) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2 * np.ones((2, 3, 4))) self.assertAllClose( u1, 2**2 * np.ones( (2, 3, 4, 2, 3, 4)) + 1 / 3 * utils.identity(2, 3, 4)) # Check the formula for dim-broadcasted mu with plates mu = GaussianARD(2 * np.ones((5, 3, 4)), np.ones((5, 3, 4)), shape=(3, 4), plates=(5, )) X = GaussianARD(mu, 3 * np.ones((5, 2, 3, 4)), shape=(2, 3, 4), plates=(5, )) (u0, u1) = X._message_to_child() self.assertAllClose(u0, 2 * np.ones((5, 2, 3, 4))) self.assertAllClose( u1, 2**2 * np.ones( (5, 2, 3, 4, 2, 3, 4)) + 1 / 3 * utils.identity(2, 3, 4)) # Check posterior X = GaussianARD(2, 3) Y = GaussianARD(X, 1) Y.observe(10) X.update() (u0, u1) = X._message_to_child() self.assertAllClose(u0, 1 / (3 + 1) * (3 * 2 + 1 * 10)) self.assertAllClose(u1, (1 / (3 + 1) * (3 * 2 + 1 * 10))**2 + 1 / (3 + 1)) pass
def test_message_to_parent_mu(self): """ Test that GaussianARD computes the message to the 1st parent correctly. """ # Check formula with uncertain parent alpha alpha = Gamma(2, 1) X = GaussianARD(0, alpha) X.observe(3) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2 * 3) self.assertAllClose(m1, -0.5 * 2) # Check formula with uncertain node X = GaussianARD(1, 2) Y = GaussianARD(X, 1) Y.observe(5) X.update() (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2 * 1 / (2 + 1) * (2 * 1 + 1 * 5)) self.assertAllClose(m1, -0.5 * 2) # Check alpha larger than mu X = GaussianARD(np.zeros((2, 3)), 2 * np.ones((3, 2, 3))) X.observe(3 * np.ones((3, 2, 3))) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2 * 3 * 3 * np.ones((2, 3))) self.assertAllClose(m1, -0.5 * 3 * 2 * utils.identity(2, 3)) # Check mu larger than alpha X = GaussianARD(np.zeros((3, 2, 3)), 2 * np.ones((2, 3))) X.observe(3 * np.ones((3, 2, 3))) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2 * 3 * np.ones((3, 2, 3))) self.assertAllClose(m1, -0.5 * 2 * utils.identity(3, 2, 3)) # Check node larger than mu and alpha X = GaussianARD(np.zeros((2, 3)), 2 * np.ones((3, )), shape=(3, 2, 3)) X.observe(3 * np.ones((3, 2, 3))) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2 * 3 * 3 * np.ones((2, 3))) self.assertAllClose(m1, -0.5 * 2 * 3 * utils.identity(2, 3)) # Check broadcasting of dimensions X = GaussianARD(np.zeros((2, 1)), 2 * np.ones((2, 3)), shape=(2, 3)) X.observe(3 * np.ones((2, 3))) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2 * 3 * 3 * np.ones((2, 1))) self.assertAllClose(m1, -0.5 * 2 * 3 * utils.identity(2, 1)) # Check plates for smaller mu than node X = GaussianARD(GaussianARD(0, 1, shape=(3, ), plates=(4, 1)), 2 * np.ones((3, )), shape=(2, 3), plates=(4, 5)) X.observe(3 * np.ones((4, 5, 2, 3))) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, 2 * 3 * 5 * 2 * np.ones((4, 1, 3))) self.assertAllClose(m1, -0.5 * 2 * 5 * 2 * utils.identity(3)) # Check mask X = GaussianARD(np.zeros((2, 1, 3)), 2 * np.ones((2, 4, 3)), shape=(3, ), plates=( 2, 4, )) X.observe(3 * np.ones((2, 4, 3)), mask=[[True, True, True, False], [False, True, False, True]]) (m0, m1) = X._message_to_parent(0) self.assertAllClose(m0, (2 * 3 * np.ones( (2, 1, 3)) * np.array([[[3]], [[2]]]))) self.assertAllClose(m1, (-0.5 * 2 * utils.identity(3) * np.ones( (2, 1, 1, 1)) * np.array([[[[3]]], [[[2]]]]))) # Check non-ARD Gaussian child mu = np.array([1, 2]) alpha = np.array([3, 4]) Lambda = np.array([[1, 0.5], [0.5, 1]]) X = GaussianARD(mu, alpha) Y = Gaussian(X, Lambda) y = np.array([5, 6]) Y.observe(y) X.update() (m0, m1) = X._message_to_parent(0) mean = np.dot(np.linalg.inv(np.diag(alpha) + Lambda), np.dot(np.diag(alpha), mu) + np.dot(Lambda, y)) self.assertAllClose(m0, np.dot(np.diag(alpha), mean)) self.assertAllClose(m1, -0.5 * np.diag(alpha)) pass
def test_message_to_parent(self): """ Test the message from SumMultiply node to its parents. """ data = 2 tau = 3 def check_message(true_m0, true_m1, parent, *args, F=None): if F is None: A = SumMultiply(*args) B = GaussianARD(A, tau) B.observe(data*np.ones(A.plates + A.dims[0])) else: A = F (A_m0, A_m1) = A._message_to_parent(parent) self.assertAllClose(true_m0, A_m0) self.assertAllClose(true_m1, A_m1) pass # Check: different message to each of multiple parents X1 = GaussianARD(np.random.randn(2), np.random.rand(2)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(2), np.random.rand(2)) x2 = X2.get_moments() m0 = tau * data * x2[0] m1 = -0.5 * tau * x2[1] * np.identity(2) check_message(m0, m1, 0, 'i,i->i', X1, X2) check_message(m0, m1, 0, X1, [9], X2, [9], [9]) m0 = tau * data * x1[0] m1 = -0.5 * tau * x1[1] * np.identity(2) check_message(m0, m1, 1, 'i,i->i', X1, X2) check_message(m0, m1, 1, X1, [9], X2, [9], [9]) # Check: key not in output X1 = GaussianARD(np.random.randn(2), np.random.rand(2)) x1 = X1.get_moments() m0 = tau * data * np.ones(2) m1 = -0.5 * tau * np.ones((2,2)) check_message(m0, m1, 0, 'i', X1) check_message(m0, m1, 0, 'i->', X1) check_message(m0, m1, 0, X1, [9]) check_message(m0, m1, 0, X1, [9], []) # Check: key not in some input X1 = GaussianARD(np.random.randn(), np.random.rand()) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(2), np.random.rand(2)) x2 = X2.get_moments() m0 = tau * data * np.sum(x2[0], axis=-1) m1 = -0.5 * tau * np.sum(x2[1] * np.identity(2), axis=(-1,-2)) check_message(m0, m1, 0, ',i->i', X1, X2) check_message(m0, m1, 0, X1, [], X2, [9], [9]) m0 = tau * data * x1[0] * np.ones(2) m1 = -0.5 * tau * x1[1] * np.identity(2) check_message(m0, m1, 1, ',i->i', X1, X2) check_message(m0, m1, 1, X1, [], X2, [9], [9]) # Check: keys in different order Y1 = GaussianARD(np.random.randn(3,2), np.random.rand(3,2)) y1 = Y1.get_moments() Y2 = GaussianARD(np.random.randn(2,3), np.random.rand(2,3)) y2 = Y2.get_moments() m0 = tau * data * y2[0].T m1 = -0.5 * tau * np.einsum('ijlk->jikl', y2[1] * utils.identity(2,3)) check_message(m0, m1, 0, 'ij,ji->ij', Y1, Y2) check_message(m0, m1, 0, Y1, ['i','j'], Y2, ['j','i'], ['i','j']) m0 = tau * data * y1[0].T m1 = -0.5 * tau * np.einsum('ijlk->jikl', y1[1] * utils.identity(3,2)) check_message(m0, m1, 1, 'ij,ji->ij', Y1, Y2) check_message(m0, m1, 1, Y1, ['i','j'], Y2, ['j','i'], ['i','j']) # Check: plates when different dimensionality X1 = GaussianARD(np.random.randn(5), np.random.rand(5), shape=(), plates=(5,)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(5,3), np.random.rand(5,3), shape=(3,), plates=(5,)) x2 = X2.get_moments() m0 = tau * data * np.sum(np.ones((5,3)) * x2[0], axis=-1) m1 = -0.5 * tau * np.sum(x2[1] * utils.identity(3), axis=(-1,-2)) check_message(m0, m1, 0, ',i->i', X1, X2) check_message(m0, m1, 0, X1, [], X2, ['i'], ['i']) m0 = tau * data * x1[0][:,np.newaxis] * np.ones((5,3)) m1 = -0.5 * tau * x1[1][:,np.newaxis,np.newaxis] * utils.identity(3) check_message(m0, m1, 1, ',i->i', X1, X2) check_message(m0, m1, 1, X1, [], X2, ['i'], ['i']) # Check: other parent's moments broadcasts over plates when node has the # same plates X1 = GaussianARD(np.random.randn(5,4,3), np.random.rand(5,4,3), shape=(3,), plates=(5,4)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3,), plates=(5,4)) x2 = X2.get_moments() m0 = tau * data * np.ones((5,4,3)) * x2[0] m1 = -0.5 * tau * x2[1] * utils.identity(3) check_message(m0, m1, 0, 'i,i->i', X1, X2) check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i']) # Check: other parent's moments broadcasts over plates when node does # not have that plate X1 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3,), plates=()) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3,), plates=(5,4)) x2 = X2.get_moments() m0 = tau * data * np.sum(np.ones((5,4,3)) * x2[0], axis=(0,1)) m1 = -0.5 * tau * np.sum(np.ones((5,4,1,1)) * utils.identity(3) * x2[1], axis=(0,1)) check_message(m0, m1, 0, 'i,i->i', X1, X2) check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i']) # Check: other parent's moments broadcasts over plates when the node # only broadcasts that plate X1 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3,), plates=(1,1)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3), np.random.rand(3), shape=(3,), plates=(5,4)) x2 = X2.get_moments() m0 = tau * data * np.sum(np.ones((5,4,3)) * x2[0], axis=(0,1), keepdims=True) m1 = -0.5 * tau * np.sum(np.ones((5,4,1,1)) * utils.identity(3) * x2[1], axis=(0,1), keepdims=True) check_message(m0, m1, 0, 'i,i->i', X1, X2) check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i']) # Check: broadcasted dimensions X1 = GaussianARD(np.random.randn(1,1), np.random.rand(1,1)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3,2), np.random.rand(3,2)) x2 = X2.get_moments() m0 = tau * data * np.sum(np.ones((3,2)) * x2[0], keepdims=True) m1 = -0.5 * tau * np.sum(utils.identity(3,2) * x2[1], keepdims=True) check_message(m0, m1, 0, 'ij,ij->ij', X1, X2) check_message(m0, m1, 0, X1, [0,1], X2, [0,1], [0,1]) m0 = tau * data * np.ones((3,2)) * x1[0] m1 = -0.5 * tau * utils.identity(3,2) * x1[1] check_message(m0, m1, 1, 'ij,ij->ij', X1, X2) check_message(m0, m1, 1, X1, [0,1], X2, [0,1], [0,1]) # Check: non-ARD observations X1 = GaussianARD(np.random.randn(2), np.random.rand(2)) x1 = X1.get_moments() Lambda = np.array([[2, 1.5], [1.5, 2]]) F = SumMultiply('i->i', X1) Y = Gaussian(F, Lambda) y = np.random.randn(2) Y.observe(y) m0 = np.dot(Lambda, y) m1 = -0.5 * Lambda check_message(m0, m1, 0, 'i->i', X1, F=F) check_message(m0, m1, 0, X1, ['i'], ['i'], F=F) # Check: mask with same shape X1 = GaussianARD(np.random.randn(3,2), np.random.rand(3,2), shape=(2,), plates=(3,)) x1 = X1.get_moments() mask = np.array([True, False, True]) F = SumMultiply('i->i', X1) Y = GaussianARD(F, tau) Y.observe(data*np.ones((3,2)), mask=mask) m0 = tau * data * mask[:,np.newaxis] * np.ones(2) m1 = -0.5 * tau * mask[:,np.newaxis,np.newaxis] * np.identity(2) check_message(m0, m1, 0, 'i->i', X1, F=F) check_message(m0, m1, 0, X1, ['i'], ['i'], F=F) # Check: mask larger X1 = GaussianARD(np.random.randn(2), np.random.rand(2), shape=(2,), plates=()) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(3,2), np.random.rand(3,2), shape=(2,), plates=(3,)) x2 = X2.get_moments() mask = np.array([True, False, True]) F = SumMultiply('i,i->i', X1, X2) Y = GaussianARD(F, tau, plates=(3,)) Y.observe(data*np.ones((3,2)), mask=mask) m0 = tau * data * np.sum(mask[:,np.newaxis] * x2[0], axis=0) m1 = -0.5 * tau * np.sum(mask[:,np.newaxis,np.newaxis] * x2[1] * np.identity(2), axis=0) check_message(m0, m1, 0, 'i,i->i', X1, X2, F=F) check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i'], F=F) # Check: mask for broadcasted plate X1 = GaussianARD(np.random.randn(2), np.random.rand(2), plates=(1,)) x1 = X1.get_moments() X2 = GaussianARD(np.random.randn(2), np.random.rand(2), plates=(3,)) x2 = X2.get_moments() mask = np.array([True, False, True]) F = SumMultiply('i,i->i', X1, X2) Y = GaussianARD(F, tau, plates=(3,)) Y.observe(data*np.ones((3,2)), mask=mask) m0 = tau * data * np.sum(mask[:,np.newaxis] * x2[0], axis=0, keepdims=True) m1 = -0.5 * tau * np.sum(mask[:,np.newaxis,np.newaxis] * x2[1] * np.identity(2), axis=0, keepdims=True) check_message(m0, m1, 0, 'i->i', X1, F=F) check_message(m0, m1, 0, X1, ['i'], ['i'], F=F) pass