Exemplo n.º 1
0
    def test_message_to_parent(self):
        """
        Test the message from SumMultiply node to its parents.
        """

        data = 2
        tau = 3
        
        def check_message(true_m0, true_m1, parent, *args, F=None):
            if F is None:
                A = SumMultiply(*args)
                B = GaussianARD(A, tau)
                B.observe(data*np.ones(A.plates + A.dims[0]))
            else:
                A = F
            (A_m0, A_m1) = A._message_to_parent(parent)
            self.assertAllClose(true_m0, A_m0)
            self.assertAllClose(true_m1, A_m1)
            pass

        # Check: different message to each of multiple parents
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         ndim=1)
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         ndim=1)
        x2 = X2.get_moments()
        m0 = tau * data * x2[0]
        m1 = -0.5 * tau * x2[1] * np.identity(2)
        check_message(m0, m1, 0,
                      'i,i->i',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      [9],
                      X2,
                      [9],
                      [9])
        m0 = tau * data * x1[0]
        m1 = -0.5 * tau * x1[1] * np.identity(2)
        check_message(m0, m1, 1,
                      'i,i->i',
                      X1,
                      X2)
        check_message(m0, m1, 1,
                      X1,
                      [9],
                      X2,
                      [9],
                      [9])
        
        # Check: key not in output
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         ndim=1)
        x1 = X1.get_moments()
        m0 = tau * data * np.ones(2)
        m1 = -0.5 * tau * np.ones((2,2))
        check_message(m0, m1, 0,
                      'i',
                      X1)
        check_message(m0, m1, 0,
                      'i->',
                      X1)
        check_message(m0, m1, 0,
                      X1,
                      [9])
        check_message(m0, m1, 0,
                      X1,
                      [9],
                      [])

        # Check: key not in some input
        X1 = GaussianARD(np.random.randn(),
                         np.random.rand())
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         ndim=1)
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(x2[0], axis=-1)
        m1 = -0.5 * tau * np.sum(x2[1] * np.identity(2),
                                 axis=(-1,-2))
        check_message(m0, m1, 0,
                      ',i->i',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      [],
                      X2,
                      [9],
                      [9])
        m0 = tau * data * x1[0] * np.ones(2)
        m1 = -0.5 * tau * x1[1] * np.identity(2)
        check_message(m0, m1, 1,
                      ',i->i',
                      X1,
                      X2)
        check_message(m0, m1, 1,
                      X1,
                      [],
                      X2,
                      [9],
                      [9])

        # Check: keys in different order
        Y1 = GaussianARD(np.random.randn(3,2),
                         np.random.rand(3,2),
                         ndim=2)
        y1 = Y1.get_moments()
        Y2 = GaussianARD(np.random.randn(2,3),
                         np.random.rand(2,3),
                         ndim=2)
        y2 = Y2.get_moments()
        m0 = tau * data * y2[0].T
        m1 = -0.5 * tau * np.einsum('ijlk->jikl', y2[1] * misc.identity(2,3))
        check_message(m0, m1, 0,
                      'ij,ji->ij',
                      Y1,
                      Y2)
        check_message(m0, m1, 0,
                      Y1,
                      ['i','j'],
                      Y2,
                      ['j','i'],
                      ['i','j'])
        m0 = tau * data * y1[0].T
        m1 = -0.5 * tau * np.einsum('ijlk->jikl', y1[1] * misc.identity(3,2))
        check_message(m0, m1, 1,
                      'ij,ji->ij',
                      Y1,
                      Y2)
        check_message(m0, m1, 1,
                      Y1,
                      ['i','j'],
                      Y2,
                      ['j','i'],
                      ['i','j'])

        # Check: plates when different dimensionality
        X1 = GaussianARD(np.random.randn(5),
                         np.random.rand(5),
                         shape=(),
                         plates=(5,))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(5,3),
                         np.random.rand(5,3),
                         shape=(3,),
                         plates=(5,))
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(np.ones((5,3)) * x2[0], axis=-1)
        m1 = -0.5 * tau * np.sum(x2[1] * misc.identity(3), axis=(-1,-2))
        check_message(m0, m1, 0,
                      ',i->i',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      [],
                      X2,
                      ['i'],
                      ['i'])
        m0 = tau * data * x1[0][:,np.newaxis] * np.ones((5,3))
        m1 = -0.5 * tau * x1[1][:,np.newaxis,np.newaxis] * misc.identity(3)
        check_message(m0, m1, 1,
                      ',i->i',
                      X1,
                      X2)
        check_message(m0, m1, 1,
                      X1,
                      [],
                      X2,
                      ['i'],
                      ['i'])
        
        # Check: other parent's moments broadcasts over plates when node has the
        # same plates
        X1 = GaussianARD(np.random.randn(5,4,3),
                         np.random.rand(5,4,3),
                         shape=(3,),
                         plates=(5,4))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3,),
                         plates=(5,4))
        x2 = X2.get_moments()
        m0 = tau * data * np.ones((5,4,3)) * x2[0]
        m1 = -0.5 * tau * x2[1] * misc.identity(3)
        check_message(m0, m1, 0,
                      'i,i->i',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      X2,
                      ['i'],
                      ['i'])
        
        # Check: other parent's moments broadcasts over plates when node does
        # not have that plate
        X1 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3,),
                         plates=())
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3,),
                         plates=(5,4))
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(np.ones((5,4,3)) * x2[0], axis=(0,1))
        m1 = -0.5 * tau * np.sum(np.ones((5,4,1,1))
                                 * misc.identity(3)
                                 * x2[1], 
                                 axis=(0,1))
        check_message(m0, m1, 0,
                      'i,i->i',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      X2,
                      ['i'],
                      ['i'])
        
        # Check: other parent's moments broadcasts over plates when the node
        # only broadcasts that plate
        X1 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3,),
                         plates=(1,1))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3,),
                         plates=(5,4))
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(np.ones((5,4,3)) * x2[0], axis=(0,1), keepdims=True)
        m1 = -0.5 * tau * np.sum(np.ones((5,4,1,1))
                                 * misc.identity(3)
                                 * x2[1], 
                                 axis=(0,1),
                                 keepdims=True)
        check_message(m0, m1, 0,
                      'i,i->i',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      X2,
                      ['i'],
                      ['i'])
        
        # Check: broadcasted dimensions
        X1 = GaussianARD(np.random.randn(1,1),
                         np.random.rand(1,1),
                         ndim=2)
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3,2),
                         np.random.rand(3,2),
                         ndim=2)
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(np.ones((3,2)) * x2[0], 
                                 keepdims=True)
        m1 = -0.5 * tau * np.sum(misc.identity(3,2) * x2[1], 
                                 keepdims=True)
        check_message(m0, m1, 0,
                      'ij,ij->ij',
                      X1,
                      X2)
        check_message(m0, m1, 0,
                      X1,
                      [0,1],
                      X2,
                      [0,1],
                      [0,1])
        m0 = tau * data * np.ones((3,2)) * x1[0]
        m1 = -0.5 * tau * misc.identity(3,2) * x1[1]
        check_message(m0, m1, 1,
                      'ij,ij->ij',
                      X1,
                      X2)
        check_message(m0, m1, 1,
                      X1,
                      [0,1],
                      X2,
                      [0,1],
                      [0,1])

        # Check: non-ARD observations
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         ndim=1)
        x1 = X1.get_moments()
        Lambda = np.array([[2, 1.5], [1.5, 2]])
        F = SumMultiply('i->i', X1)
        Y = Gaussian(F, Lambda)
        y = np.random.randn(2)
        Y.observe(y)
        m0 = np.dot(Lambda, y)
        m1 = -0.5 * Lambda
        check_message(m0, m1, 0,
                      'i->i',
                      X1,
                      F=F)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      ['i'],
                      F=F)

        # Check: mask with same shape
        X1 = GaussianARD(np.random.randn(3,2),
                         np.random.rand(3,2),
                         shape=(2,),
                         plates=(3,))
        x1 = X1.get_moments()
        mask = np.array([True, False, True])
        F = SumMultiply('i->i', X1)
        Y = GaussianARD(F, tau, ndim=1)
        Y.observe(data*np.ones((3,2)), mask=mask)
        m0 = tau * data * mask[:,np.newaxis] * np.ones(2)
        m1 = -0.5 * tau * mask[:,np.newaxis,np.newaxis] * np.identity(2)
        check_message(m0, m1, 0,
                      'i->i',
                      X1,
                      F=F)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      ['i'],
                      F=F)

        # Check: mask larger
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         shape=(2,),
                         plates=())
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3,2),
                         np.random.rand(3,2),
                         shape=(2,),
                         plates=(3,))
        x2 = X2.get_moments()
        mask = np.array([True, False, True])
        F = SumMultiply('i,i->i', X1, X2)
        Y = GaussianARD(F, tau,
                        plates=(3,),
                        ndim=1)
        Y.observe(data*np.ones((3,2)), mask=mask)
        m0 = tau * data * np.sum(mask[:,np.newaxis] * x2[0], axis=0)
        m1 = -0.5 * tau * np.sum(mask[:,np.newaxis,np.newaxis]
                                 * x2[1]
                                 * np.identity(2),
                                 axis=0)
        check_message(m0, m1, 0,
                      'i,i->i',
                      X1,
                      X2,
                      F=F)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      X2,
                      ['i'],
                      ['i'],
                      F=F)

        # Check: mask for broadcasted plate
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         ndim=1,
                         plates=(1,))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         ndim=1,
                         plates=(3,))
        x2 = X2.get_moments()
        mask = np.array([True, False, True])
        F = SumMultiply('i,i->i', X1, X2)
        Y = GaussianARD(F, tau,
                        plates=(3,),
                        ndim=1)
        Y.observe(data*np.ones((3,2)), mask=mask)
        m0 = tau * data * np.sum(mask[:,np.newaxis] * x2[0],
                                 axis=0,
                                 keepdims=True)
        m1 = -0.5 * tau * np.sum(mask[:,np.newaxis,np.newaxis]
                                 * x2[1]
                                 * np.identity(2),
                                 axis=0,
                                 keepdims=True)
        check_message(m0, m1, 0,
                      'i->i',
                      X1,
                      F=F)
        check_message(m0, m1, 0,
                      X1,
                      ['i'],
                      ['i'],
                      F=F)

        # Check: Gaussian-gamma parents
        X1 = GaussianGamma(
            np.random.randn(2),
            random.covariance(2),
            np.random.rand(),
            np.random.rand()
        )
        x1 = X1.get_moments()
        X2 = GaussianGamma(
            np.random.randn(2),
            random.covariance(2),
            np.random.rand(),
            np.random.rand()
        )
        x2 = X2.get_moments()
        F = SumMultiply('i,i->i', X1, X2)
        V = random.covariance(2)
        y = np.random.randn(2)
        Y = Gaussian(F, V)
        Y.observe(y)
        m0 = np.dot(V, y) * x2[0]
        m1 = -0.5 * V * x2[1]
        m2 = -0.5 * np.einsum('i,ij,j', y, V, y) * x2[2]#linalg.inner(V, x2[2], ndim=2)
        m3 = 0.5 * 2 #linalg.chol_logdet(linalg.chol(V)) + 2*x2[3]
        m = F._message_to_parent(0)
        self.assertAllClose(m[0], m0)
        self.assertAllClose(m[1], m1)
        self.assertAllClose(m[2], m2)
        self.assertAllClose(m[3], m3)

        pass
