def test_dot_function_e_cat(self):
        """ CONTINUOUS states and outcomes, but add a final (fourth) hidden state factor
        Now, when arguments themselves are instances of Categorical """
        array_path = os.path.join(os.getcwd(), "tests/data/dot_e.mat")
        mat_contents = loadmat(file_name=array_path)

        A = mat_contents["A"]
        obs = Categorical(values=mat_contents["o"])
        states = mat_contents["s"]
        states_array_version = np.empty(states.shape[1], dtype=object)
        for i in range(states.shape[1]):
            states_array_version[i] = states[0][i][0]
        states_array_version = Categorical(values=states_array_version)

        result_1 = mat_contents["result1"]
        result_2 = mat_contents["result2"]
        result_3 = mat_contents["result3"]

        A = Categorical(values=A)
        result_1_py = A.dot(obs, return_numpy=True)
        self.assertTrue(np.isclose(result_1, result_1_py).all())

        result_2_py = A.dot(states_array_version, return_numpy=True)
        result_2_py = result_2_py.astype("float64")[:, np.newaxis]
        self.assertTrue(np.isclose(result_2, result_2_py).all())

        result_3_py = A.dot(states_array_version,
                            dims_to_omit=[0],
                            return_numpy=True)
        self.assertTrue(np.isclose(result_3, result_3_py).all())
    def test_dot_function_c(self):
        """ DISCRETE states and outcomes, but also a third hidden state factor """
        array_path = os.path.join(os.getcwd(), "tests/data/dot_c.mat")
        mat_contents = loadmat(file_name=array_path)

        A = mat_contents["A"]
        obs = mat_contents["o"]
        states = mat_contents["s"]
        states_array_version = np.empty(states.shape[1], dtype=object)
        for i in range(states.shape[1]):
            states_array_version[i] = states[0][i][0]

        result_1 = mat_contents["result1"]
        result_2 = mat_contents["result2"]
        result_3 = mat_contents["result3"]

        A = Categorical(values=A)
        result_1_py = A.dot(obs, return_numpy=True)
        self.assertTrue(np.isclose(result_1, result_1_py).all())

        result_2_py = A.dot(states_array_version, return_numpy=True)
        result_2_py = result_2_py.astype("float64")[:, np.newaxis]
        self.assertTrue(np.isclose(result_2, result_2_py).all())

        result_3_py = A.dot(states_array_version,
                            dims_to_omit=[0],
                            return_numpy=True)
        self.assertTrue(np.isclose(result_3, result_3_py).all())
    def test_dot_function_a_cat(self):
        """ test with vectors and matrices, discrete state / outcomes
        Now, when arguments themselves are instances of Categorical
        """

        array_path = os.path.join(os.getcwd(), "tests/data/dot_a.mat")
        mat_contents = loadmat(file_name=array_path)

        A = mat_contents["A"]
        obs = Categorical(values=mat_contents["o"])
        states = Categorical(values=mat_contents["s"][0])
        result_1 = mat_contents["result1"]
        result_2 = mat_contents["result2"]
        result_3 = mat_contents["result3"]

        A = Categorical(values=A)
        result_1_py = A.dot(obs, return_numpy=True)
        self.assertTrue(np.isclose(result_1, result_1_py).all())

        result_2_py = A.dot(states, return_numpy=True)
        result_2_py = result_2_py.astype("float64")[:, np.newaxis]
        self.assertTrue(np.isclose(result_2, result_2_py).all())

        result_3_py = A.dot(states, dims_to_omit=[0], return_numpy=True)
        self.assertTrue(np.isclose(result_3, result_3_py).all())
    def test_update_pA_multiFactor_somemodalities(self):
        """
        Test for updating prior Dirichlet parameters over sensory likelihood (pA)
        in the case that SOME observation modalities are updated and the generative model 
        has multiple hidden state factors
        """

        n_states = [2, 6]
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0
    
        # multiple observation modalities
        num_obs = [3,4,5]

        modalities_to_update = [0, 2]

        A = Categorical(values = construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values = construct_pA(num_obs,n_states))

        observation = A.dot(qs,return_numpy=False).sample()

        pA_updated = core.update_likelihood_dirichlet(pA, A, observation, qs, lr=learning_rate, modalities=modalities_to_update,return_numpy=True)

        for modality, no in enumerate(num_obs):
            
            if modality in modalities_to_update:
                validation_pA = pA[modality] + learning_rate * core.spm_cross(np.eye(no)[observation[modality]], qs.values)
            else:
                validation_pA = pA[modality]
            self.assertTrue(np.all(pA_updated[modality]==validation_pA.values))
    def test_dot_function_f(self):
        """ Test for when the outcome modality is a trivially one-dimensional vector, meaning
        the return of spm_dot is a scalar - this tests that the spm_dot function
        successfully wraps such scalar returns into an array """

        states = np.empty(2, dtype=object)
        states[0] = np.array([0.75, 0.25])
        states[1] = np.array([0.5, 0.5])
        No = 1
        A = Categorical(values=np.ones([No] + list(states.shape)))
        A.normalize()

        # return the result as a Categorical
        result_cat = A.dot(states, return_numpy=False)
        self.assertTrue(np.prod(result_cat.shape) == 1)

        # return the result as a numpy array
        result_np = A.dot(states, return_numpy=True)
        self.assertTrue(np.prod(result_np.shape) == 1)
    def test_update_pA_singleFactor_all(self):
        """
        Test for updating prior Dirichlet parameters over sensory likelihood (pA)
        in the case that all observation modalities are updated and the generative model 
        has a single hidden state factor
        """

        n_states = [3]
        qs = Categorical(values = construct_init_qs(n_states))
        learning_rate = 1.0
    
        # single observation modality
        num_obs = [4]

        A = Categorical(values = construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values = construct_pA(num_obs,n_states))

        observation = A.dot(qs,return_numpy=False).sample()

        pA_updated = core.update_likelihood_dirichlet(pA, A, observation, qs, lr=learning_rate, modalities="all",return_numpy=True)

        validation_pA = pA + learning_rate * core.spm_cross(np.eye(*num_obs)[observation], qs.values)
        self.assertTrue(np.all(pA_updated==validation_pA.values))

        # multiple observation modalities
        num_obs = [3,4]

        A = Categorical(values = construct_generic_A(num_obs, n_states))
        pA = Dirichlet(values = construct_pA(num_obs,n_states))

        observation = A.dot(qs,return_numpy=False).sample()

        pA_updated = core.update_likelihood_dirichlet(pA, A, observation, qs, lr=learning_rate, modalities="all",return_numpy=True)

        for modality, no in enumerate(num_obs):

            validation_pA = pA[modality] + learning_rate * core.spm_cross(np.eye(no)[observation[modality]], qs.values)
            self.assertTrue(np.all(pA_updated[modality]==validation_pA.values))
    def test_dot_function_b(self):
        """ continuous states and outcomes """
        array_path = os.path.join(os.getcwd(), "tests/data/dot_b.mat")
        mat_contents = loadmat(file_name=array_path)

        A = mat_contents["A"]
        obs = mat_contents["o"]
        states = mat_contents["s"]
        states = np.array(states, dtype=object)

        result_1 = mat_contents["result1"]
        result_2 = mat_contents["result2"]
        result_3 = mat_contents["result3"]

        A = Categorical(values=A)
        result_1_py = A.dot(obs, return_numpy=True)
        self.assertTrue(np.isclose(result_1, result_1_py).all())

        result_2_py = A.dot(states, return_numpy=True)
        result_2_py = result_2_py.astype("float64")[:, np.newaxis]
        self.assertTrue(np.isclose(result_2, result_2_py).all())

        result_3_py = A.dot(states, dims_to_omit=[0], return_numpy=True)
        self.assertTrue(np.isclose(result_3, result_3_py).all())