コード例 #1
0
    def test_lower_bound(self):
        """
        Test lower bound for multinomial node.
        """

        # Test for a bug found in multinomial
        X = Multinomial(10, [0.3, 0.5, 0.2])
        l = X.lower_bound_contribution()
        self.assertAllClose(l, 0.0)
        
        pass
コード例 #2
0
ファイル: test_multinomial.py プロジェクト: zehsilva/bayespy
    def test_lower_bound(self):
        """
        Test lower bound for multinomial node.
        """

        # Test for a bug found in multinomial
        X = Multinomial(10, [0.3, 0.5, 0.2])
        l = X.lower_bound_contribution()
        self.assertAllClose(l, 0.0)

        pass
コード例 #3
0
ファイル: test_multinomial.py プロジェクト: zehsilva/bayespy
    def test_init(self):
        """
        Test the creation of multinomial nodes.
        """

        # Some simple initializations
        X = Multinomial(10, [0.1, 0.3, 0.6])
        X = Multinomial(10, Dirichlet([5, 4, 3]))

        # Check that plates are correct
        X = Multinomial(10, [0.1, 0.3, 0.6], plates=(3, 4))
        self.assertEqual(X.plates, (3, 4))
        X = Multinomial(10, 0.25 * np.ones((2, 3, 4)))
        self.assertEqual(X.plates, (2, 3))
        n = 10 * np.ones((3, 4), dtype=np.int)
        X = Multinomial(n, [0.1, 0.3, 0.6])
        self.assertEqual(X.plates, (3, 4))
        X = Multinomial(n, Dirichlet([2, 1, 9], plates=(3, 4)))
        self.assertEqual(X.plates, (3, 4))

        # Probabilities not a vector
        self.assertRaises(ValueError, Multinomial, 10, 0.5)

        # Invalid probability
        self.assertRaises(ValueError, Multinomial, 10, [-0.5, 1.5])
        self.assertRaises(ValueError, Multinomial, 10, [0.5, 1.5])

        # Invalid number of trials
        self.assertRaises(ValueError, Multinomial, -1, [0.5, 0.5])
        self.assertRaises(ValueError, Multinomial, 8.5, [0.5, 0.5])

        # Inconsistent plates
        self.assertRaises(ValueError,
                          Multinomial,
                          10,
                          0.25 * np.ones((2, 4)),
                          plates=(3, ))

        # Explicit plates too small
        self.assertRaises(ValueError,
                          Multinomial,
                          10,
                          0.25 * np.ones((2, 4)),
                          plates=(1, ))

        pass
コード例 #4
0
    def test_moments(self):
        """
        Test the moments of multinomial nodes.
        """

        # Simple test
        X = Multinomial(1, [0.7,0.2,0.1])
        u = X._message_to_child()
        self.assertEqual(len(u), 1)
        self.assertAllClose(u[0],
                            [0.7,0.2,0.1])

        # Test n
        X = Multinomial(10, [0.7,0.2,0.1])
        u = X._message_to_child()
        self.assertAllClose(u[0],
                            [7,2,1])

        # Test plates in p
        n = np.random.randint(1, 10)
        p = np.random.dirichlet([1,1], size=3)
        X = Multinomial(n, p)
        u = X._message_to_child()
        self.assertAllClose(u[0],
                            p*n)
        
        # Test plates in n
        n = np.random.randint(1, 10, size=(3,))
        p = np.random.dirichlet([1,1,1,1])
        X = Multinomial(n, p)
        u = X._message_to_child()
        self.assertAllClose(u[0],
                            p*n[:,None])

        # Test plates in p and n
        n = np.random.randint(1, 10, size=(4,1))
        p = np.random.dirichlet([1,1], size=3)
        X = Multinomial(n, p)
        u = X._message_to_child()
        self.assertAllClose(u[0],
                            p*n[...,None])

        # Test with Dirichlet prior
        P = Dirichlet([7, 3])
        logp = P._message_to_child()[0]
        p0 = np.exp(logp[0]) / (np.exp(logp[0]) + np.exp(logp[1]))
        p1 = np.exp(logp[1]) / (np.exp(logp[0]) + np.exp(logp[1]))
        X = Multinomial(1, P)
        u = X._message_to_child()
        p = np.array([p0, p1])
        self.assertAllClose(u[0],
                            p)

        # Test with broadcasted plates
        P = Dirichlet([7, 3], plates=(10,))
        X = Multinomial(5, P)
        u = X._message_to_child()
        self.assertAllClose(u[0] * np.ones(X.get_shape(0)),
                            5*p*np.ones((10,1)))

        pass
コード例 #5
0
ファイル: test_multinomial.py プロジェクト: zehsilva/bayespy
    def test_moments(self):
        """
        Test the moments of multinomial nodes.
        """

        # Simple test
        X = Multinomial(1, [0.7, 0.2, 0.1])
        u = X._message_to_child()
        self.assertEqual(len(u), 1)
        self.assertAllClose(u[0], [0.7, 0.2, 0.1])

        # Test n
        X = Multinomial(10, [0.7, 0.2, 0.1])
        u = X._message_to_child()
        self.assertAllClose(u[0], [7, 2, 1])

        # Test plates in p
        n = np.random.randint(1, 10)
        p = np.random.dirichlet([1, 1], size=3)
        X = Multinomial(n, p)
        u = X._message_to_child()
        self.assertAllClose(u[0], p * n)

        # Test plates in n
        n = np.random.randint(1, 10, size=(3, ))
        p = np.random.dirichlet([1, 1, 1, 1])
        X = Multinomial(n, p)
        u = X._message_to_child()
        self.assertAllClose(u[0], p * n[:, None])

        # Test plates in p and n
        n = np.random.randint(1, 10, size=(4, 1))
        p = np.random.dirichlet([1, 1], size=3)
        X = Multinomial(n, p)
        u = X._message_to_child()
        self.assertAllClose(u[0], p * n[..., None])

        # Test with Dirichlet prior
        P = Dirichlet([7, 3])
        logp = P._message_to_child()[0]
        p0 = np.exp(logp[0]) / (np.exp(logp[0]) + np.exp(logp[1]))
        p1 = np.exp(logp[1]) / (np.exp(logp[0]) + np.exp(logp[1]))
        X = Multinomial(1, P)
        u = X._message_to_child()
        p = np.array([p0, p1])
        self.assertAllClose(u[0], p)

        # Test with broadcasted plates
        P = Dirichlet([7, 3], plates=(10, ))
        X = Multinomial(5, P)
        u = X._message_to_child()
        self.assertAllClose(u[0] * np.ones(X.get_shape(0)), 5 * p * np.ones(
            (10, 1)))

        pass