Exemplo n.º 1
0
    def test_update_pB_multi_factor_with_actions_all_factors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are actions. All factors are updated
        """

        n_states = [3, 4, 5]
        n_control = [3, 4, 5]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        B = Categorical(values=construct_generic_B(n_states, n_control))
        B.normalize()
        pB = Dirichlet(values=construct_pB(n_states, n_control))
        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,
                                                      B,
                                                      action,
                                                      qs,
                                                      qs_prev,
                                                      lr=l_rate,
                                                      factors="all",
                                                      return_numpy=True)

        validation_pB = pB.copy()
        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            validation_pB[:, :, action[factor]] += (
                l_rate *
                core.spm_cross(qs[factor].values, qs_prev[factor].values) *
                (B[factor][:, :, action[factor]].values > 0))
            self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))
Exemplo n.º 2
0
    def test_update_pB_single_dactor_with_actions(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that the one and only hidden state factor is updated, and there 
        are actions.
        """

        n_states = [3]
        n_control = [3]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        B = Categorical(values=construct_generic_B(n_states, n_control))
        pB = Dirichlet(values=np.ones_like(B.values))
        action = np.array([np.random.randint(nc) for nc in n_control])
        pB_updated = core.update_transition_dirichlet(pB,
                                                      B,
                                                      action,
                                                      qs,
                                                      qs_prev,
                                                      lr=l_rate,
                                                      factors="all",
                                                      return_numpy=True)

        validation_pB = pB.copy()
        validation_pB[:, :, action[0]] += (
            l_rate * core.spm_cross(qs.values, qs_prev.values) *
            (B[:, :, action[0]].values > 0))
        self.assertTrue(np.all(pB_updated == validation_pB.values))
Exemplo n.º 3
0
 def test_copy(self):
     values = np.random.rand(3, 2)
     d = Dirichlet(values=values)
     d_copy = d.copy()
     self.assertTrue(np.array_equal(d_copy.values, d.values))
     d_copy.values = d_copy.values * 2
     self.assertFalse(np.array_equal(d_copy.values, d.values))
Exemplo n.º 4
0
    def test_update_pB_single_factor_no_actions(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that the one and only hidden state factor is updated, and there 
        are no actions.
        """

        n_states = [3]
        n_control = [
            1
        ]  # this is how we encode the fact that there aren't any actions
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        B = Categorical(
            values=np.random.rand(n_states[0], n_states[0], n_control[0]))
        B.normalize()
        pB = Dirichlet(values=np.ones_like(B.values))
        action = np.array([np.random.randint(nc) for nc in n_control])
        pB_updated = learning.update_transition_dirichlet(pB,
                                                          B,
                                                          action,
                                                          qs,
                                                          qs_prev,
                                                          lr=l_rate,
                                                          factors="all",
                                                          return_numpy=True)

        validation_pB = pB.copy()
        validation_pB[:, :, 0] += (l_rate *
                                   maths.spm_cross(qs.values, qs_prev.values) *
                                   (B[:, :, action[0]].values > 0))
        self.assertTrue(np.all(pB_updated == validation_pB.values))
Exemplo n.º 5
0
    def test_update_pB_multi_factor_no_actions_one_factor(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are no actions. One factor is updated
        """
        n_states = [3, 4]
        n_control = [1, 1]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        factors_to_update = [np.random.randint(len(n_states))]

        B = Categorical(values=np.array([
            np.random.rand(ns, ns, n_control[factor])
            for factor, ns in enumerate(n_states)
        ],
                                        dtype=object))
        B.normalize()
        pB = Dirichlet(values=np.array([
            np.ones_like(B[factor].values) for factor in range(len(n_states))
        ],
                                       dtype=object))

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = learning.update_transition_dirichlet(
            pB,
            B,
            action,
            qs,
            qs_prev,
            lr=l_rate,
            factors=factors_to_update,
            return_numpy=True)

        validation_pB = pB.copy()
        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            if factor in factors_to_update:
                validation_pB[:, :, action[factor]] += (
                    l_rate * maths.spm_cross(qs[factor].values,
                                             qs_prev[factor].values) *
                    (B[factor][:, :, action[factor]].values > 0))

            self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))
Exemplo n.º 6
0
    def test_update_pB_multi_factor_some_controllable_some_factors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and some of them 
        are controllable. Some factors are updated.
        """

        n_states = [3, 4, 5]
        n_control = [1, 3, 1]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        factors_to_update = [0, 1]
        B_values = np.empty(len(n_states), dtype=object)
        pB_values = np.empty(len(n_states), dtype=object)
        for factor, ns in enumerate(n_states):
            B_values[factor] = np.random.rand(ns, ns, n_control[factor])
            pB_values[factor] = np.ones((ns, ns, n_control[factor]))

        B = Categorical(values=B_values)
        B.normalize()
        pB = Dirichlet(values=pB_values)

        action = np.array([np.random.randint(nc) for nc in n_control])
        pB_updated = core.update_transition_dirichlet(
            pB,
            B,
            action,
            qs,
            qs_prev,
            lr=l_rate,
            factors=factors_to_update,
            return_numpy=True)

        validation_pB = pB.copy()
        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            if factor in factors_to_update:
                validation_pB[:, :, action[factor]] += (
                    l_rate *
                    core.spm_cross(qs[factor].values, qs_prev[factor].values) *
                    (B[factor][:, :, action[factor]].values > 0))
            self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))