Esempio n. 1
0
    def test_mask_to_parent(self):
        """
        Test the mask handling in Gate node
        """

        X = GaussianARD(2, 1, shape=(4, 5), plates=(3, 2))
        F = Gate([0, 0, 1], X)

        self.assertAllClose(
            F._compute_weights_to_parent(0, [True, False, False]),
            [True, False, False])
        self.assertAllClose(
            F._compute_weights_to_parent(1, [True, False, False]),
            [[True], [False], [False]])

        pass
Esempio n. 2
0
    def test_mask_to_parent(self):
        """
        Test the mask handling in Gate node
        """

        X = GaussianARD(2, 1, shape=(4, 5), plates=(3, 2))
        F = Gate([0, 0, 1], X)

        self.assertAllClose(
            F._compute_weights_to_parent(0, [True, False, False]),
            [True, False, False]
        )
        self.assertAllClose(
            F._compute_weights_to_parent(1, [True, False, False]),
            [[True], [False], [False]]
        )

        pass
Esempio n. 3
0
    def test_init(self):
        """
        Test the creation of Gate node
        """

        # Gating scalar node
        Z = Categorical(np.ones(3) / 3)
        X = GaussianARD(0, 1, shape=(), plates=(3, ))
        Y = Gate(Z, X)
        self.assertEqual(Y.plates, ())
        self.assertEqual(Y.dims, ((), ()))

        # Gating non-scalar node
        Z = Categorical(np.ones(3) / 3)
        X = GaussianARD(0, 1, shape=(2, ), plates=(3, ))
        Y = Gate(Z, X)
        self.assertEqual(Y.plates, ())
        self.assertEqual(Y.dims, ((2, ), (2, 2)))

        # Plates from Z
        Z = Categorical(np.ones(3) / 3, plates=(4, ))
        X = GaussianARD(0, 1, shape=(2, ), plates=(3, ))
        Y = Gate(Z, X)
        self.assertEqual(Y.plates, (4, ))
        self.assertEqual(Y.dims, ((2, ), (2, 2)))

        # Plates from X
        Z = Categorical(np.ones(3) / 3)
        X = GaussianARD(0, 1, shape=(2, ), plates=(4, 3))
        Y = Gate(Z, X)
        self.assertEqual(Y.plates, (4, ))
        self.assertEqual(Y.dims, ((2, ), (2, 2)))

        # Plates from Z and X
        Z = Categorical(np.ones(3) / 3, plates=(5, ))
        X = GaussianARD(0, 1, shape=(2, ), plates=(4, 1, 3))
        Y = Gate(Z, X)
        self.assertEqual(Y.plates, (4, 5))
        self.assertEqual(Y.dims, ((2, ), (2, 2)))

        # Gating non-default plate
        Z = Categorical(np.ones(3) / 3)
        X = GaussianARD(0, 1, shape=(), plates=(3, 4))
        Y = Gate(Z, X, gated_plate=-2)
        self.assertEqual(Y.plates, (4, ))
        self.assertEqual(Y.dims, ((), ()))

        # Fixed gating
        Z = 2
        X = GaussianARD(0, 1, shape=(2, ), plates=(3, ))
        Y = Gate(Z, X)
        self.assertEqual(Y.plates, ())
        self.assertEqual(Y.dims, ((2, ), (2, 2)))

        # Fixed X
        Z = Categorical(np.ones(3) / 3)
        X = [1, 2, 3]
        Y = Gate(Z, X, moments=GaussianMoments(0))
        self.assertEqual(Y.plates, ())
        self.assertEqual(Y.dims, ((), ()))

        # Do not accept non-negative cluster plates
        Z = Categorical(np.ones(3) / 3)
        X = GaussianARD(0, 1, plates=(3, ))
        self.assertRaises(ValueError, Gate, Z, X, gated_plate=0)

        # None of the parents have the cluster plate axis
        Z = Categorical(np.ones(3) / 3)
        X = GaussianARD(0, 1)
        self.assertRaises(ValueError, Gate, Z, X)

        # Inconsistent cluster plate
        Z = Categorical(np.ones(3) / 3)
        X = GaussianARD(0, 1, plates=(2, ))
        self.assertRaises(ValueError, Gate, Z, X)

        pass