Exemplo n.º 2
0
    def test_message_to_parent_mu(self):
        """
        Test that GaussianARD computes the message to the 1st parent correctly.
        """

        # Check formula with uncertain parent alpha
        mu = GaussianARD(0, 1)
        alpha = Gamma(2, 1)
        X = GaussianARD(mu, alpha)
        X.observe(3)
        (m0, m1) = mu._message_from_children()
        #(m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0, 2 * 3)
        self.assertAllClose(m1, -0.5 * 2)

        # Check formula with uncertain node
        mu = GaussianARD(1, 1e10)
        X = GaussianARD(mu, 2)
        Y = GaussianARD(X, 1)
        Y.observe(5)
        X.update()
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0, 2 * 1 / (2 + 1) * (2 * 1 + 1 * 5))
        self.assertAllClose(m1, -0.5 * 2)

        # Check alpha larger than mu
        mu = GaussianARD(np.zeros((2, 3)), 1e10, shape=(2, 3))
        X = GaussianARD(mu, 2 * np.ones((3, 2, 3)))
        X.observe(3 * np.ones((3, 2, 3)))
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0, 2 * 3 * 3 * np.ones((2, 3)))
        self.assertAllClose(m1, -0.5 * 3 * 2 * misc.identity(2, 3))

        # Check mu larger than alpha
        mu = GaussianARD(np.zeros((3, 2, 3)), 1e10, shape=(3, 2, 3))
        X = GaussianARD(mu, 2 * np.ones((2, 3)))
        X.observe(3 * np.ones((3, 2, 3)))
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0, 2 * 3 * np.ones((3, 2, 3)))
        self.assertAllClose(m1, -0.5 * 2 * misc.identity(3, 2, 3))

        # Check node larger than mu and alpha
        mu = GaussianARD(np.zeros((2, 3)), 1e10, shape=(2, 3))
        X = GaussianARD(mu, 2 * np.ones((3, )), shape=(3, 2, 3))
        X.observe(3 * np.ones((3, 2, 3)))
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0, 2 * 3 * 3 * np.ones((2, 3)))
        self.assertAllClose(m1, -0.5 * 2 * 3 * misc.identity(2, 3))

        # Check broadcasting of dimensions
        mu = GaussianARD(np.zeros((2, 1)), 1e10, shape=(2, 1))
        X = GaussianARD(mu, 2 * np.ones((2, 3)), shape=(2, 3))
        X.observe(3 * np.ones((2, 3)))
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0, 2 * 3 * 3 * np.ones((2, 1)))
        self.assertAllClose(m1, -0.5 * 2 * 3 * misc.identity(2, 1))

        # Check plates for smaller mu than node
        mu = GaussianARD(0, 1, shape=(3, ), plates=(4, 1, 1))
        X = GaussianARD(mu, 2 * np.ones((3, )), shape=(2, 3), plates=(4, 5))
        X.observe(3 * np.ones((4, 5, 2, 3)))
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0 * np.ones((4, 1, 1, 3)),
                            2 * 3 * 5 * 2 * np.ones((4, 1, 1, 3)))
        self.assertAllClose(
            m1 * np.ones((4, 1, 1, 3, 3)),
            -0.5 * 2 * 5 * 2 * misc.identity(3) * np.ones((4, 1, 1, 3, 3)))

        # Check mask
        mu = GaussianARD(np.zeros((2, 1, 3)), 1e10, shape=(3, ))
        X = GaussianARD(mu,
                        2 * np.ones((2, 4, 3)),
                        shape=(3, ),
                        plates=(
                            2,
                            4,
                        ))
        X.observe(3 * np.ones((2, 4, 3)),
                  mask=[[True, True, True, False], [False, True, False, True]])
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0, (2 * 3 * np.ones(
            (2, 1, 3)) * np.array([[[3]], [[2]]])))
        self.assertAllClose(m1, (-0.5 * 2 * misc.identity(3) * np.ones(
            (2, 1, 1, 1)) * np.array([[[[3]]], [[[2]]]])))

        # Check mask with different shapes
        mu = GaussianARD(np.zeros((2, 1, 3)), 1e10, shape=())
        X = GaussianARD(mu,
                        2 * np.ones((2, 4, 3)),
                        shape=(3, ),
                        plates=(
                            2,
                            4,
                        ))
        mask = np.array([[True, True, True, False], [False, True, False,
                                                     True]])
        X.observe(3 * np.ones((2, 4, 3)), mask=mask)
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(
            m0, 2 * 3 * np.sum(
                np.ones((2, 4, 3)) * mask[..., None], axis=-2, keepdims=True))
        self.assertAllClose(m1, (-0.5 * 2 * np.sum(
            np.ones((2, 4, 3)) * mask[..., None], axis=-2, keepdims=True)))

        # Check non-ARD Gaussian child
        mu = np.array([1, 2])
        Mu = GaussianARD(mu, 1e10, shape=(2, ))
        alpha = np.array([3, 4])
        Lambda = np.array([[1, 0.5], [0.5, 1]])
        X = GaussianARD(Mu, alpha)
        Y = Gaussian(X, Lambda)
        y = np.array([5, 6])
        Y.observe(y)
        X.update()
        (m0, m1) = Mu._message_from_children()
        mean = np.dot(np.linalg.inv(np.diag(alpha) + Lambda),
                      np.dot(np.diag(alpha), mu) + np.dot(Lambda, y))
        self.assertAllClose(m0, np.dot(np.diag(alpha), mean))
        self.assertAllClose(m1, -0.5 * np.diag(alpha))

        # Check broadcasted variable axes
        mu = GaussianARD(np.zeros(1), 1e10, shape=(1, ))
        X = GaussianARD(mu, 2, shape=(3, ))
        X.observe(3 * np.ones(3))
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0,
                            2 * 3 * np.sum(np.ones(3), axis=-1, keepdims=True))
        self.assertAllClose(
            m1,
            -0.5 * 2 * np.sum(np.identity(3), axis=(-1, -2), keepdims=True))

        pass
