Ejemplo n.º 1
0
    def test_update_pA_multi_factor_some_modalities(self):
        """
        Test for updating prior Dirichlet parameters over sensory likelihood (pA)
        in the case that SOME observation modalities are updated and the generative model 
        has multiple hidden state factors
        """
        n_states = [2, 6]
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        # multiple observation modalities
        num_obs = [3, 4, 5]
        modalities_to_update = [0, 2]
        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))
        observation = A.dot(qs, return_numpy=False).sample()
        pA_updated = learning.update_likelihood_dirichlet(
            pA,
            A,
            observation,
            qs,
            lr=l_rate,
            modalities=modalities_to_update,
            return_numpy=True)

        for modality, no in enumerate(num_obs):
            if modality in modalities_to_update:
                update = maths.spm_cross(
                    np.eye(no)[observation[modality]], qs.values)
                validation_pA = pA[modality] + l_rate * update
            else:
                validation_pA = pA[modality]
            self.assertTrue(
                np.all(pA_updated[modality] == validation_pA.values))
Ejemplo n.º 2
0
    def test_update_pA_single_factor_all(self):
        """
        Test for updating prior Dirichlet parameters over sensory likelihood (pA)
        in the case that all observation modalities are updated and the generative model 
        has a single hidden state factor
        """
        n_states = [3]
        qs = Categorical(values=construct_init_qs(n_states))
        l_rate = 1.0

        # single observation modality
        num_obs = [4]
        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))

        observation = A.dot(qs, return_numpy=False).sample()
        pA_updated = learning.update_likelihood_dirichlet(pA,
                                                          A,
                                                          observation,
                                                          qs,
                                                          lr=l_rate,
                                                          modalities="all",
                                                          return_numpy=True)
        validation_pA = pA + l_rate * maths.spm_cross(
            np.eye(*num_obs)[observation], qs.values)
        self.assertTrue(np.all(pA_updated == validation_pA.values))

        # multiple observation modalities
        num_obs = [3, 4]
        A = Categorical(values=construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values=construct_pA(num_obs, n_states))
        observation = A.dot(qs, return_numpy=False).sample()
        pA_updated = learning.update_likelihood_dirichlet(pA,
                                                          A,
                                                          observation,
                                                          qs,
                                                          lr=l_rate,
                                                          modalities="all",
                                                          return_numpy=True)

        for modality, no in enumerate(num_obs):
            update = maths.spm_cross(
                np.eye(no)[observation[modality]], qs.values)
            validation_pA = pA[modality] + l_rate * update
            self.assertTrue(
                np.all(pA_updated[modality] == validation_pA.values))
Ejemplo n.º 3
0
    def update_A(self, obs):

        pA_updated = learning.update_likelihood_dirichlet(
            self.pA,
            self.A,
            obs,
            self.qs,
            self.lr_pA,
            self.modalities_to_learn,
            return_numpy=False)

        self.pA = pA_updated
        self.A = pA_updated.mean()

        return pA_updated