예제 #1
0
 def test_copy(self):
     values = np.random.rand(3, 2)
     d = Dirichlet(values=values)
     d_copy = d.copy()
     self.assertTrue(np.array_equal(d_copy.values, d.values))
     d_copy.values = d_copy.values * 2
     self.assertFalse(np.array_equal(d_copy.values, d.values))
예제 #2
0
    def test_update_pB_multiFactor_noActions_allFactors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are no actions. All factors are updated
        """

        n_states = [3, 4]
        n_control = [1, 1] 
        qs_prev = Categorical(values = construct_init_qs(n_states))
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0

        B = Categorical(values = np.array([np.random.rand(ns, ns, n_control[factor]) for factor, ns in enumerate(n_states)]))
        B.normalize()
        pB = Dirichlet(values = np.array([np.ones_like(B[factor].values) for factor in range(len(n_states))]))

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,B,action,qs,qs_prev,lr=learning_rate,factors="all",return_numpy=True)

        validation_pB = pB.copy()

        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            validation_pB[:,:,action[factor]] += learning_rate * core.spm_cross(qs[factor].values, qs_prev[factor].values) * (B[factor][:, :, action[factor]].values > 0)
            self.assertTrue(np.all(pB_updated[factor]==validation_pB.values))
예제 #3
0
    def test_update_pB_multiFactor_withActions_someFactors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are actions. Some factors are updated
        """

        n_states = [3, 4, 5]
        n_control = [3, 4, 5]
        qs_prev = Categorical(values = construct_init_qs(n_states))
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0

        factors_to_update = [0,1]

        B = Categorical(values = construct_generic_B(n_states,n_control))
        B.normalize()
        pB = Dirichlet(values = construct_pB(n_states,n_control))

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,B,action,qs,qs_prev,lr=learning_rate,factors=factors_to_update,return_numpy=True)

        validation_pB = pB.copy()

        for factor, _ in enumerate(n_control):

            validation_pB = pB[factor].copy()

            if factor in factors_to_update:
                validation_pB[:,:,action[factor]] += learning_rate * core.spm_cross(qs[factor].values, qs_prev[factor].values) * (B[factor][:, :, action[factor]].values > 0)
                
            self.assertTrue(np.all(pB_updated[factor]==validation_pB.values))
예제 #4
0
    def test_update_pB_multiFactor_someControllable_someFactors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and some of them 
        are controllable. Some factors are updated.
        """

        n_states = [3, 4, 5]
        n_control = [1, 3, 1]
        qs_prev = Categorical(values = construct_init_qs(n_states))
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0

        factors_to_update = [0,1]

        B_values = np.empty(len(n_states),dtype=object)
        pB_values = np.empty(len(n_states),dtype=object)
        for factor, ns in enumerate(n_states):
            B_values[factor] = np.random.rand(ns, ns, n_control[factor])
            pB_values[factor] = np.ones( (ns, ns, n_control[factor]) )

        B = Categorical(values = B_values)
        B.normalize()
        pB = Dirichlet(values = pB_values)

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,B,action,qs,qs_prev,lr=learning_rate,factors=factors_to_update,return_numpy=True)

        validation_pB = pB.copy()

        for factor, _ in enumerate(n_control):

            validation_pB = pB[factor].copy()

            if factor in factors_to_update:
                validation_pB[:,:,action[factor]] += learning_rate * core.spm_cross(qs[factor].values, qs_prev[factor].values) * (B[factor][:, :, action[factor]].values > 0)
                
            self.assertTrue(np.all(pB_updated[factor]==validation_pB.values))
예제 #5
0
    def test_update_pB_singleFactor_withActions(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that the one and only hidden state factor is updated, and there 
        are actions.
        """

        n_states = [3]
        n_control = [3] 
        qs_prev = Categorical(values = construct_init_qs(n_states))
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0

        B = Categorical(values = construct_generic_B(n_states, n_control))
        pB = Dirichlet(values = np.ones_like(B.values))

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,B,action,qs,qs_prev,lr=learning_rate,factors="all",return_numpy=True)

        validation_pB = pB.copy()
        validation_pB[:,:,action[0]] += learning_rate * core.spm_cross(qs.values, qs_prev.values) * (B[:, :, action[0]].values > 0)
        self.assertTrue(np.all(pB_updated==validation_pB.values))