Exemplo n.º 3
0
    def test_message_to_child(self):
        """
        Test moments of GaussianARD.
        """

        # Check that moments have full shape when broadcasting
        X = GaussianARD(np.zeros((2, )), np.ones((3, 2)), shape=(4, 3, 2))
        (u0, u1) = X._message_to_child()
        self.assertEqual(np.shape(u0), (4, 3, 2))
        self.assertEqual(np.shape(u1), (4, 3, 2, 4, 3, 2))

        # Check the formula
        X = GaussianARD(2, 3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2)
        self.assertAllClose(u1, 2**2 + 1 / 3)

        # Check the formula for multidimensional arrays
        X = GaussianARD(2 * np.ones((2, 1, 4)), 3 * np.ones((2, 3, 1)), ndim=3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2 * np.ones((2, 3, 4)))
        self.assertAllClose(
            u1, 2**2 * np.ones(
                (2, 3, 4, 2, 3, 4)) + 1 / 3 * misc.identity(2, 3, 4))

        # Check the formula for dim-broadcasted mu
        X = GaussianARD(2 * np.ones((3, 1)), 3 * np.ones((2, 3, 4)), ndim=3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2 * np.ones((2, 3, 4)))
        self.assertAllClose(
            u1, 2**2 * np.ones(
                (2, 3, 4, 2, 3, 4)) + 1 / 3 * misc.identity(2, 3, 4))

        # Check the formula for dim-broadcasted alpha
        X = GaussianARD(2 * np.ones((2, 3, 4)), 3 * np.ones((3, 1)), ndim=3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2 * np.ones((2, 3, 4)))
        self.assertAllClose(
            u1, 2**2 * np.ones(
                (2, 3, 4, 2, 3, 4)) + 1 / 3 * misc.identity(2, 3, 4))

        # Check the formula for dim-broadcasted mu and alpha
        X = GaussianARD(2 * np.ones((3, 1)),
                        3 * np.ones((3, 1)),
                        shape=(2, 3, 4))
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2 * np.ones((2, 3, 4)))
        self.assertAllClose(
            u1, 2**2 * np.ones(
                (2, 3, 4, 2, 3, 4)) + 1 / 3 * misc.identity(2, 3, 4))

        # Check the formula for dim-broadcasted mu with plates
        mu = GaussianARD(2 * np.ones((5, 1, 3, 4)),
                         np.ones((5, 1, 3, 4)),
                         shape=(3, 4),
                         plates=(5, 1))
        X = GaussianARD(mu,
                        3 * np.ones((5, 2, 3, 4)),
                        shape=(2, 3, 4),
                        plates=(5, ))
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2 * np.ones((5, 2, 3, 4)))
        self.assertAllClose(
            u1, 2**2 * np.ones(
                (5, 2, 3, 4, 2, 3, 4)) + 1 / 3 * misc.identity(2, 3, 4))

        # Check posterior
        X = GaussianARD(2, 3)
        Y = GaussianARD(X, 1)
        Y.observe(10)
        X.update()
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 1 / (3 + 1) * (3 * 2 + 1 * 10))
        self.assertAllClose(u1,
                            (1 / (3 + 1) * (3 * 2 + 1 * 10))**2 + 1 / (3 + 1))

        pass
