Ejemplo n.º 1
0
    def test_observed(self):
        """
        Test observed categorical nodes
        """

        # Single observation
        X = Categorical([0.7, 0.2, 0.1])
        X.observe(2)
        u = X._message_to_child()
        self.assertAllClose(u[0], [0, 0, 1])

        # One plate axis
        X = Categorical([0.7, 0.2, 0.1], plates=(2, ))
        X.observe([2, 1])
        u = X._message_to_child()
        self.assertAllClose(u[0], [[0, 0, 1], [0, 1, 0]])

        # Several plate axes
        X = Categorical([0.7, 0.1, 0.1, 0.1], plates=(
            2,
            3,
        ))
        X.observe([[2, 1, 1], [0, 2, 3]])
        u = X._message_to_child()
        self.assertAllClose(u[0], [[[0, 0, 1, 0], [0, 1, 0, 0], [0, 1, 0, 0]],
                                   [[1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]])

        # Check invalid observations
        X = Categorical([0.7, 0.2, 0.1])
        self.assertRaises(ValueError, X.observe, -1)
        self.assertRaises(ValueError, X.observe, 3)
        self.assertRaises(ValueError, X.observe, 1.5)

        pass
Ejemplo n.º 2
0
    def test_observed(self):
        """
        Test observed categorical nodes
        """

        # Single observation
        X = Categorical([0.7,0.2,0.1])
        X.observe(2)
        u = X._message_to_child()
        self.assertAllClose(u[0],
                            [0,0,1])

        # One plate axis
        X = Categorical([0.7,0.2,0.1], plates=(2,))
        X.observe([2,1])
        u = X._message_to_child()
        self.assertAllClose(u[0],
                            [[0,0,1],
                             [0,1,0]])

        # Several plate axes
        X = Categorical([0.7,0.1,0.1,0.1], plates=(2,3,))
        X.observe([[2,1,1],
                   [0,2,3]])
        u = X._message_to_child()
        self.assertAllClose(u[0],
                            [ [[0,0,1,0],
                               [0,1,0,0],
                               [0,1,0,0]],
                              [[1,0,0,0],
                               [0,0,1,0],
                               [0,0,0,1]] ])

        # Check invalid observations
        X = Categorical([0.7,0.2,0.1])
        self.assertRaises(ValueError,
                          X.observe,
                          -1)
        self.assertRaises(ValueError,
                          X.observe,
                          3)
        self.assertRaises(ValueError,
                          X.observe,
                          1.5)

        pass
Ejemplo n.º 3
0
lung = Mixture(smoking, Categorical, [[0.98, 0.02], [0.25, 0.75]])

bronchitis = Mixture(smoking, Categorical, [[0.97, 0.03], [0.08, 0.92]])

xray = Mixture(tuberculosis, Mixture, lung, Categorical,
               _or([0.96, 0.04], [0.115, 0.885]))

dyspnea = Mixture(
    bronchitis, Mixture, tuberculosis, Mixture, lung, Categorical,
    [_or([0.6, 0.4], [0.18, 0.82]),
     _or([0.11, 0.89], [0.04, 0.96])])

# Mark observations
tuberculosis.observe(TRUE)
smoking.observe(FALSE)
bronchitis.observe(
    TRUE)  # not a "chance" observation as in the original example

# Run inference
Q = VB(dyspnea, xray, bronchitis, lung, smoking, tuberculosis, asia)
Q.update(repeat=100)

# Show results
print("P(asia):", asia.get_moments()[0][TRUE])
print("P(tuberculosis):", tuberculosis.get_moments()[0][TRUE])
print("P(smoking):", smoking.get_moments()[0][TRUE])
print("P(lung):", lung.get_moments()[0][TRUE])
print("P(bronchitis):", bronchitis.get_moments()[0][TRUE])
print("P(xray):", xray.get_moments()[0][TRUE])
print("P(dyspnea):", dyspnea.get_moments()[0][TRUE])
Ejemplo n.º 4
0
        ])
    """
    data-->[[0, 0, 0, 1, 3, 0, 0], [0, 1, 0, 1, 3, 0, 0], [1, 0, 1, 0, 2, 1, 0],
    [4, 0, 0, 1, 3, 2, 1],[3, 1, 0, 0, 0, 2, 1], [2, 0, 0, 1, 1, 0, 0], [4, 0, 0, 0, 2, 0, 0],
    [0, 0, 0, 1, 3, 0, 0],[3, 1, 0, 0, 0, 2, 1], [1, 1, 1, 0, 0, 2, 0], [4, 1, 1, 1, 2, 0, 0]]
    """