Esempio n. 4
0
    def test_message_to_parent(self):
        """
        Test the message to parents of Gate node.
        """

        # Unobserved and broadcasting
        Z = 2
        X = GaussianARD(0, 1, shape=(), plates=(3, ))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        m = F._message_to_parent(0)
        self.assertEqual(len(m), 1)
        self.assertAllClose(m[0], 0 * np.ones(3))
        m = F._message_to_parent(1)
        self.assertEqual(len(m), 2)
        self.assertAllClose(m[0] * np.ones(3), [0, 0, 0])
        self.assertAllClose(m[1] * np.ones(3), [0, 0, 0])

        # Gating scalar node
        Z = 2
        X = GaussianARD([1, 2, 3], 1, shape=(), plates=(3, ))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        Y.observe(10)
        m = F._message_to_parent(0)
        self.assertAllClose(
            m[0], [10 * 1 - 0.5 * 2, 10 * 2 - 0.5 * 5, 10 * 3 - 0.5 * 10])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [0, 0, 10])
        self.assertAllClose(m[1], [0, 0, -0.5])

        # Fixed X
        Z = 2
        X = [1, 2, 3]
        F = Gate(Z, X, moments=GaussianMoments(0))
        Y = GaussianARD(F, 1)
        Y.observe(10)
        m = F._message_to_parent(0)
        self.assertAllClose(
            m[0], [10 * 1 - 0.5 * 1, 10 * 2 - 0.5 * 4, 10 * 3 - 0.5 * 9])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [0, 0, 10])
        self.assertAllClose(m[1], [0, 0, -0.5])

        # Uncertain gating
        Z = Categorical([0.2, 0.3, 0.5])
        X = GaussianARD([1, 2, 3], 1, shape=(), plates=(3, ))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        Y.observe(10)
        m = F._message_to_parent(0)
        self.assertAllClose(
            m[0], [10 * 1 - 0.5 * 2, 10 * 2 - 0.5 * 5, 10 * 3 - 0.5 * 10])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [0.2 * 10, 0.3 * 10, 0.5 * 10])
        self.assertAllClose(m[1], [-0.5 * 0.2, -0.5 * 0.3, -0.5 * 0.5])

        # Plates in Z
        Z = [2, 0]
        X = GaussianARD([1, 2, 3], 1, shape=(), plates=(3, ))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        Y.observe([10, 20])
        m = F._message_to_parent(0)
        self.assertAllClose(
            m[0], [[10 * 1 - 0.5 * 2, 10 * 2 - 0.5 * 5, 10 * 3 - 0.5 * 10],
                   [20 * 1 - 0.5 * 2, 20 * 2 - 0.5 * 5, 20 * 3 - 0.5 * 10]])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [20, 0, 10])
        self.assertAllClose(m[1], [-0.5, 0, -0.5])

        # Plates in X
        Z = 2
        X = GaussianARD([[1, 2, 3], [4, 5, 6]], 1, shape=(), plates=(
            2,
            3,
        ))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        Y.observe([10, 20])
        m = F._message_to_parent(0)
        self.assertAllClose(m[0], [
            10 * 1 - 0.5 * 2 + 20 * 4 - 0.5 * 17, 10 * 2 - 0.5 * 5 + 20 * 5 -
            0.5 * 26, 10 * 3 - 0.5 * 10 + 20 * 6 - 0.5 * 37
        ])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [[0, 0, 10], [0, 0, 20]])
        self.assertAllClose(m[1] * np.ones((2, 3)),
                            [[0, 0, -0.5], [0, 0, -0.5]])

        # Gating non-default plate
        Z = 2
        X = GaussianARD([[1], [2], [3]], 1, shape=(), plates=(3, 1))
        F = Gate(Z, X, gated_plate=-2)
        Y = GaussianARD(F, 1)
        Y.observe([10])
        m = F._message_to_parent(0)
        self.assertAllClose(
            m[0], [10 * 1 - 0.5 * 2, 10 * 2 - 0.5 * 5, 10 * 3 - 0.5 * 10])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [[0], [0], [10]])
        self.assertAllClose(m[1], [[0], [0], [-0.5]])

        # Gating non-scalar node
        Z = 2
        X = GaussianARD([[1, 4], [2, 5], [3, 6]], 1, shape=(2, ), plates=(3, ))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        Y.observe([10, 20])
        m = F._message_to_parent(0)
        self.assertAllClose(m[0], [
            10 * 1 - 0.5 * 2 + 20 * 4 - 0.5 * 17, 10 * 2 - 0.5 * 5 + 20 * 5 -
            0.5 * 26, 10 * 3 - 0.5 * 10 + 20 * 6 - 0.5 * 37
        ])
        m = F._message_to_parent(1)
        I = np.identity(2)
        self.assertAllClose(m[0], [[0, 0], [0, 0], [10, 20]])
        self.assertAllClose(m[1], [0 * I, 0 * I, -0.5 * I])

        # Broadcasting the moments on the cluster axis
        Z = 2
        X = GaussianARD(2, 1, shape=(), plates=(3, ))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        Y.observe(10)
        m = F._message_to_parent(0)
        self.assertAllClose(
            m[0], [10 * 2 - 0.5 * 5, 10 * 2 - 0.5 * 5, 10 * 2 - 0.5 * 5])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [0, 0, 10])
        self.assertAllClose(m[1], [0, 0, -0.5])

        pass