Exemplo n.º 4
0
    def test_message_to_parent_mu(self):
        """
        Test that GaussianARD computes the message to the 1st parent correctly.
        """

        # Check formula with uncertain parent alpha
        mu = GaussianARD(0, 1)
        alpha = Gamma(2,1)
        X = GaussianARD(mu,
                        alpha)
        X.observe(3)
        (m0, m1) = mu._message_from_children()
        #(m0, m1) = X._message_to_parent(0)
        self.assertAllClose(m0,
                            2*3)
        self.assertAllClose(m1,
                            -0.5*2)

        # Check formula with uncertain node
        mu = GaussianARD(1, 1e10)
        X = GaussianARD(mu, 2)
        Y = GaussianARD(X, 1)
        Y.observe(5)
        X.update()
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0,
                            2 * 1/(2+1)*(2*1+1*5))
        self.assertAllClose(m1,
                            -0.5*2)

        # Check alpha larger than mu
        mu = GaussianARD(np.zeros((2,3)), 1e10, shape=(2,3))
        X = GaussianARD(mu,
                        2*np.ones((3,2,3)))
        X.observe(3*np.ones((3,2,3)))
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0,
                            2*3 * 3 * np.ones((2,3)))
        self.assertAllClose(m1,
                            -0.5 * 3 * 2*misc.identity(2,3))

        # Check mu larger than alpha
        mu = GaussianARD(np.zeros((3,2,3)), 1e10, shape=(3,2,3))
        X = GaussianARD(mu,
                        2*np.ones((2,3)))
        X.observe(3*np.ones((3,2,3)))
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0,
                            2 * 3 * np.ones((3,2,3)))
        self.assertAllClose(m1,
                            -0.5 * 2*misc.identity(3,2,3))

        # Check node larger than mu and alpha
        mu = GaussianARD(np.zeros((2,3)), 1e10, shape=(2,3))
        X = GaussianARD(mu,
                        2*np.ones((3,)),
                        shape=(3,2,3))
        X.observe(3*np.ones((3,2,3)))
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0,
                            2*3 * 3*np.ones((2,3)))
        self.assertAllClose(m1,
                            -0.5 * 2 * 3*misc.identity(2,3))

        # Check broadcasting of dimensions
        mu = GaussianARD(np.zeros((2,1)), 1e10, shape=(2,1))
        X = GaussianARD(mu,
                        2*np.ones((2,3)),
                        shape=(2,3))
        X.observe(3*np.ones((2,3)))
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0,
                            2*3 * 3*np.ones((2,1)))
        self.assertAllClose(m1,
                            -0.5 * 2 * 3*misc.identity(2,1))

        # Check plates for smaller mu than node
        mu = GaussianARD(0,1, 
                         shape=(3,),
                         plates=(4,1,1))
        X = GaussianARD(mu,
                        2*np.ones((3,)),
                        shape=(2,3),
                        plates=(4,5))
        X.observe(3*np.ones((4,5,2,3)))
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0 * np.ones((4,1,1,3)),
                            2*3 * 5*2*np.ones((4,1,1,3)))
        self.assertAllClose(m1 * np.ones((4,1,1,3,3)),
                            -0.5*2 * 5*2*misc.identity(3) * np.ones((4,1,1,3,3)))

        # Check mask
        mu = GaussianARD(np.zeros((2,1,3)), 1e10, shape=(3,))
        X = GaussianARD(mu,
                        2*np.ones((2,4,3)),
                        shape=(3,),
                        plates=(2,4,))
        X.observe(3*np.ones((2,4,3)), mask=[[True, True, True, False],
                                            [False, True, False, True]])
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0,
                            (2*3 * np.ones((2,1,3)) 
                             * np.array([[[3]], [[2]]])))
        self.assertAllClose(m1,
                            (-0.5*2 * misc.identity(3)
                             * np.ones((2,1,1,1))
                             * np.array([[[[3]]], [[[2]]]])))

        # Check mask with different shapes
        mu = GaussianARD(np.zeros((2,1,3)), 1e10, shape=())
        X = GaussianARD(mu,
                        2*np.ones((2,4,3)),
                        shape=(3,),
                        plates=(2,4,))
        mask = np.array([[True, True, True, False],
                         [False, True, False, True]])
        X.observe(3*np.ones((2,4,3)), mask=mask)
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0,
                            2*3 * np.sum(np.ones((2,4,3))*mask[...,None], 
                                         axis=-2,
                                         keepdims=True))
        self.assertAllClose(m1,
                            (-0.5*2 * np.sum(np.ones((2,4,3))*mask[...,None],
                                             axis=-2,
                                             keepdims=True)))

        # Check non-ARD Gaussian child
        mu = np.array([1,2])
        Mu = GaussianARD(mu, 1e10, shape=(2,))
        alpha = np.array([3,4])
        Lambda = np.array([[1, 0.5],
                          [0.5, 1]])
        X = GaussianARD(Mu, alpha, ndim=1)
        Y = Gaussian(X, Lambda)
        y = np.array([5,6])
        Y.observe(y)
        X.update()
        (m0, m1) = Mu._message_from_children()
        mean = np.dot(np.linalg.inv(np.diag(alpha)+Lambda),
                      np.dot(np.diag(alpha), mu)
                      + np.dot(Lambda, y))
        self.assertAllClose(m0,
                            np.dot(np.diag(alpha), mean))
        self.assertAllClose(m1,
                            -0.5*np.diag(alpha))

        # Check broadcasted variable axes
        mu = GaussianARD(np.zeros(1), 1e10, shape=(1,))
        X = GaussianARD(mu,
                        2,
                        shape=(3,))
        X.observe(3*np.ones(3))
        (m0, m1) = mu._message_from_children()
        self.assertAllClose(m0,
                            2*3 * np.sum(np.ones(3), axis=-1, keepdims=True))
        self.assertAllClose(m1,
                            -0.5*2 * np.sum(np.identity(3), 
                                            axis=(-1,-2), 
                                            keepdims=True))

        pass