data = np.array(data)
N = len(data)
print(N)

p_age = Dirichlet(
    1.0 *
    np.ones(5))  #used to classify text in a document to a particular topic.
age = Categorical(
    p_age, plates=(N, ))  #a sequence of unique values and no missing values
age.observe(data[:, 0])

p_gender = Dirichlet(1.0 * np.ones(2))
gender = Categorical(p_gender, plates=(N, ))
gender.observe(data[:, 1])

p_familyhistory = Dirichlet(1.0 * np.ones(2))
familyhistory = Categorical(p_familyhistory, plates=(N, ))
familyhistory.observe(data[:, 2])

p_diet = Dirichlet(1.0 * np.ones(3))
diet = Categorical(p_diet, plates=(N, ))
diet.observe(data[:, 3])

p_lifestyle = Dirichlet(1.0 * np.ones(4))
lifestyle = Categorical(p_lifestyle, plates=(N, ))
Ejemplo n.º 5
0
smoking = Categorical([0.5, 0.5])

lung = Mixture(smoking, Categorical, [[0.98, 0.02], [0.25, 0.75]])

bronchitis = Mixture(smoking, Categorical, [[0.97, 0.03], [0.08, 0.92]])

xray = Mixture(tuberculosis, Mixture, lung, Categorical,
               _or([0.96, 0.04], [0.115, 0.885]))

dyspnea = Mixture(bronchitis, Mixture, tuberculosis, Mixture, lung, Categorical,
                  [_or([0.6, 0.4], [0.18, 0.82]),
                   _or([0.11, 0.89], [0.04, 0.96])])

# Mark observations
tuberculosis.observe(TRUE)
smoking.observe(FALSE)
bronchitis.observe(TRUE) # not a "chance" observation as in the original example

# Run inference
Q = VB(dyspnea, xray, bronchitis, lung, smoking, tuberculosis, asia)
Q.update(repeat=100)

# Show results
print("P(asia):", asia.get_moments()[0][TRUE])
print("P(tuberculosis):", tuberculosis.get_moments()[0][TRUE])
print("P(smoking):", smoking.get_moments()[0][TRUE])
print("P(lung):", lung.get_moments()[0][TRUE])
print("P(bronchitis):", bronchitis.get_moments()[0][TRUE])
print("P(xray):", xray.get_moments()[0][TRUE])
print("P(dyspnea):", dyspnea.get_moments()[0][TRUE])
Ejemplo n.º 6
0

A = Categorical([0.5, 0.5])

T = Mixture(A, Categorical, [[0.99, 0.01], [0.8, 0.2]])

S = Categorical([0.5, 0.5])

L = Mixture(S, Categorical, [[0.98, 0.02], [0.75, 0.25]])

B = Mixture(S, Categorical, [[0.97, 0.03], [0.70, 0.30]])

X = Mixture(T, Mixture, L, Categorical, _or([0.96, 0.04], [0.115, 0.885]))

D = Mixture(B, Mixture, X, Categorical, _or([0.115, 0.885], [0.04, 0.96]))

T.observe(TRUE)
S.observe(FALSE)

B.observe(TRUE)

Q = VB(A, T, S, L, B, X, D)
Q.update(repeat=100)

print("P(asia): ", A.get_moments()[0][TRUE])
print("P(tuberculosis): ", T.get_moments()[0][TRUE])
print("P(smoking): ", S.get_moments()[0][TRUE])
print("P(lung): ", L.get_moments()[0][TRUE])
print("P(bronchitis): ", B.get_moments()[0][TRUE])
print("P(xray): ", X.get_moments()[0][TRUE])
print("P(dyspnea): ", D.get_moments()[0][TRUE])
Ejemplo n.º 7
0
        O_n = Gaussian(mu, lambda_, name=f"O_{n}")
        obs_nodes[n] = O_n
    X.observe(o_n)

for action in actions:
    trial, agent, a_n, n = action
    if a_n < 0:  #action reset
        continue
    if n in action_nodes:
        A = action_nodes[n]
    else:
        category_prob = Dirichlet(1e-3 * np.ones(A_D),
                                  name='category_prob')  #FIXME: Unconfirmed!
        A = Categorical(category_prob)
        action_nodes[n] = A
    A.observe(a_n)

# In[139]:

action_nodes[0].__dict__

# In[ ]:

Dirichlet(1e-3 * np.ones(A_D))

# In[120]:

np.prod(env.action_space.shape)

# In[102]: