def test_update_pB_multiFactor_withActions_someFactors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are actions. Some factors are updated
        """

        n_states = [3, 4, 5]
        n_control = [3, 4, 5]
        qs_prev = Categorical(values = construct_init_qs(n_states))
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0

        factors_to_update = [0,1]

        B = Categorical(values = construct_generic_B(n_states,n_control))
        B.normalize()
        pB = Dirichlet(values = construct_pB(n_states,n_control))

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,B,action,qs,qs_prev,lr=learning_rate,factors=factors_to_update,return_numpy=True)

        validation_pB = pB.copy()

        for factor, _ in enumerate(n_control):

            validation_pB = pB[factor].copy()

            if factor in factors_to_update:
                validation_pB[:,:,action[factor]] += learning_rate * core.spm_cross(qs[factor].values, qs_prev[factor].values) * (B[factor][:, :, action[factor]].values > 0)
                
            self.assertTrue(np.all(pB_updated[factor]==validation_pB.values))
    def test_update_pB_multiFactor_noActions_allFactors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and there 
        are no actions. All factors are updated
        """

        n_states = [3, 4]
        n_control = [1, 1] 
        qs_prev = Categorical(values = construct_init_qs(n_states))
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0

        B = Categorical(values = np.array([np.random.rand(ns, ns, n_control[factor]) for factor, ns in enumerate(n_states)]))
        B.normalize()
        pB = Dirichlet(values = np.array([np.ones_like(B[factor].values) for factor in range(len(n_states))]))

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,B,action,qs,qs_prev,lr=learning_rate,factors="all",return_numpy=True)

        validation_pB = pB.copy()

        for factor, _ in enumerate(n_control):
            validation_pB = pB[factor].copy()
            validation_pB[:,:,action[factor]] += learning_rate * core.spm_cross(qs[factor].values, qs_prev[factor].values) * (B[factor][:, :, action[factor]].values > 0)
            self.assertTrue(np.all(pB_updated[factor]==validation_pB.values))
    def test_update_pA_multiFactor_somemodalities(self):
        """
        Test for updating prior Dirichlet parameters over sensory likelihood (pA)
        in the case that SOME observation modalities are updated and the generative model 
        has multiple hidden state factors
        """

        n_states = [2, 6]
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0
    
        # multiple observation modalities
        num_obs = [3,4,5]

        modalities_to_update = [0, 2]

        A = Categorical(values = construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values = construct_pA(num_obs,n_states))

        observation = A.dot(qs,return_numpy=False).sample()

        pA_updated = core.update_likelihood_dirichlet(pA, A, observation, qs, lr=learning_rate, modalities=modalities_to_update,return_numpy=True)

        for modality, no in enumerate(num_obs):
            
            if modality in modalities_to_update:
                validation_pA = pA[modality] + learning_rate * core.spm_cross(np.eye(no)[observation[modality]], qs.values)
            else:
                validation_pA = pA[modality]
            self.assertTrue(np.all(pA_updated[modality]==validation_pA.values))
    def test_update_pA_singleFactor_all(self):
        """
        Test for updating prior Dirichlet parameters over sensory likelihood (pA)
        in the case that all observation modalities are updated and the generative model 
        has a single hidden state factor
        """

        n_states = [3]
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0
    
        # single observation modality
        num_obs = [4]

        A = Categorical(values = construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values = construct_pA(num_obs,n_states))

        observation = A.dot(qs,return_numpy=False).sample()

        pA_updated = core.update_likelihood_dirichlet(pA, A, observation, qs, lr=learning_rate, modalities="all",return_numpy=True)

        validation_pA = pA + learning_rate * core.spm_cross(np.eye(*num_obs)[observation], qs.values)
        self.assertTrue(np.all(pA_updated==validation_pA.values))

        # multiple observation modalities
        num_obs = [3,4]

        A = Categorical(values = construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values = construct_pA(num_obs,n_states))

        observation = A.dot(qs,return_numpy=False).sample()

        pA_updated = core.update_likelihood_dirichlet(pA, A, observation, qs, lr=learning_rate, modalities="all",return_numpy=True)

        for modality, no in enumerate(num_obs):

            validation_pA = pA[modality] + learning_rate * core.spm_cross(np.eye(no)[observation[modality]], qs.values)
            self.assertTrue(np.all(pA_updated[modality]==validation_pA.values))
    def test_update_pB_multiFactor_someControllable_someFactors(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that there are mulitple hidden state factors, and some of them 
        are controllable. Some factors are updated.
        """

        n_states = [3, 4, 5]
        n_control = [1, 3, 1]
        qs_prev = Categorical(values = construct_init_qs(n_states))
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0

        factors_to_update = [0,1]

        B_values = np.empty(len(n_states),dtype=object)
        pB_values = np.empty(len(n_states),dtype=object)
        for factor, ns in enumerate(n_states):
            B_values[factor] = np.random.rand(ns, ns, n_control[factor])
            pB_values[factor] = np.ones( (ns, ns, n_control[factor]) )

        B = Categorical(values = B_values)
        B.normalize()
        pB = Dirichlet(values = pB_values)

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,B,action,qs,qs_prev,lr=learning_rate,factors=factors_to_update,return_numpy=True)

        validation_pB = pB.copy()

        for factor, _ in enumerate(n_control):

            validation_pB = pB[factor].copy()

            if factor in factors_to_update:
                validation_pB[:,:,action[factor]] += learning_rate * core.spm_cross(qs[factor].values, qs_prev[factor].values) * (B[factor][:, :, action[factor]].values > 0)
                
            self.assertTrue(np.all(pB_updated[factor]==validation_pB.values))
    def test_update_pB_singleFactor_withActions(self):
        """
        Test for updating prior Dirichlet parameters over transition likelihood (pB)
        in the case that the one and only hidden state factor is updated, and there 
        are actions.
        """

        n_states = [3]
        n_control = [3] 
        qs_prev = Categorical(values = construct_init_qs(n_states))
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0

        B = Categorical(values = construct_generic_B(n_states, n_control))
        pB = Dirichlet(values = np.ones_like(B.values))

        action = np.array([np.random.randint(nc) for nc in n_control])

        pB_updated = core.update_transition_dirichlet(pB,B,action,qs,qs_prev,lr=learning_rate,factors="all",return_numpy=True)

        validation_pB = pB.copy()
        validation_pB[:,:,action[0]] += learning_rate * core.spm_cross(qs.values, qs_prev.values) * (B[:, :, action[0]].values > 0)
        self.assertTrue(np.all(pB_updated==validation_pB.values))