Ejemplo n.º 1
0
    def test_message_to_parent_mu(self):
        """
        Test that GaussianARD computes the message to the 1st parent correctly.
        """

        # Check formula with uncertain parent alpha
        alpha = Gamma(2,1)
        X = GaussianArrayARD(0,
                             alpha)
        X.observe(3)
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0,
                            2*3)
        self.assertAllClose(m1,
                            -0.5*2)

        # Check formula with uncertain node
        X = GaussianArrayARD(1, 2)
        Y = GaussianArrayARD(X, 1)
        Y.observe(5)
        X.update()
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0,
                            2 * 1/(2+1)*(2*1+1*5))
        self.assertAllClose(m1,
                            -0.5*2)

        # Check alpha larger than mu
        X = GaussianArrayARD(np.zeros((2,3)),
                             2*np.ones((3,2,3)))
        X.observe(3*np.ones((3,2,3)))
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0,
                            2*3 * 3 * np.ones((2,3)))
        self.assertAllClose(m1,
                            -0.5 * 3 * 2*utils.identity(2,3))

        # Check mu larger than alpha
        X = GaussianArrayARD(np.zeros((3,2,3)),
                             2*np.ones((2,3)))
        X.observe(3*np.ones((3,2,3)))
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0,
                            2 * 3 * np.ones((3,2,3)))
        self.assertAllClose(m1,
                            -0.5 * 2*utils.identity(3,2,3))

        # Check node larger than mu and alpha
        X = GaussianArrayARD(np.zeros((2,3)),
                             2*np.ones((3,)),
                             shape=(3,2,3))
        X.observe(3*np.ones((3,2,3)))
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0,
                            2*3 * 3*np.ones((2,3)))
        self.assertAllClose(m1,
                            -0.5 * 2 * 3*utils.identity(2,3))

        # Check broadcasting of dimensions
        X = GaussianArrayARD(np.zeros((2,1)),
                             2*np.ones((2,3)),
                             shape=(2,3))
        X.observe(3*np.ones((2,3)))
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0,
                            2*3 * 3*np.ones((2,1)))
        self.assertAllClose(m1,
                            -0.5 * 2 * 3*utils.identity(2,1))

        # Check plates for smaller mu than node
        X = GaussianArrayARD(GaussianArrayARD(0,1, 
                                              shape=(3,),
                                              plates=(4,1)),
                             2*np.ones((3,)),
                             shape=(2,3),
                             plates=(4,5))
        X.observe(3*np.ones((4,5,2,3)))
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0,
                            2*3 * 5*2*np.ones((4,1,3)))
        self.assertAllClose(m1,
                            -0.5*2 * 5*2*utils.identity(3))

        # Check mask
        X = GaussianArrayARD(np.zeros((2,1,3)),
                             2*np.ones((2,4,3)),
                             shape=(3,),
                             plates=(2,4,))
        X.observe(3*np.ones((2,4,3)), mask=[[True, True, True, False],
                                            [False, True, False, True]])
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0,
                            (2*3 * np.ones((2,1,3)) 
                             * np.array([[[3]], [[2]]])))
        self.assertAllClose(m1,
                            (-0.5*2 * utils.identity(3)
                             * np.ones((2,1,1,1))
                             * np.array([[[[3]]], [[[2]]]])))

        # Check non-ARD Gaussian child
        mu = np.array([1,2])
        alpha = np.array([3,4])
        Lambda = np.array([[1, 0.5],
                          [0.5, 1]])
        X = GaussianArrayARD(mu, alpha)
        Y = Gaussian(X, Lambda)
        y = np.array([5,6])
        Y.observe(y)
        X.update()
        (m0, m1) = X._message_to_parent(0)
        mean = np.dot(np.linalg.inv(np.diag(alpha)+Lambda),
                      np.dot(np.diag(alpha), mu)
                      + np.dot(Lambda, y))
        self.assertAllClose(m0,
                            np.dot(np.diag(alpha), mean))
        self.assertAllClose(m1,
                            -0.5*np.diag(alpha))

        pass