Esempio n. 5
0
    def test_message_to_child(self):
        """
        Test the message to child of Gate node.
        """

        # Gating scalar node
        Z = 2
        X = GaussianARD([1, 2, 3], 1, shape=(), plates=(3, ))
        Y = Gate(Z, X)
        u = Y._message_to_child()
        self.assertEqual(len(u), 2)
        self.assertAllClose(u[0], 3)
        self.assertAllClose(u[1], 3**2 + 1)

        # Fixed X
        Z = 2
        X = [1, 2, 3]
        Y = Gate(Z, X, moments=GaussianMoments(0))
        u = Y._message_to_child()
        self.assertEqual(len(u), 2)
        self.assertAllClose(u[0], 3)
        self.assertAllClose(u[1], 3**2)

        # Uncertain gating
        Z = Categorical([0.2, 0.3, 0.5])
        X = GaussianARD([1, 2, 3], 1, shape=(), plates=(3, ))
        Y = Gate(Z, X)
        u = Y._message_to_child()
        self.assertAllClose(u[0], 0.2 * 1 + 0.3 * 2 + 0.5 * 3)
        self.assertAllClose(u[1], 0.2 * 2 + 0.3 * 5 + 0.5 * 10)

        # Plates in Z
        Z = [2, 0]
        X = GaussianARD([1, 2, 3], 1, shape=(), plates=(3, ))
        Y = Gate(Z, X)
        u = Y._message_to_child()
        self.assertAllClose(u[0], [3, 1])
        self.assertAllClose(u[1], [10, 2])

        # Plates in X
        Z = 2
        X = GaussianARD([1, 2, 3], 1, shape=(), plates=(
            4,
            3,
        ))
        Y = Gate(Z, X)
        u = Y._message_to_child()
        self.assertAllClose(np.ones(4) * u[0], np.ones(4) * 3)
        self.assertAllClose(np.ones(4) * u[1], np.ones(4) * 10)

        # Gating non-default plate
        Z = 2
        X = GaussianARD([[1], [2], [3]], 1, shape=(), plates=(3, 4))
        Y = Gate(Z, X, gated_plate=-2)
        u = Y._message_to_child()
        self.assertAllClose(np.ones(4) * u[0], np.ones(4) * 3)
        self.assertAllClose(np.ones(4) * u[1], np.ones(4) * 10)

        # Gating non-scalar node
        Z = 2
        X = GaussianARD([1 * np.ones(4), 2 * np.ones(4), 3 * np.ones(4)],
                        1,
                        shape=(4, ),
                        plates=(3, ))
        Y = Gate(Z, X)
        u = Y._message_to_child()
        self.assertAllClose(u[0], 3 * np.ones(4))
        self.assertAllClose(u[1], 9 * np.ones((4, 4)) + 1 * np.identity(4))

        # Broadcasting the moments on the cluster axis
        Z = 2
        X = GaussianARD(1, 1, shape=(), plates=(3, ))
        Y = Gate(Z, X)
        u = Y._message_to_child()
        self.assertEqual(len(u), 2)
        self.assertAllClose(u[0], 1)
        self.assertAllClose(u[1], 1**2 + 1)

        pass