Exemplo n.º 5
0
    def test_message_to_child(self):
        """
        Test moments of GaussianARD.
        """

        # Check that moments have full shape when broadcasting
        X = GaussianARD(np.zeros((2,)),
                        np.ones((3,2)),
                        shape=(4,3,2))
        (u0, u1) = X._message_to_child()
        self.assertEqual(np.shape(u0),
                         (4,3,2))
        self.assertEqual(np.shape(u1),
                         (4,3,2,4,3,2))

        # Check the formula
        X = GaussianARD(2, 3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2)
        self.assertAllClose(u1, 2**2 + 1/3)

        # Check the formula for multidimensional arrays
        X = GaussianARD(2*np.ones((2,1,4)),
                        3*np.ones((2,3,1)),
                        ndim=3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2*np.ones((2,3,4)))
        self.assertAllClose(u1, 
                            2**2 * np.ones((2,3,4,2,3,4))
                            + 1/3 * misc.identity(2,3,4))
                            

        # Check the formula for dim-broadcasted mu
        X = GaussianARD(2*np.ones((3,1)),
                        3*np.ones((2,3,4)),
                        ndim=3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2*np.ones((2,3,4)))
        self.assertAllClose(u1, 
                            2**2 * np.ones((2,3,4,2,3,4))
                            + 1/3 * misc.identity(2,3,4))
                            
        # Check the formula for dim-broadcasted alpha
        X = GaussianARD(2*np.ones((2,3,4)),
                        3*np.ones((3,1)),
                        ndim=3)
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2*np.ones((2,3,4)))
        self.assertAllClose(u1, 
                            2**2 * np.ones((2,3,4,2,3,4))
                            + 1/3 * misc.identity(2,3,4))
                            
        # Check the formula for dim-broadcasted mu and alpha
        X = GaussianARD(2*np.ones((3,1)),
                        3*np.ones((3,1)),
                        shape=(2,3,4))
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2*np.ones((2,3,4)))
        self.assertAllClose(u1, 
                            2**2 * np.ones((2,3,4,2,3,4))
                            + 1/3 * misc.identity(2,3,4))
                            
        # Check the formula for dim-broadcasted mu with plates
        mu = GaussianARD(2*np.ones((5,1,3,4)),
                         np.ones((5,1,3,4)),
                         shape=(3,4),
                         plates=(5,1))
        X = GaussianARD(mu,
                        3*np.ones((5,2,3,4)),
                        shape=(2,3,4),
                        plates=(5,))
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0, 2*np.ones((5,2,3,4)))
        self.assertAllClose(u1, 
                            2**2 * np.ones((5,2,3,4,2,3,4))
                            + 1/3 * misc.identity(2,3,4))

        # Check posterior
        X = GaussianARD(2, 3)
        Y = GaussianARD(X, 1)
        Y.observe(10)
        X.update()
        (u0, u1) = X._message_to_child()
        self.assertAllClose(u0,
                            1/(3+1) * (3*2 + 1*10))
        self.assertAllClose(u1,
                            (1/(3+1) * (3*2 + 1*10))**2 + 1/(3+1))
        
        pass