Ejemplo n.º 2
0
    def test_message_to_child(self):
        """
        Test that GaussianArrayARD computes the message to children correctly.
        """

        # Check that moments have full shape when broadcasting
        X = GaussianArrayARD(np.zeros((2,)),
                             np.ones((3,2)),
                             shape=(4,3,2))
        (u0, u1) = X._message_to_child()
        self.assertEqual(np.shape(u0),
                         (4,3,2))
        self.assertEqual(np.shape(u1),
                         (4,3,2,4,3,2))

        # Check the formula
        X = GaussianArrayARD(2, 3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2)
        self.assertAllClose(u1, 2**2 + 1/3)

        # Check the formula for multidimensional arrays
        X = GaussianArrayARD(2*np.ones((2,1,4)),
                             3*np.ones((2,3,1)),
                             ndim=3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2*np.ones((2,3,4)))
        self.assertAllClose(u1, 
                            2**2 * np.ones((2,3,4,2,3,4))
                            + 1/3 * utils.identity(2,3,4))
                            

        # Check the formula for dim-broadcasted mu
        X = GaussianArrayARD(2*np.ones((3,1)),
                             3*np.ones((2,3,4)),
                             ndim=3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2*np.ones((2,3,4)))
        self.assertAllClose(u1, 
                            2**2 * np.ones((2,3,4,2,3,4))
                            + 1/3 * utils.identity(2,3,4))
                            
        # Check the formula for dim-broadcasted alpha
        X = GaussianArrayARD(2*np.ones((2,3,4)),
                             3*np.ones((3,1)),
                             ndim=3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2*np.ones((2,3,4)))
        self.assertAllClose(u1, 
                            2**2 * np.ones((2,3,4,2,3,4))
                            + 1/3 * utils.identity(2,3,4))
                            
        # Check the formula for dim-broadcasted mu and alpha
        X = GaussianArrayARD(2*np.ones((3,1)),
                             3*np.ones((3,1)),
                             shape=(2,3,4))
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2*np.ones((2,3,4)))
        self.assertAllClose(u1, 
                            2**2 * np.ones((2,3,4,2,3,4))
                            + 1/3 * utils.identity(2,3,4))
                            
        # Check the formula for dim-broadcasted mu with plates
        mu = GaussianArrayARD(2*np.ones((5,3,4)),
                              np.ones((5,3,4)),
                              shape=(3,4),
                              plates=(5,))
        X = GaussianArrayARD(mu,
                             3*np.ones((5,2,3,4)),
                             shape=(2,3,4),
                             plates=(5,))
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2*np.ones((5,2,3,4)))
        self.assertAllClose(u1, 
                            2**2 * np.ones((5,2,3,4,2,3,4))
                            + 1/3 * utils.identity(2,3,4))

        # Check posterior
        X = GaussianArrayARD(2, 3)
        Y = GaussianArrayARD(X, 1)
        Y.observe(10)
        X.update()
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0,
                            1/(3+1) * (3*2 + 1*10))
        self.assertAllClose(u1,
                            (1/(3+1) * (3*2 + 1*10))**2 + 1/(3+1))
        
        pass