Esempio n. 6
0
    def test_message_to_parent(self):
        """
        Test the message to parents of Mixture node.
        """

        K = 3

        # Broadcasting the moments on the cluster axis
        Mu = GaussianARD(2, 1,
                         ndim=0,
                         plates=(K,))
        (mu, mumu) = Mu._message_to_child()
        Alpha = Gamma(3, 1,
                      plates=(K,))
        (alpha, logalpha) = Alpha._message_to_child()
        z = Categorical(np.ones(K)/K)
        X = Mixture(z, GaussianARD, Mu, Alpha)
        tau = 4
        Y = GaussianARD(X, tau)
        y = 5
        Y.observe(y)
        (x, xx) = X._message_to_child()
        m = z._message_from_children()
        self.assertAllClose(m[0] * np.ones(K),
                            random.gaussian_logpdf(xx*alpha,
                                                   x*alpha*mu,
                                                   mumu*alpha,
                                                   logalpha,
                                                   0)
                            * np.ones(K))
        m = Mu._message_from_children()
        self.assertAllClose(m[0],
                            1/K * (alpha*x) * np.ones(3))
        self.assertAllClose(m[1],
                            -0.5 * 1/K * alpha * np.ones(3))

        # Some parameters do not have cluster plate axis
        Mu = GaussianARD(2, 1,
                         ndim=0,
                         plates=(K,))
        (mu, mumu) = Mu._message_to_child()
        Alpha = Gamma(3, 1) # Note: no cluster plate axis!
        (alpha, logalpha) = Alpha._message_to_child()
        z = Categorical(np.ones(K)/K)
        X = Mixture(z, GaussianARD, Mu, Alpha)
        tau = 4
        Y = GaussianARD(X, tau)
        y = 5
        Y.observe(y)
        (x, xx) = X._message_to_child()
        m = z._message_from_children()
        self.assertAllClose(m[0] * np.ones(K),
                            random.gaussian_logpdf(xx*alpha,
                                                   x*alpha*mu,
                                                   mumu*alpha,
                                                   logalpha,
                                                   0)
                            * np.ones(K))
                                                   
        m = Mu._message_from_children()
        self.assertAllClose(m[0],
                            1/K * (alpha*x) * np.ones(3))
        self.assertAllClose(m[1],
                            -0.5 * 1/K * alpha * np.ones(3))

        # Cluster assignments do not have as many plate axes as parameters.
        M = 2
        Mu = GaussianARD(2, 1,
                         ndim=0,
                         plates=(K,M))
        (mu, mumu) = Mu._message_to_child()
        Alpha = Gamma(3, 1,
                      plates=(K,M))
        (alpha, logalpha) = Alpha._message_to_child()
        z = Categorical(np.ones(K)/K)
        X = Mixture(z, GaussianARD, Mu, Alpha, cluster_plate=-2)
        tau = 4
        Y = GaussianARD(X, tau)
        y = 5 * np.ones(M)
        Y.observe(y)
        (x, xx) = X._message_to_child()
        m = z._message_from_children()
        self.assertAllClose(m[0]*np.ones(K),
                            np.sum(random.gaussian_logpdf(xx*alpha,
                                                          x*alpha*mu,
                                                          mumu*alpha,
                                                          logalpha,
                                                          0) *
                                   np.ones((K,M)),
                                   axis=-1))
                                                   
        m = Mu._message_from_children()
        self.assertAllClose(m[0] * np.ones((K,M)),
                            1/K * (alpha*x) * np.ones((K,M)))
        self.assertAllClose(m[1] * np.ones((K,M)),
                            -0.5 * 1/K * alpha * np.ones((K,M)))
        

        # Mixed distribution broadcasts g
        # This tests for a found bug. The bug caused an error.
        Z = Categorical([0.3, 0.5, 0.2])
        X = Mixture(Z, Categorical, [[0.2,0.8], [0.1,0.9], [0.3,0.7]])
        m = Z._message_from_children()

        #
        # Test nested mixtures
        #
        t1 = [1, 1, 0, 3, 3]
        t2 = [2]
        p = Dirichlet([1, 1], plates=(4, 3))
        X = Mixture(t1, Mixture, t2, Categorical, p)
        X.observe([1, 1, 0, 0, 0])
        p.update()
        self.assertAllClose(
            p.phi[0],
            [
                [[1, 1], [1, 1], [2, 1]],
                [[1, 1], [1, 1], [1, 3]],
                [[1, 1], [1, 1], [1, 1]],
                [[1, 1], [1, 1], [3, 1]],
            ]
        )

        # Test sample plates in nested mixtures
        t1 = Categorical([0.3, 0.7], plates=(5,))
        t2 = [[1], [1], [0], [3], [3]]
        t3 = 2
        p = Dirichlet([1, 1], plates=(2, 4, 3))
        X = Mixture(t1, Mixture, t2, Mixture, t3, Categorical, p)
        X.observe([1, 1, 0, 0, 0])
        p.update()
        self.assertAllClose(
            p.phi[0],
            [
                [
                    [[1, 1], [1, 1], [1.3, 1]],
                    [[1, 1], [1, 1], [1, 1.6]],
                    [[1, 1], [1, 1], [1, 1]],
                    [[1, 1], [1, 1], [1.6, 1]],
                ],
                [
                    [[1, 1], [1, 1], [1.7, 1]],
                    [[1, 1], [1, 1], [1, 2.4]],
                    [[1, 1], [1, 1], [1, 1]],
                    [[1, 1], [1, 1], [2.4, 1]],
                ]
            ]
        )

        # Check that Gate and nested Mixture are equal
        t1 = Categorical([0.3, 0.7], plates=(5,))
        t2 = Categorical([0.1, 0.3, 0.6], plates=(5, 1))
        p = Dirichlet([1, 2, 3, 4], plates=(2, 3))
        X = Mixture(t1, Mixture, t2, Categorical, p)
        X.observe([3, 3, 1, 2, 2])
        t1_msg = t1._message_from_children()
        t2_msg = t2._message_from_children()
        p_msg = p._message_from_children()
        t1 = Categorical([0.3, 0.7], plates=(5,))
        t2 = Categorical([0.1, 0.3, 0.6], plates=(5, 1))
        p = Dirichlet([1, 2, 3, 4], plates=(2, 3))
        X = Categorical(Gate(t1, Gate(t2, p)))
        X.observe([3, 3, 1, 2, 2])
        t1_msg2 = t1._message_from_children()
        t2_msg2 = t2._message_from_children()
        p_msg2 = p._message_from_children()
        self.assertAllClose(t1_msg[0], t1_msg2[0])
        self.assertAllClose(t2_msg[0], t2_msg2[0])
        self.assertAllClose(p_msg[0], p_msg2[0])

        pass
