Beispiel #1
0
    def test_update_pB_multi_factor_with_actions_all_factors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are actions. All factors are updated
        """

        n_states = [3, 4, 5]
        n_control = [3, 4, 5]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        B = Categorical(values=construct_generic_B(n_states, n_control))
        B.normalize()
        pB = Dirichlet(values=construct_pB(n_states, n_control))
        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = learning.update_transition_dirichlet(pB,
                                                          B,
                                                          action,
                                                          qs,
                                                          qs_prev,
                                                          lr=l_rate,
                                                          factors="all",
                                                          return_numpy=True)

        validation_pB = pB.copy()
        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            validation_pB[:, :, action[factor]] += (
                l_rate *
                maths.spm_cross(qs[factor].values, qs_prev[factor].values) *
                (B[factor][:, :, action[factor]].values > 0))
            self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))
Beispiel #2
0
    def test_update_pB_single_dactor_with_actions(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that the one and only hidden state factor is updated, and there 
        are actions.
        """

        n_states = [3]
        n_control = [3]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        B = Categorical(values=construct_generic_B(n_states, n_control))
        pB = Dirichlet(values=np.ones_like(B.values))
        action = np.array([np.random.randint(nc) for nc in n_control])
        pB_updated = learning.update_transition_dirichlet(pB,
                                                          B,
                                                          action,
                                                          qs,
                                                          qs_prev,
                                                          lr=l_rate,
                                                          factors="all",
                                                          return_numpy=True)

        validation_pB = pB.copy()
        validation_pB[:, :, action[0]] += (
            l_rate * maths.spm_cross(qs.values, qs_prev.values) *
            (B[:, :, action[0]].values > 0))
        self.assertTrue(np.all(pB_updated == validation_pB.values))
Beispiel #3
0
    def test_update_pA_multi_factor_some_modalities(self):
        """
        Test for updating prior Dirichlet parameters over sensory likelihood (pA)
        in the case that SOME observation modalities are updated and the generative model 
        has multiple hidden state factors
        """
        n_states = [2, 6]
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        # multiple observation modalities
        num_obs = [3, 4, 5]
        modalities_to_update = [0, 2]
        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))
        observation = A.dot(qs, return_numpy=False).sample()
        pA_updated = learning.update_likelihood_dirichlet(
            pA,
            A,
            observation,
            qs,
            lr=l_rate,
            modalities=modalities_to_update,
            return_numpy=True)

        for modality, no in enumerate(num_obs):
            if modality in modalities_to_update:
                update = maths.spm_cross(
                    np.eye(no)[observation[modality]], qs.values)
                validation_pA = pA[modality] + l_rate * update
            else:
                validation_pA = pA[modality]
            self.assertTrue(
                np.all(pA_updated[modality] == validation_pA.values))
Beispiel #4
0
    def test_update_pA_single_factor_all(self):
        """
        Test for updating prior Dirichlet parameters over sensory likelihood (pA)
        in the case that all observation modalities are updated and the generative model 
        has a single hidden state factor
        """
        n_states = [3]
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        # single observation modality
        num_obs = [4]
        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))

        observation = A.dot(qs, return_numpy=False).sample()
        pA_updated = learning.update_likelihood_dirichlet(pA,
                                                          A,
                                                          observation,
                                                          qs,
                                                          lr=l_rate,
                                                          modalities="all",
                                                          return_numpy=True)
        validation_pA = pA + l_rate * maths.spm_cross(
            np.eye(*num_obs)[observation], qs.values)
        self.assertTrue(np.all(pA_updated == validation_pA.values))

        # multiple observation modalities
        num_obs = [3, 4]
        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))
        observation = A.dot(qs, return_numpy=False).sample()
        pA_updated = learning.update_likelihood_dirichlet(pA,
                                                          A,
                                                          observation,
                                                          qs,
                                                          lr=l_rate,
                                                          modalities="all",
                                                          return_numpy=True)

        for modality, no in enumerate(num_obs):
            update = maths.spm_cross(
                np.eye(no)[observation[modality]], qs.values)
            validation_pA = pA[modality] + l_rate * update
            self.assertTrue(
                np.all(pA_updated[modality] == validation_pA.values))
