Пример #1
0
    def test_update_pB_multi_factor_with_actions_all_factors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are actions. All factors are updated
        """

        n_states = [3, 4, 5]
        n_control = [3, 4, 5]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        B = Categorical(values=construct_generic_B(n_states, n_control))
        B.normalize()
        pB = Dirichlet(values=construct_pB(n_states, n_control))
        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,
                                                      B,
                                                      action,
                                                      qs,
                                                      qs_prev,
                                                      lr=l_rate,
                                                      factors="all",
                                                      return_numpy=True)

        validation_pB = pB.copy()
        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            validation_pB[:, :, action[factor]] += (
                l_rate *
                core.spm_cross(qs[factor].values, qs_prev[factor].values) *
                (B[factor][:, :, action[factor]].values > 0))
            self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))
Пример #2
0
    def test_update_pB_single_dactor_with_actions(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that the one and only hidden state factor is updated, and there 
        are actions.
        """

        n_states = [3]
        n_control = [3]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        B = Categorical(values=construct_generic_B(n_states, n_control))
        pB = Dirichlet(values=np.ones_like(B.values))
        action = np.array([np.random.randint(nc) for nc in n_control])
        pB_updated = core.update_transition_dirichlet(pB,
                                                      B,
                                                      action,
                                                      qs,
                                                      qs_prev,
                                                      lr=l_rate,
                                                      factors="all",
                                                      return_numpy=True)

        validation_pB = pB.copy()
        validation_pB[:, :, action[0]] += (
            l_rate * core.spm_cross(qs.values, qs_prev.values) *
            (B[:, :, action[0]].values > 0))
        self.assertTrue(np.all(pB_updated == validation_pB.values))
Пример #3
0
    def test_update_pB_multi_factor_no_actions_one_factor(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are no actions. One factor is updated
        """
        n_states = [3, 4]
        n_control = [1, 1]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        factors_to_update = [np.random.randint(len(n_states))]

        B = Categorical(values=np.array([
            np.random.rand(ns, ns, n_control[factor])
            for factor, ns in enumerate(n_states)
        ]))
        B.normalize()
        pB = Dirichlet(values=np.array([
            np.ones_like(B[factor].values) for factor in range(len(n_states))
        ]))

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(
            pB,
            B,
            action,
            qs,
            qs_prev,
            lr=l_rate,
            factors=factors_to_update,
            return_numpy=True)

        validation_pB = pB.copy()
        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            if factor in factors_to_update:
                validation_pB[:, :, action[factor]] += (
                    l_rate *
                    core.spm_cross(qs[factor].values, qs_prev[factor].values) *
                    (B[factor][:, :, action[factor]].values > 0))

            self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))
Пример #4
0
    def test_update_pB_multi_factor_some_controllable_some_factors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and some of them 
        are controllable. Some factors are updated.
        """

        n_states = [3, 4, 5]
        n_control = [1, 3, 1]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        factors_to_update = [0, 1]
        B_values = np.empty(len(n_states), dtype=object)
        pB_values = np.empty(len(n_states), dtype=object)
        for factor, ns in enumerate(n_states):
            B_values[factor] = np.random.rand(ns, ns, n_control[factor])
            pB_values[factor] = np.ones((ns, ns, n_control[factor]))

        B = Categorical(values=B_values)
        B.normalize()
        pB = Dirichlet(values=pB_values)

        action = np.array([np.random.randint(nc) for nc in n_control])
        pB_updated = core.update_transition_dirichlet(
            pB,
            B,
            action,
            qs,
            qs_prev,
            lr=l_rate,
            factors=factors_to_update,
            return_numpy=True)

        validation_pB = pB.copy()
        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            if factor in factors_to_update:
                validation_pB[:, :, action[factor]] += (
                    l_rate *
                    core.spm_cross(qs[factor].values, qs_prev[factor].values) *
                    (B[factor][:, :, action[factor]].values > 0))
            self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))