Ejemplo n.º 3
0
    def test_message_to_child(self):
        """
        Test that GaussianARD computes the message to children correctly.
        """

        # Check that moments have full shape when broadcasting
        X = GaussianARD(np.zeros((2, )), np.ones((3, 2)), shape=(4, 3, 2))
        (u0, u1) = X._message_to_child()
        self.assertEqual(np.shape(u0), (4, 3, 2))
        self.assertEqual(np.shape(u1), (4, 3, 2, 4, 3, 2))

        # Check the formula
        X = GaussianARD(2, 3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2)
        self.assertAllClose(u1, 2**2 + 1 / 3)

        # Check the formula for multidimensional arrays
        X = GaussianARD(2 * np.ones((2, 1, 4)), 3 * np.ones((2, 3, 1)), ndim=3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2 * np.ones((2, 3, 4)))
        self.assertAllClose(
            u1, 2**2 * np.ones(
                (2, 3, 4, 2, 3, 4)) + 1 / 3 * utils.identity(2, 3, 4))

        # Check the formula for dim-broadcasted mu
        X = GaussianARD(2 * np.ones((3, 1)), 3 * np.ones((2, 3, 4)), ndim=3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2 * np.ones((2, 3, 4)))
        self.assertAllClose(
            u1, 2**2 * np.ones(
                (2, 3, 4, 2, 3, 4)) + 1 / 3 * utils.identity(2, 3, 4))

        # Check the formula for dim-broadcasted alpha
        X = GaussianARD(2 * np.ones((2, 3, 4)), 3 * np.ones((3, 1)), ndim=3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2 * np.ones((2, 3, 4)))
        self.assertAllClose(
            u1, 2**2 * np.ones(
                (2, 3, 4, 2, 3, 4)) + 1 / 3 * utils.identity(2, 3, 4))

        # Check the formula for dim-broadcasted mu and alpha
        X = GaussianARD(2 * np.ones((3, 1)),
                        3 * np.ones((3, 1)),
                        shape=(2, 3, 4))
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2 * np.ones((2, 3, 4)))
        self.assertAllClose(
            u1, 2**2 * np.ones(
                (2, 3, 4, 2, 3, 4)) + 1 / 3 * utils.identity(2, 3, 4))

        # Check the formula for dim-broadcasted mu with plates
        mu = GaussianARD(2 * np.ones((5, 3, 4)),
                         np.ones((5, 3, 4)),
                         shape=(3, 4),
                         plates=(5, ))
        X = GaussianARD(mu,
                        3 * np.ones((5, 2, 3, 4)),
                        shape=(2, 3, 4),
                        plates=(5, ))
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2 * np.ones((5, 2, 3, 4)))
        self.assertAllClose(
            u1, 2**2 * np.ones(
                (5, 2, 3, 4, 2, 3, 4)) + 1 / 3 * utils.identity(2, 3, 4))

        # Check posterior
        X = GaussianARD(2, 3)
        Y = GaussianARD(X, 1)
        Y.observe(10)
        X.update()
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 1 / (3 + 1) * (3 * 2 + 1 * 10))
        self.assertAllClose(u1,
                            (1 / (3 + 1) * (3 * 2 + 1 * 10))**2 + 1 / (3 + 1))

        pass