Esempio n. 7
0
    def test_message_to_parent(self):
        """
        Test the message to parents of Gate node.
        """

        # Unobserved and broadcasting
        Z = 2
        X = GaussianARD(0, 1, shape=(), plates=(3,))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        m = F._message_to_parent(0)
        self.assertEqual(len(m), 1)
        self.assertAllClose(m[0], 0*np.ones(3))
        m = F._message_to_parent(1)
        self.assertEqual(len(m), 2)
        self.assertAllClose(m[0]*np.ones(3), [0, 0, 0])
        self.assertAllClose(m[1]*np.ones(3), [0, 0, 0])
        
        # Gating scalar node
        Z = 2
        X = GaussianARD([1,2,3], 1, shape=(), plates=(3,))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        Y.observe(10)
        m = F._message_to_parent(0)
        self.assertAllClose(m[0], [10*1-0.5*2, 10*2-0.5*5, 10*3-0.5*10])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [0, 0, 10])
        self.assertAllClose(m[1], [0, 0, -0.5])
        
        # Fixed X
        Z = 2
        X = [1,2,3]
        F = Gate(Z, X, moments=GaussianMoments(0))
        Y = GaussianARD(F, 1)
        Y.observe(10)
        m = F._message_to_parent(0)
        self.assertAllClose(m[0], [10*1-0.5*1, 10*2-0.5*4, 10*3-0.5*9])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [0, 0, 10])
        self.assertAllClose(m[1], [0, 0, -0.5])

        # Uncertain gating
        Z = Categorical([0.2, 0.3, 0.5])
        X = GaussianARD([1,2,3], 1, shape=(), plates=(3,))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        Y.observe(10)
        m = F._message_to_parent(0)
        self.assertAllClose(m[0], [10*1-0.5*2, 10*2-0.5*5, 10*3-0.5*10])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [0.2*10, 0.3*10, 0.5*10])
        self.assertAllClose(m[1], [-0.5*0.2, -0.5*0.3, -0.5*0.5])

        # Plates in Z
        Z = [2, 0]
        X = GaussianARD([1,2,3], 1, shape=(), plates=(3,))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        Y.observe([10, 20])
        m = F._message_to_parent(0)
        self.assertAllClose(m[0], [[10*1-0.5*2, 10*2-0.5*5, 10*3-0.5*10],
                                   [20*1-0.5*2, 20*2-0.5*5, 20*3-0.5*10]])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [20, 0, 10])
        self.assertAllClose(m[1], [-0.5, 0, -0.5])

        # Plates in X
        Z = 2
        X = GaussianARD([[1,2,3], [4,5,6]], 1, shape=(), plates=(2,3,))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        Y.observe([10, 20])
        m = F._message_to_parent(0)
        self.assertAllClose(m[0], [10*1-0.5*2 + 20*4-0.5*17,
                                   10*2-0.5*5 + 20*5-0.5*26,
                                   10*3-0.5*10 + 20*6-0.5*37])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [[0, 0, 10],
                                   [0, 0, 20]])
        self.assertAllClose(m[1]*np.ones((2,3)), [[0, 0, -0.5],
                                                  [0, 0, -0.5]])

        # Gating non-default plate
        Z = 2
        X = GaussianARD([[1],[2],[3]], 1, shape=(), plates=(3,1))
        F = Gate(Z, X, gated_plate=-2)
        Y = GaussianARD(F, 1)
        Y.observe([10])
        m = F._message_to_parent(0)
        self.assertAllClose(m[0], [10*1-0.5*2, 10*2-0.5*5, 10*3-0.5*10])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [[0], [0], [10]])
        self.assertAllClose(m[1], [[0], [0], [-0.5]])

        # Gating non-scalar node
        Z = 2
        X = GaussianARD([[1,4],[2,5],[3,6]], 1, shape=(2,), plates=(3,))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        Y.observe([10,20])
        m = F._message_to_parent(0)
        self.assertAllClose(m[0], [10*1-0.5*2 + 20*4-0.5*17,
                                   10*2-0.5*5 + 20*5-0.5*26,
                                   10*3-0.5*10 + 20*6-0.5*37])
        m = F._message_to_parent(1)
        I = np.identity(2)
        self.assertAllClose(m[0], [[0,0], [0,0], [10,20]])
        self.assertAllClose(m[1], [0*I, 0*I, -0.5*I])
        
        # Broadcasting the moments on the cluster axis
        Z = 2
        X = GaussianARD(2, 1, shape=(), plates=(3,))
        F = Gate(Z, X)
        Y = GaussianARD(F, 1)
        Y.observe(10)
        m = F._message_to_parent(0)
        self.assertAllClose(m[0], [10*2-0.5*5, 10*2-0.5*5, 10*2-0.5*5])
        m = F._message_to_parent(1)
        self.assertAllClose(m[0], [0, 0, 10])
        self.assertAllClose(m[1], [0, 0, -0.5])

        pass
