def test_update_pA_multi_factor_some_modalities(self): """ Test for updating prior Dirichlet parameters over sensory likelihood (pA) in the case that SOME observation modalities are updated and the generative model has multiple hidden state factors """ n_states = [2, 6] qs = Categorical(values=construct_init_qs(n_states)) l_rate = 1.0 # multiple observation modalities num_obs = [3, 4, 5] modalities_to_update = [0, 2] A = Categorical(values=construct_generic_A(num_obs, n_states)) pA = Dirichlet(values=construct_pA(num_obs, n_states)) observation = A.dot(qs, return_numpy=False).sample() pA_updated = learning.update_likelihood_dirichlet( pA, A, observation, qs, lr=l_rate, modalities=modalities_to_update, return_numpy=True) for modality, no in enumerate(num_obs): if modality in modalities_to_update: update = maths.spm_cross( np.eye(no)[observation[modality]], qs.values) validation_pA = pA[modality] + l_rate * update else: validation_pA = pA[modality] self.assertTrue( np.all(pA_updated[modality] == validation_pA.values))
def test_update_pA_single_factor_all(self): """ Test for updating prior Dirichlet parameters over sensory likelihood (pA) in the case that all observation modalities are updated and the generative model has a single hidden state factor """ n_states = [3] qs = Categorical(values=construct_init_qs(n_states)) l_rate = 1.0 # single observation modality num_obs = [4] A = Categorical(values=construct_generic_A(num_obs, n_states)) pA = Dirichlet(values=construct_pA(num_obs, n_states)) observation = A.dot(qs, return_numpy=False).sample() pA_updated = learning.update_likelihood_dirichlet(pA, A, observation, qs, lr=l_rate, modalities="all", return_numpy=True) validation_pA = pA + l_rate * maths.spm_cross( np.eye(*num_obs)[observation], qs.values) self.assertTrue(np.all(pA_updated == validation_pA.values)) # multiple observation modalities num_obs = [3, 4] A = Categorical(values=construct_generic_A(num_obs, n_states)) pA = Dirichlet(values=construct_pA(num_obs, n_states)) observation = A.dot(qs, return_numpy=False).sample() pA_updated = learning.update_likelihood_dirichlet(pA, A, observation, qs, lr=l_rate, modalities="all", return_numpy=True) for modality, no in enumerate(num_obs): update = maths.spm_cross( np.eye(no)[observation[modality]], qs.values) validation_pA = pA[modality] + l_rate * update self.assertTrue( np.all(pA_updated[modality] == validation_pA.values))
def update_A(self, obs): pA_updated = learning.update_likelihood_dirichlet( self.pA, self.A, obs, self.qs, self.lr_pA, self.modalities_to_learn, return_numpy=False) self.pA = pA_updated self.A = pA_updated.mean() return pA_updated