Ejemplo n.º 4
0
    def test_message_to_parent_mu(self):
        """
        Test that GaussianARD computes the message to the 1st parent correctly.
        """

        # Check formula with uncertain parent alpha
        alpha = Gamma(2, 1)
        X = GaussianARD(0, alpha)
        X.observe(3)
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0, 2 * 3)
        self.assertAllClose(m1, -0.5 * 2)

        # Check formula with uncertain node
        X = GaussianARD(1, 2)
        Y = GaussianARD(X, 1)
        Y.observe(5)
        X.update()
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0, 2 * 1 / (2 + 1) * (2 * 1 + 1 * 5))
        self.assertAllClose(m1, -0.5 * 2)

        # Check alpha larger than mu
        X = GaussianARD(np.zeros((2, 3)), 2 * np.ones((3, 2, 3)))
        X.observe(3 * np.ones((3, 2, 3)))
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0, 2 * 3 * 3 * np.ones((2, 3)))
        self.assertAllClose(m1, -0.5 * 3 * 2 * utils.identity(2, 3))

        # Check mu larger than alpha
        X = GaussianARD(np.zeros((3, 2, 3)), 2 * np.ones((2, 3)))
        X.observe(3 * np.ones((3, 2, 3)))
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0, 2 * 3 * np.ones((3, 2, 3)))
        self.assertAllClose(m1, -0.5 * 2 * utils.identity(3, 2, 3))

        # Check node larger than mu and alpha
        X = GaussianARD(np.zeros((2, 3)), 2 * np.ones((3, )), shape=(3, 2, 3))
        X.observe(3 * np.ones((3, 2, 3)))
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0, 2 * 3 * 3 * np.ones((2, 3)))
        self.assertAllClose(m1, -0.5 * 2 * 3 * utils.identity(2, 3))

        # Check broadcasting of dimensions
        X = GaussianARD(np.zeros((2, 1)), 2 * np.ones((2, 3)), shape=(2, 3))
        X.observe(3 * np.ones((2, 3)))
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0, 2 * 3 * 3 * np.ones((2, 1)))
        self.assertAllClose(m1, -0.5 * 2 * 3 * utils.identity(2, 1))

        # Check plates for smaller mu than node
        X = GaussianARD(GaussianARD(0, 1, shape=(3, ), plates=(4, 1)),
                        2 * np.ones((3, )),
                        shape=(2, 3),
                        plates=(4, 5))
        X.observe(3 * np.ones((4, 5, 2, 3)))
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0, 2 * 3 * 5 * 2 * np.ones((4, 1, 3)))
        self.assertAllClose(m1, -0.5 * 2 * 5 * 2 * utils.identity(3))

        # Check mask
        X = GaussianARD(np.zeros((2, 1, 3)),
                        2 * np.ones((2, 4, 3)),
                        shape=(3, ),
                        plates=(
                            2,
                            4,
                        ))
        X.observe(3 * np.ones((2, 4, 3)),
                  mask=[[True, True, True, False], [False, True, False, True]])
        (m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0, (2 * 3 * np.ones(
            (2, 1, 3)) * np.array([[[3]], [[2]]])))
        self.assertAllClose(m1, (-0.5 * 2 * utils.identity(3) * np.ones(
            (2, 1, 1, 1)) * np.array([[[[3]]], [[[2]]]])))

        # Check non-ARD Gaussian child
        mu = np.array([1, 2])
        alpha = np.array([3, 4])
        Lambda = np.array([[1, 0.5], [0.5, 1]])
        X = GaussianARD(mu, alpha)
        Y = Gaussian(X, Lambda)
        y = np.array([5, 6])
        Y.observe(y)
        X.update()
        (m0, m1) = X._message_to_parent(0)
        mean = np.dot(np.linalg.inv(np.diag(alpha) + Lambda),
                      np.dot(np.diag(alpha), mu) + np.dot(Lambda, y))
        self.assertAllClose(m0, np.dot(np.diag(alpha), mean))
        self.assertAllClose(m1, -0.5 * np.diag(alpha))

        pass