Beispiel #5
0
    def cross(self, x=None, return_numpy=False, *args):
        """ Multi-dimensional outer product
            
            If no `x` argument is passed, the function returns the "auto-outer product" 
            of self. Otherwise, the function will recursively take the outer product 
            of the initial entry of `x` with `self` until it has depleted the possible 
            entries of `x` that it can outer-product

        Parameters
        ----------
        - `x` [np.ndarray || [Categorical] (optional)
            The values to perform the outer-product with
        - `args` [np.ndarray] || Categorical] (optional)
            Perform the outer product of the `args` with self
       
        Returns
        -------
        - `y` [np.ndarray || Categorical]
            The result of the outer-product
        """
        x = utils.to_numpy(x)

        if x is not None:
            if len(args) > 0 and utils.is_distribution(args[0]):
                arg_array = []
                for arg in args:
                    arg_array.append(arg.values)
                y = maths.spm_cross(self.values, x, *arg_array)
            else:
                y = maths.spm_cross(self.values, x, *args)
        else:
            y = maths.spm_cross(self.values)

        if return_numpy:
            return y
        else:
            return Categorical(values=y)
Beispiel #6
0
    def test_update_pB_multi_factor_no_actions_one_factor(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are no actions. One factor is updated
        """
        n_states = [3, 4]
        n_control = [1, 1]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        factors_to_update = [np.random.randint(len(n_states))]

        B = Categorical(values=np.array([
            np.random.rand(ns, ns, n_control[factor])
            for factor, ns in enumerate(n_states)
        ],
                                        dtype=object))
        B.normalize()
        pB = Dirichlet(values=np.array([
            np.ones_like(B[factor].values) for factor in range(len(n_states))
        ],
                                       dtype=object))

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = learning.update_transition_dirichlet(
            pB,
            B,
            action,
            qs,
            qs_prev,
            lr=l_rate,
            factors=factors_to_update,
            return_numpy=True)

        validation_pB = pB.copy()
        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            if factor in factors_to_update:
                validation_pB[:, :, action[factor]] += (
                    l_rate * maths.spm_cross(qs[factor].values,
                                             qs_prev[factor].values) *
                    (B[factor][:, :, action[factor]].values > 0))

            self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))
Beispiel #7
0
    def test_update_pB_multi_factor_some_controllable_some_factors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and some of them 
        are controllable. Some factors are updated.
        """

        n_states = [3, 4, 5]
        n_control = [1, 3, 1]
        qs_prev = Categorical(values=construct_init_qs(n_states))
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        factors_to_update = [0, 1]
        B_values = np.empty(len(n_states), dtype=object)
        pB_values = np.empty(len(n_states), dtype=object)
        for factor, ns in enumerate(n_states):
            B_values[factor] = np.random.rand(ns, ns, n_control[factor])
            pB_values[factor] = np.ones((ns, ns, n_control[factor]))

        B = Categorical(values=B_values)
        B.normalize()
        pB = Dirichlet(values=pB_values)

        action = np.array([np.random.randint(nc) for nc in n_control])
        pB_updated = learning.update_transition_dirichlet(
            pB,
            B,
            action,
            qs,
            qs_prev,
            lr=l_rate,
            factors=factors_to_update,
            return_numpy=True)

        validation_pB = pB.copy()
        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            if factor in factors_to_update:
                validation_pB[:, :, action[factor]] += (
                    l_rate * maths.spm_cross(qs[factor].values,
                                             qs_prev[factor].values) *
                    (B[factor][:, :, action[factor]].values > 0))
            self.assertTrue(np.all(pB_updated[factor] == validation_pB.values))