Esempio n. 8
0
    def test_message_to_child(self):
        """
        Test the message to child of Gate node.
        """

        # Gating scalar node
        Z = 2
        X = GaussianARD([1,2,3], 1, shape=(), plates=(3,))
        Y = Gate(Z, X)
        u = Y._message_to_child()
        self.assertEqual(len(u), 2)
        self.assertAllClose(u[0], 3)
        self.assertAllClose(u[1], 3**2+1)

        # Fixed X
        Z = 2
        X = [1, 2, 3]
        Y = Gate(Z, X, moments=GaussianMoments(0))
        u = Y._message_to_child()
        self.assertEqual(len(u), 2)
        self.assertAllClose(u[0], 3)
        self.assertAllClose(u[1], 3**2)

        # Uncertain gating
        Z = Categorical([0.2,0.3,0.5])
        X = GaussianARD([1,2,3], 1, shape=(), plates=(3,))
        Y = Gate(Z, X)
        u = Y._message_to_child()
        self.assertAllClose(u[0], 0.2*1 + 0.3*2 + 0.5*3)
        self.assertAllClose(u[1], 0.2*2 + 0.3*5 + 0.5*10)

        # Plates in Z
        Z = [2, 0]
        X = GaussianARD([1,2,3], 1, shape=(), plates=(3,))
        Y = Gate(Z, X)
        u = Y._message_to_child()
        self.assertAllClose(u[0], [3, 1])
        self.assertAllClose(u[1], [10, 2])

        # Plates in X
        Z = 2
        X = GaussianARD([1,2,3], 1, shape=(), plates=(4,3,))
        Y = Gate(Z, X)
        u = Y._message_to_child()
        self.assertAllClose(np.ones(4)*u[0], np.ones(4)*3)
        self.assertAllClose(np.ones(4)*u[1], np.ones(4)*10)

        # Gating non-default plate
        Z = 2
        X = GaussianARD([[1],[2],[3]], 1, shape=(), plates=(3,4))
        Y = Gate(Z, X, gated_plate=-2)
        u = Y._message_to_child()
        self.assertAllClose(np.ones(4)*u[0], np.ones(4)*3)
        self.assertAllClose(np.ones(4)*u[1], np.ones(4)*10)

        # Gating non-scalar node
        Z = 2
        X = GaussianARD([1*np.ones(4),
                         2*np.ones(4),
                         3*np.ones(4)],
                        1,
                        shape=(4,), plates=(3,))
        Y = Gate(Z, X)
        u = Y._message_to_child()
        self.assertAllClose(u[0], 3*np.ones(4))
        self.assertAllClose(u[1], 9*np.ones((4,4)) + 1*np.identity(4))
        
        # Broadcasting the moments on the cluster axis
        Z = 2
        X = GaussianARD(1, 1, shape=(), plates=(3,))
        Y = Gate(Z, X)
        u = Y._message_to_child()
        self.assertEqual(len(u), 2)
        self.assertAllClose(u[0], 1)
        self.assertAllClose(u[1], 1**2+1)

        pass