Ejemplo n.º 5
0
    def test_message_to_parent(self):
        """
        Test the message from SumMultiply node to its parents.
        """

        data = 2
        tau = 3
        
        def check_message(true_m0, true_m1, parent, *args, F=None):
            if F is None:
                A = SumMultiply(*args)
                B = GaussianARD(A, tau)
                B.observe(data*np.ones(A.plates + A.dims[0]))
            else:
                A = F
            (A_m0, A_m1) = A._message_to_parent(parent)
            self.assertAllClose(true_m0, A_m0)
            self.assertAllClose(true_m1, A_m1)
            pass

        # Check: different message to each of multiple parents
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(2),
                         np.random.rand(2))
        x2 = X2.get_moments()
        m0 = tau * data * x2[0]
        m1 = -0.5 * tau * x2[1] * np.identity(2)
        check_message(m0, m1, 0,
                      'i,i->i',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      [9],
                      X2,
                      [9],
                      [9])
        m0 = tau * data * x1[0]
        m1 = -0.5 * tau * x1[1] * np.identity(2)
        check_message(m0, m1, 1,
                      'i,i->i',
                      X1,
                      X2)
        check_message(m0, m1, 1,
                      X1,
                      [9],
                      X2,
                      [9],
                      [9])
        
        # Check: key not in output
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2))
        x1 = X1.get_moments()
        m0 = tau * data * np.ones(2)
        m1 = -0.5 * tau * np.ones((2,2))
        check_message(m0, m1, 0,
                      'i',
                      X1)
        check_message(m0, m1, 0,
                      'i->',
                      X1)
        check_message(m0, m1, 0,
                      X1,
                      [9])
        check_message(m0, m1, 0,
                      X1,
                      [9],
                      [])

        # Check: key not in some input
        X1 = GaussianARD(np.random.randn(),
                         np.random.rand())
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(2),
                         np.random.rand(2))
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(x2[0], axis=-1)
        m1 = -0.5 * tau * np.sum(x2[1] * np.identity(2),
                                 axis=(-1,-2))
        check_message(m0, m1, 0,
                      ',i->i',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      [],
                      X2,
                      [9],
                      [9])
        m0 = tau * data * x1[0] * np.ones(2)
        m1 = -0.5 * tau * x1[1] * np.identity(2)
        check_message(m0, m1, 1,
                      ',i->i',
                      X1,
                      X2)
        check_message(m0, m1, 1,
                      X1,
                      [],
                      X2,
                      [9],
                      [9])

        # Check: keys in different order
        Y1 = GaussianARD(np.random.randn(3,2),
                         np.random.rand(3,2))
        y1 = Y1.get_moments()
        Y2 = GaussianARD(np.random.randn(2,3),
                         np.random.rand(2,3))
        y2 = Y2.get_moments()
        m0 = tau * data * y2[0].T
        m1 = -0.5 * tau * np.einsum('ijlk->jikl', y2[1] * utils.identity(2,3))
        check_message(m0, m1, 0,
                      'ij,ji->ij',
                      Y1,
                      Y2)
        check_message(m0, m1, 0,
                      Y1,
                      ['i','j'],
                      Y2,
                      ['j','i'],
                      ['i','j'])
        m0 = tau * data * y1[0].T
        m1 = -0.5 * tau * np.einsum('ijlk->jikl', y1[1] * utils.identity(3,2))
        check_message(m0, m1, 1,
                      'ij,ji->ij',
                      Y1,
                      Y2)
        check_message(m0, m1, 1,
                      Y1,
                      ['i','j'],
                      Y2,
                      ['j','i'],
                      ['i','j'])

        # Check: plates when different dimensionality
        X1 = GaussianARD(np.random.randn(5),
                         np.random.rand(5),
                         shape=(),
                         plates=(5,))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(5,3),
                         np.random.rand(5,3),
                         shape=(3,),
                         plates=(5,))
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(np.ones((5,3)) * x2[0], axis=-1)
        m1 = -0.5 * tau * np.sum(x2[1] * utils.identity(3), axis=(-1,-2))
        check_message(m0, m1, 0,
                      ',i->i',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      [],
                      X2,
                      ['i'],
                      ['i'])
        m0 = tau * data * x1[0][:,np.newaxis] * np.ones((5,3))
        m1 = -0.5 * tau * x1[1][:,np.newaxis,np.newaxis] * utils.identity(3)
        check_message(m0, m1, 1,
                      ',i->i',
                      X1,
                      X2)
        check_message(m0, m1, 1,
                      X1,
                      [],
                      X2,
                      ['i'],
                      ['i'])
        
        # Check: other parent's moments broadcasts over plates when node has the
        # same plates
        X1 = GaussianARD(np.random.randn(5,4,3),
                         np.random.rand(5,4,3),
                         shape=(3,),
                         plates=(5,4))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3,),
                         plates=(5,4))
        x2 = X2.get_moments()
        m0 = tau * data * np.ones((5,4,3)) * x2[0]
        m1 = -0.5 * tau * x2[1] * utils.identity(3)
        check_message(m0, m1, 0,
                      'i,i->i',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      X2,
                      ['i'],
                      ['i'])
        
        # Check: other parent's moments broadcasts over plates when node does
        # not have that plate
        X1 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3,),
                         plates=())
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3,),
                         plates=(5,4))
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(np.ones((5,4,3)) * x2[0], axis=(0,1))
        m1 = -0.5 * tau * np.sum(np.ones((5,4,1,1))
                                 * utils.identity(3)
                                 * x2[1], 
                                 axis=(0,1))
        check_message(m0, m1, 0,
                      'i,i->i',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      X2,
                      ['i'],
                      ['i'])
        
        # Check: other parent's moments broadcasts over plates when the node
        # only broadcasts that plate
        X1 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3,),
                         plates=(1,1))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3,),
                         plates=(5,4))
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(np.ones((5,4,3)) * x2[0], axis=(0,1), keepdims=True)
        m1 = -0.5 * tau * np.sum(np.ones((5,4,1,1))
                                 * utils.identity(3)
                                 * x2[1], 
                                 axis=(0,1),
                                 keepdims=True)
        check_message(m0, m1, 0,
                      'i,i->i',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      X2,
                      ['i'],
                      ['i'])
        
        # Check: broadcasted dimensions
        X1 = GaussianARD(np.random.randn(1,1),
                         np.random.rand(1,1))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3,2),
                         np.random.rand(3,2))
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(np.ones((3,2)) * x2[0], 
                                 keepdims=True)
        m1 = -0.5 * tau * np.sum(utils.identity(3,2) * x2[1], 
                                 keepdims=True)
        check_message(m0, m1, 0,
                      'ij,ij->ij',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      [0,1],
                      X2,
                      [0,1],
                      [0,1])
        m0 = tau * data * np.ones((3,2)) * x1[0]
        m1 = -0.5 * tau * utils.identity(3,2) * x1[1]
        check_message(m0, m1, 1,
                      'ij,ij->ij',
                      X1,
                      X2)
        check_message(m0, m1, 1,
                      X1,
                      [0,1],
                      X2,
                      [0,1],
                      [0,1])

        # Check: non-ARD observations
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2))
        x1 = X1.get_moments()
        Lambda = np.array([[2, 1.5], [1.5, 2]])
        F = SumMultiply('i->i', X1)
        Y = Gaussian(F, Lambda)
        y = np.random.randn(2)
        Y.observe(y)
        m0 = np.dot(Lambda, y)
        m1 = -0.5 * Lambda
        check_message(m0, m1, 0,
                      'i->i',
                      X1,
                      F=F)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      ['i'],
                      F=F)

        # Check: mask with same shape
        X1 = GaussianARD(np.random.randn(3,2),
                         np.random.rand(3,2),
                         shape=(2,),
                         plates=(3,))
        x1 = X1.get_moments()
        mask = np.array([True, False, True])
        F = SumMultiply('i->i', X1)
        Y = GaussianARD(F, tau)                     
        Y.observe(data*np.ones((3,2)), mask=mask)
        m0 = tau * data * mask[:,np.newaxis] * np.ones(2)
        m1 = -0.5 * tau * mask[:,np.newaxis,np.newaxis] * np.identity(2)
        check_message(m0, m1, 0,
                      'i->i',
                      X1,
                      F=F)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      ['i'],
                      F=F)

        # Check: mask larger
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         shape=(2,),
                         plates=())
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3,2),
                         np.random.rand(3,2),
                         shape=(2,),
                         plates=(3,))
        x2 = X2.get_moments()
        mask = np.array([True, False, True])
        F = SumMultiply('i,i->i', X1, X2)
        Y = GaussianARD(F, tau,
                        plates=(3,))                     
        Y.observe(data*np.ones((3,2)), mask=mask)
        m0 = tau * data * np.sum(mask[:,np.newaxis] * x2[0], axis=0)
        m1 = -0.5 * tau * np.sum(mask[:,np.newaxis,np.newaxis]
                                 * x2[1]
                                 * np.identity(2),
                                 axis=0)
        check_message(m0, m1, 0,
                      'i,i->i',
                      X1,
                      X2,
                      F=F)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      X2,
                      ['i'],
                      ['i'],
                      F=F)

        # Check: mask for broadcasted plate
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         plates=(1,))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         plates=(3,))
        x2 = X2.get_moments()
        mask = np.array([True, False, True])
        F = SumMultiply('i,i->i', X1, X2)
        Y = GaussianARD(F, tau,
                        plates=(3,))
        Y.observe(data*np.ones((3,2)), mask=mask)
        m0 = tau * data * np.sum(mask[:,np.newaxis] * x2[0], 
                                 axis=0,
                                 keepdims=True)
        m1 = -0.5 * tau * np.sum(mask[:,np.newaxis,np.newaxis]
                                 * x2[1]
                                 * np.identity(2),
                                 axis=0,
                                 keepdims=True)
        check_message(m0, m1, 0,
                      'i->i',
                      X1,
                      F=F)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      ['i'],
                      F=F)

        pass