Exemplo n.º 6
0
    def test_message_to_parent(self):
        """
        Test the message from SumMultiply node to its parents.
        """

        data = 2
        tau = 3

        def check_message(true_m0, true_m1, parent, *args, F=None):
            if F is None:
                A = SumMultiply(*args)
                B = GaussianARD(A, tau)
                B.observe(data * np.ones(A.plates + A.dims[0]))
            else:
                A = F
            (A_m0, A_m1) = A._message_to_parent(parent)
            self.assertAllClose(true_m0, A_m0)
            self.assertAllClose(true_m1, A_m1)
            pass

        # Check: different message to each of multiple parents
        X1 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1)
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1)
        x2 = X2.get_moments()
        m0 = tau * data * x2[0]
        m1 = -0.5 * tau * x2[1] * np.identity(2)
        check_message(m0, m1, 0, 'i,i->i', X1, X2)
        check_message(m0, m1, 0, X1, [9], X2, [9], [9])
        m0 = tau * data * x1[0]
        m1 = -0.5 * tau * x1[1] * np.identity(2)
        check_message(m0, m1, 1, 'i,i->i', X1, X2)
        check_message(m0, m1, 1, X1, [9], X2, [9], [9])

        # Check: key not in output
        X1 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1)
        x1 = X1.get_moments()
        m0 = tau * data * np.ones(2)
        m1 = -0.5 * tau * np.ones((2, 2))
        check_message(m0, m1, 0, 'i', X1)
        check_message(m0, m1, 0, 'i->', X1)
        check_message(m0, m1, 0, X1, [9])
        check_message(m0, m1, 0, X1, [9], [])

        # Check: key not in some input
        X1 = GaussianARD(np.random.randn(), np.random.rand())
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1)
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(x2[0], axis=-1)
        m1 = -0.5 * tau * np.sum(x2[1] * np.identity(2), axis=(-1, -2))
        check_message(m0, m1, 0, ',i->i', X1, X2)
        check_message(m0, m1, 0, X1, [], X2, [9], [9])
        m0 = tau * data * x1[0] * np.ones(2)
        m1 = -0.5 * tau * x1[1] * np.identity(2)
        check_message(m0, m1, 1, ',i->i', X1, X2)
        check_message(m0, m1, 1, X1, [], X2, [9], [9])

        # Check: keys in different order
        Y1 = GaussianARD(np.random.randn(3, 2), np.random.rand(3, 2), ndim=2)
        y1 = Y1.get_moments()
        Y2 = GaussianARD(np.random.randn(2, 3), np.random.rand(2, 3), ndim=2)
        y2 = Y2.get_moments()
        m0 = tau * data * y2[0].T
        m1 = -0.5 * tau * np.einsum('ijlk->jikl', y2[1] * misc.identity(2, 3))
        check_message(m0, m1, 0, 'ij,ji->ij', Y1, Y2)
        check_message(m0, m1, 0, Y1, ['i', 'j'], Y2, ['j', 'i'], ['i', 'j'])
        m0 = tau * data * y1[0].T
        m1 = -0.5 * tau * np.einsum('ijlk->jikl', y1[1] * misc.identity(3, 2))
        check_message(m0, m1, 1, 'ij,ji->ij', Y1, Y2)
        check_message(m0, m1, 1, Y1, ['i', 'j'], Y2, ['j', 'i'], ['i', 'j'])

        # Check: plates when different dimensionality
        X1 = GaussianARD(np.random.randn(5),
                         np.random.rand(5),
                         shape=(),
                         plates=(5, ))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(5, 3),
                         np.random.rand(5, 3),
                         shape=(3, ),
                         plates=(5, ))
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(np.ones((5, 3)) * x2[0], axis=-1)
        m1 = -0.5 * tau * np.sum(x2[1] * misc.identity(3), axis=(-1, -2))
        check_message(m0, m1, 0, ',i->i', X1, X2)
        check_message(m0, m1, 0, X1, [], X2, ['i'], ['i'])
        m0 = tau * data * x1[0][:, np.newaxis] * np.ones((5, 3))
        m1 = -0.5 * tau * x1[1][:, np.newaxis, np.newaxis] * misc.identity(3)
        check_message(m0, m1, 1, ',i->i', X1, X2)
        check_message(m0, m1, 1, X1, [], X2, ['i'], ['i'])

        # Check: other parent's moments broadcasts over plates when node has the
        # same plates
        X1 = GaussianARD(np.random.randn(5, 4, 3),
                         np.random.rand(5, 4, 3),
                         shape=(3, ),
                         plates=(5, 4))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3, ),
                         plates=(5, 4))
        x2 = X2.get_moments()
        m0 = tau * data * np.ones((5, 4, 3)) * x2[0]
        m1 = -0.5 * tau * x2[1] * misc.identity(3)
        check_message(m0, m1, 0, 'i,i->i', X1, X2)
        check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i'])

        # Check: other parent's moments broadcasts over plates when node does
        # not have that plate
        X1 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3, ),
                         plates=())
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3, ),
                         plates=(5, 4))
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(np.ones((5, 4, 3)) * x2[0], axis=(0, 1))
        m1 = -0.5 * tau * np.sum(np.ones(
            (5, 4, 1, 1)) * misc.identity(3) * x2[1],
                                 axis=(0, 1))
        check_message(m0, m1, 0, 'i,i->i', X1, X2)
        check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i'])

        # Check: other parent's moments broadcasts over plates when the node
        # only broadcasts that plate
        X1 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3, ),
                         plates=(1, 1))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3),
                         np.random.rand(3),
                         shape=(3, ),
                         plates=(5, 4))
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(
            np.ones((5, 4, 3)) * x2[0], axis=(0, 1), keepdims=True)
        m1 = -0.5 * tau * np.sum(np.ones(
            (5, 4, 1, 1)) * misc.identity(3) * x2[1],
                                 axis=(0, 1),
                                 keepdims=True)
        check_message(m0, m1, 0, 'i,i->i', X1, X2)
        check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i'])

        # Check: broadcasted dimensions
        X1 = GaussianARD(np.random.randn(1, 1), np.random.rand(1, 1), ndim=2)
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3, 2), np.random.rand(3, 2), ndim=2)
        x2 = X2.get_moments()
        m0 = tau * data * np.sum(np.ones((3, 2)) * x2[0], keepdims=True)
        m1 = -0.5 * tau * np.sum(misc.identity(3, 2) * x2[1], keepdims=True)
        check_message(m0, m1, 0, 'ij,ij->ij', X1, X2)
        check_message(m0, m1, 0, X1, [0, 1], X2, [0, 1], [0, 1])
        m0 = tau * data * np.ones((3, 2)) * x1[0]
        m1 = -0.5 * tau * misc.identity(3, 2) * x1[1]
        check_message(m0, m1, 1, 'ij,ij->ij', X1, X2)
        check_message(m0, m1, 1, X1, [0, 1], X2, [0, 1], [0, 1])

        # Check: non-ARD observations
        X1 = GaussianARD(np.random.randn(2), np.random.rand(2), ndim=1)
        x1 = X1.get_moments()
        Lambda = np.array([[2, 1.5], [1.5, 2]])
        F = SumMultiply('i->i', X1)
        Y = Gaussian(F, Lambda)
        y = np.random.randn(2)
        Y.observe(y)
        m0 = np.dot(Lambda, y)
        m1 = -0.5 * Lambda
        check_message(m0, m1, 0, 'i->i', X1, F=F)
        check_message(m0, m1, 0, X1, ['i'], ['i'], F=F)

        # Check: mask with same shape
        X1 = GaussianARD(np.random.randn(3, 2),
                         np.random.rand(3, 2),
                         shape=(2, ),
                         plates=(3, ))
        x1 = X1.get_moments()
        mask = np.array([True, False, True])
        F = SumMultiply('i->i', X1)
        Y = GaussianARD(F, tau, ndim=1)
        Y.observe(data * np.ones((3, 2)), mask=mask)
        m0 = tau * data * mask[:, np.newaxis] * np.ones(2)
        m1 = -0.5 * tau * mask[:, np.newaxis, np.newaxis] * np.identity(2)
        check_message(m0, m1, 0, 'i->i', X1, F=F)
        check_message(m0, m1, 0, X1, ['i'], ['i'], F=F)

        # Check: mask larger
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         shape=(2, ),
                         plates=())
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(3, 2),
                         np.random.rand(3, 2),
                         shape=(2, ),
                         plates=(3, ))
        x2 = X2.get_moments()
        mask = np.array([True, False, True])
        F = SumMultiply('i,i->i', X1, X2)
        Y = GaussianARD(F, tau, plates=(3, ), ndim=1)
        Y.observe(data * np.ones((3, 2)), mask=mask)
        m0 = tau * data * np.sum(mask[:, np.newaxis] * x2[0], axis=0)
        m1 = -0.5 * tau * np.sum(
            mask[:, np.newaxis, np.newaxis] * x2[1] * np.identity(2), axis=0)
        check_message(m0, m1, 0, 'i,i->i', X1, X2, F=F)
        check_message(m0, m1, 0, X1, ['i'], X2, ['i'], ['i'], F=F)

        # Check: mask for broadcasted plate
        X1 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         ndim=1,
                         plates=(1, ))
        x1 = X1.get_moments()
        X2 = GaussianARD(np.random.randn(2),
                         np.random.rand(2),
                         ndim=1,
                         plates=(3, ))
        x2 = X2.get_moments()
        mask = np.array([True, False, True])
        F = SumMultiply('i,i->i', X1, X2)
        Y = GaussianARD(F, tau, plates=(3, ), ndim=1)
        Y.observe(data * np.ones((3, 2)), mask=mask)
        m0 = tau * data * np.sum(
            mask[:, np.newaxis] * x2[0], axis=0, keepdims=True)
        m1 = -0.5 * tau * np.sum(
            mask[:, np.newaxis, np.newaxis] * x2[1] * np.identity(2),
            axis=0,
            keepdims=True)
        check_message(m0, m1, 0, 'i->i', X1, F=F)
        check_message(m0, m1, 0, X1, ['i'], ['i'], F=F)

        # Test with constant nodes
        N = 10
        M = 8
        D = 5
        K = 3
        a = np.random.randn(N, D)
        B = Gaussian(
            np.random.randn(D),
            random.covariance(D),
        )
        C = GaussianARD(np.random.randn(M, 1, D, K),
                        np.random.rand(M, 1, D, K),
                        ndim=2)
        F = SumMultiply('i,i,ij->', a, B, C)
        tau = np.random.rand(M, N)
        Y = GaussianARD(F, tau, plates=(M, N))
        y = np.random.randn(M, N)
        Y.observe(y)
        (m0, m1) = F._message_to_parent(1)
        np.testing.assert_allclose(
            m0,
            np.einsum('mn,ni,mnik->i', tau * y, a,
                      C.get_moments()[0]),
        )
        np.testing.assert_allclose(
            m1,
            np.einsum('mn,ni,nj,mnikjl->ij', -0.5 * tau, a, a,
                      C.get_moments()[1]),
        )

        # Check: Gaussian-gamma parents
        X1 = GaussianGamma(np.random.randn(2), random.covariance(2),
                           np.random.rand(), np.random.rand())
        x1 = X1.get_moments()
        X2 = GaussianGamma(np.random.randn(2), random.covariance(2),
                           np.random.rand(), np.random.rand())
        x2 = X2.get_moments()
        F = SumMultiply('i,i->i', X1, X2)
        V = random.covariance(2)
        y = np.random.randn(2)
        Y = Gaussian(F, V)
        Y.observe(y)
        m0 = np.dot(V, y) * x2[0]
        m1 = -0.5 * V * x2[1]
        m2 = -0.5 * np.einsum('i,ij,j', y, V,
                              y) * x2[2]  #linalg.inner(V, x2[2], ndim=2)
        m3 = 0.5 * 2  #linalg.chol_logdet(linalg.chol(V)) + 2*x2[3]
        m = F._message_to_parent(0)
        self.assertAllClose(m[0], m0)
        self.assertAllClose(m[1], m1)
        self.assertAllClose(m[2], m2)
        self.assertAllClose(m[3], m3)

        # Delta moments
        N = 10
        M = 8
        D = 5
        a = np.random.randn(N, D)
        B = GaussianGamma(np.random.randn(D),
                          random.covariance(D),
                          np.random.rand(),
                          np.random.rand(),
                          ndim=1)
        F = SumMultiply('i,i->', a, B)
        tau = np.random.rand(M, N)
        Y = GaussianARD(F, tau, plates=(M, N))
        y = np.random.randn(M, N)
        Y.observe(y)
        (m0, m1, m2, m3) = F._message_to_parent(1)
        np.testing.assert_allclose(
            m0,
            np.einsum('mn,ni->i', tau * y, a),
        )
        np.testing.assert_allclose(
            m1,
            np.einsum('mn,ni,nj->ij', -0.5 * tau, a, a),
        )
        np.testing.assert_allclose(
            m2,
            np.einsum('mn->', -0.5 * tau * y**2),
        )
        np.testing.assert_allclose(
            m3,
            np.einsum('mn->', 0.5 * np.ones(np.shape(tau))),
        )
        pass