def calc_states_info_gain(A, qs_pi):
    Given a likelihood mapping A and a posterior predictive density over states Qs_pi,
    compute the Bayesian surprise (about states) expected under that policy
    A [numpy nd-array, array-of-arrays (where each entry is a numpy nd-array), or Categorical (either single-factor of AoA)]:
        Observation likelihood mapping from hidden states to observations, with different modalities (if there are multiple) stored in different arrays
    qs_pi [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), Categorical (either single-factor or AoA), or list]:
        Posterior predictive density over hidden states. If a list, each entry of the list is the posterior predictive for a given timepoint of an expected trajectory
    states_surprise [scalar]:
        Surprise (about states) expected under the policy in question

    A = utils.to_numpy(A)

    if isinstance(qs_pi, list):
        n_steps = len(qs_pi)
        for t in range(n_steps):
            qs_pi[t] = utils.to_numpy(qs_pi[t], flatten=True)
        n_steps = 1
        qs_pi = [utils.to_numpy(qs_pi, flatten=True)]

    states_surprise = 0

    for t in range(n_steps):
        states_surprise += spm_MDP_G(A, qs_pi[t])

    return states_surprise
def calc_expected_utility(qo_pi, C):
    Given expected observations under a policy Qo_pi and a prior over observations C
    compute the expected utility of the policy.

    qo_pi [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), Categorical (either single-factor or AoA), or list]:
        Expected observations under the given policy (predictive posterior over outcomes). If a list, a list of the expected observations
        over the time horizon of policy evaluation, where each entry is the expected observations at a given timestep. 
    C [numpy nd-array, array-of-arrays (where each entry is a numpy nd-array):
        Prior beliefs over outcomes, expressed in terms of relative log probabilities
    expected_util [scalar]:
        Utility (reward) expected under the policy in question

    if isinstance(qo_pi, list):
        n_steps = len(qo_pi)
        for t in range(n_steps):
            qo_pi[t] = utils.to_numpy(qo_pi[t], flatten=True)
        n_steps = 1
        qo_pi = [utils.to_numpy(qo_pi, flatten=True)]

    C = utils.to_numpy(C, flatten=True)

    # initialise expected utility
    expected_util = 0

    # in case of multiple observation modalities, loop over time points and modalities
    if utils.is_arr_of_arr(C):

        num_modalities = len(C)

        for t in range(n_steps):
            for modality in range(num_modalities):
                lnC = np.log(softmax(C[modality][:, np.newaxis]) + 1e-16)
                expected_util += qo_pi[t][modality].dot(lnC)

    # else, just loop over time (since there's only one modality)

        lnC = np.log(softmax(C[:, np.newaxis]) + 1e-16)

        for t in range(n_steps):
            lnC = np.log(softmax(C[:, np.newaxis] + 1e-16))
            expected_util += qo_pi[t].dot(lnC)

    return expected_util
def process_observations(obs, n_modalities, n_observations):
    Helper function for formatting observations    

        Observations can either be `Categorical`, `int` (converted to one-hot)
        or `tuple` (obs for each modality)
    @TODO maybe provide error messaging about observation format
    if utils.is_distribution(obs):
        obs = utils.to_numpy(obs)
        if n_modalities == 1:
            obs = obs.squeeze()
            for m in range(n_modalities):
                obs[m] = obs[m].squeeze()

    if isinstance(obs, (int, np.integer)):
        obs = np.eye(n_observations[0])[obs]

    if isinstance(obs, tuple):
        obs_arr_arr = np.empty(n_modalities, dtype=object)
        for m in range(n_modalities):
            obs_arr_arr[m] = np.eye(n_observations[m])[obs[m]]
        obs = obs_arr_arr

    return obs
def sample_action(q_pi, policies, n_control, sampling_type="marginal_action"):
    Samples action from posterior over policies, using one of two methods. 
    q_pi [1D numpy.ndarray or Categorical]:
        Posterior beliefs about (possibly multi-step) policies.
    policies [list of numpy ndarrays]:
        List of arrays that indicate the policies under consideration. Each element within the list is a matrix that stores the 
        the indices of the actions  upon the separate hidden state factors, at each timestep (nStep x nControlFactor)
    n_control [list of integers]:
        List of the dimensionalities of the different (controllable)) hidden state factors
    sampling_type [string, 'marginal_action' or 'posterior_sample']:
        Indicates whether the sampled action for a given hidden state factor is given by the evidence for that action, marginalized across different policies ('marginal_action')
        or simply the action entailed by a sample from the posterior over policies
    selectedPolicy [1D numpy ndarray]:
        Numpy array containing the indices of the actions along each control factor

    n_factors = len(n_control)

    if sampling_type == "marginal_action":

        if utils.is_distribution(q_pi):
            q_pi = utils.to_numpy(q_pi)

        action_marginals = np.empty(n_factors, dtype=object)
        for c_idx in range(n_factors):
            action_marginals[c_idx] = np.zeros(n_control[c_idx])

        # weight each action according to its integrated posterior probability over policies and timesteps
        for pol_idx, policy in enumerate(policies):
            for t in range(policy.shape[0]):
                for factor_i, action_i in enumerate(policy[t, :]):
                    action_marginals[factor_i][action_i] += q_pi[pol_idx]

        action_marginals = Categorical(values=action_marginals)
        selected_policy = np.array(action_marginals.sample())

    elif sampling_type == "posterior_sample":
        if utils.is_distribution(q_pi):
            policy_index = q_pi.sample()
            selected_policy = policies[policy_index]
            q_pi = Categorical(values=q_pi)
            policy_index = q_pi.sample()
            selected_policy = policies[policy_index]

    return selected_policy
def softmax(dist, return_numpy=True):
    Computes the softmax function on a set of values

    dist = utils.to_numpy(dist)

    output = []
    if utils.is_arr_of_arr(dist):
        for i in range(len(dist.values)):
            output.append(softmax(dist[i]), return_numpy=True)

    output = dist - dist.max(axis=0)
    output = np.exp(output)
    output = output / np.sum(output, axis=0)
    if return_numpy:
        return output
        return utils.to_categorical(output)
    def cross(self, x=None, return_numpy=False, *args):
        """ Multi-dimensional outer product
            @NOTE see `spm_cross` in core.maths
            If no `x` argument is passed, the function returns the "auto-outer product" of self
            Otherwise, the function will recursively take the outer product of the initial entry
            of `x` with `self` until it has depleted the possible entries of `x` that it can outer-product

            @TODO explain the concept of `args` in a clearer fashion 

        - `x` [np.ndarray || [Categorical] (optional)
            The values to perform the outer-product with
        - `args` [np.ndarray] || Categorical] (optional)
            Perform the outer product of the `args` with self
        - `y` [np.ndarray || Categorical]
            The result of the outer-product
        x = utils.to_numpy(x)

        if x is not None:
            if len(args) > 0 and utils.is_distribution(args[0]):
                arg_array = []
                for arg in args:
                y = maths.spm_cross(self.values, x, *arg_array)
                y = maths.spm_cross(self.values, x, *args)

        if return_numpy:
            return y
            return Categorical(values=y)
    def dot(self, x, dims_to_omit=None, return_numpy=False, obs_mode=False):
        """ Dot product of a this distribution with `x`
            @NOTE see `spm_dot` in core.maths
            @TODO create better workaround for `obs_mode`

            The dimensions in `dims_to_omit` will not be summed across during the dot product
        - `x` [1D np.ndarray || Categorical]
            The array to perform the dot product with
        - `dims_to_omit` [list of ints] (optional)
            Which dimensions to omit
        - `return_numpy` [bool] (optional)
            Whether to return `np.ndarray` or `Categorical` - defaults to `Categorical`
        - 'obs_mode' [bool] (optional)
            Whether to perform the inner product of `x` with the leading dimension of self
            @NOTE We call this `obs_mode` because it's often used to get the likelihood of an observation (leading dimension)
                  under different settings of hidden states (lagging dimensions)
        x = utils.to_numpy(x)

        # perform dot product on each sub-array
        if self.IS_AOA:
            y = np.empty(self.n_arrays, dtype=object)
            for i in range(self.n_arrays):
                y[i] = maths.spm_dot(self[i].values, x, dims_to_omit, obs_mode)
            y = maths.spm_dot(self.values, x, dims_to_omit, obs_mode)

        if return_numpy:
            return y
            return Categorical(values=y)
def spm_dot(X, y, dims_to_omit=None, obs_mode=False):
    """ Dot product of a multidimensional array `X` with `y`
    The dimensions in `dims_to_omit` will not be summed across during  dot product
    @TODO: we need documentation describing `obs_mode`
        Ideally, we could find a way to avoid it altogether 

    `y` [1D numpy.ndarray] 
        Either vector or array of arrays
    `dims_to_omit` [list :: int] (optional) 
        Which dimensions to omit

    X = utils.to_numpy(X)
    y = utils.to_numpy(y)

    # if `X` is array of array, we need to construct the dims to sum
    if utils.is_arr_of_arr(X):
        dims = (np.arange(0, len(y)) + X.ndim - len(y)).astype(int)
        Deal with particular use case - see above @TODO 
        if obs_mode is True:
            Case when you're getting the likelihood of an observation under model.
            Equivalent to something like self.values[np.where(x),:]
            where `y` is a discrete 'one-hot' observation vector
            dims = np.array([0], dtype=int)
            Case when `y` leading dimension matches the lagging dimension of `values`
            E.g. a more 'classical' dot product of a likelihood with hidden states
            dims = np.array([1], dtype=int)

        # convert `y` to array of array
        y = utils.to_arr_of_arr(y)

    # omit dims not needed for dot product
    if dims_to_omit is not None:
        if not isinstance(dims_to_omit, list):
            raise ValueError("`dims_to_omit` must be a `list` of `int`")

        # delete dims
        dims = np.delete(dims, dims_to_omit)
        if len(y) == 1:
            y = np.empty([0], dtype=object)
            y = np.delete(y, dims_to_omit)

    # perform dot product
    for d in range(len(y)):
        s = np.ones(np.ndim(X), dtype=int)
        s[dims[d]] = np.shape(y[d])[0]
        X = X * y[d].reshape(tuple(s))
        X = np.sum(X, axis=dims[d], keepdims=True)
    X = np.squeeze(X)

    # perform check to see if `x` is a scalar
    if np.prod(X.shape) <= 1.0:
        X = X.item()
        X = np.array([X]).astype("float64")

    return X
def update_posterior_states(A, obs, prior=None, method=FPI, return_numpy=True):
    Update marginal posterior over hidden states using variational inference
        Can optionally set message passing algorithm used for inference
    - 'A' [numpy nd.array (matrix or tensor or array-of-arrays) or Categorical]:
        Observation likelihood of the generative model, mapping from hidden states to observations
        Used to invert generative model to obtain marginal likelihood over hidden states, given the observation
    - 'obs' [numpy 1D array, array of arrays (with 1D numpy array entries), int or tuple]:
        The observation (generated by the environment). If single modality, this can be a 1D array 
        (one-hot vector representation) or an int (observation index)
        If multi-modality, this can be an array of arrays (whose entries are 1D one-hot vectors) or a tuple (of observation indices)
    - 'prior' [numpy 1D array, array of arrays (with 1D numpy array entries), Categorical, or None]:
        Prior beliefs about hidden states, to be integrated with the marginal likelihood to obtain a posterior distribution. 
        If None, prior is set to be equal to a flat categorical distribution (at the level of the individual inference functions).
    - 'return_numpy' [bool]:
        True/False flag to determine whether the posterior is returned as a numpy array or a Categorical
    - 'method' [str]:
        Algorithm used to perform the variational inference. 
        Options: 'FPI' - Fixed point iteration 
                    - http://www.cs.cmu.edu/~guestrin/Class/10708/recitations/r9/VI-view.pdf, slides 13- 18
                    - http://citeseerx.ist.psu.edu/viewdoc/download?doi=, slides 24 - 38
                 'VMP  - Variational message passing (not implemented)
                 'MMP' - Marginal message passing (not implemented)
                 'BP'  - Belief propagation (not implemented)
                 'EP'  - Expectation propagation (not implemented)
                 'CV'  - CLuster variation method (not implemented)
    **kwargs: List of keyword/parameter arguments corresponding to parameter values for the respective variational inference algorithm

    - 'qs' [numpy 1D array, array of arrays (with 1D numpy array entries), or Categorical]:
        Marginal posterior beliefs over hidden states 

    # safe convert to numpy
    A = utils.to_numpy(A)

    # collect model dimensions
    if utils.is_arr_of_arr(A):
        n_factors = A[0].ndim - 1
        n_states = list(A[0].shape[1:])
        n_modalities = len(A)
        n_observations = []
        for m in range(n_modalities):
        n_factors = A.ndim - 1
        n_states = list(A.shape[1:])
        n_modalities = 1
        n_observations = [A.shape[0]]

    obs = process_observations(obs, n_modalities, n_observations)
    if prior is not None:
        prior = process_priors(prior, n_factors)

    if method is FPI:
        qs = run_fpi(A, obs, n_observations, n_states, prior)
    elif method is VMP:
        raise NotImplementedError(f"{VMP} is not implemented")
    elif method is MMP:
        raise NotImplementedError(f"{MMP} is not implemented")
    elif method is BP:
        raise NotImplementedError(f"{BP} is not implemented")
    elif method is EP:
        raise NotImplementedError(f"{EP} is not implemented")
    elif method is CV:
        raise NotImplementedError(f"{CV} is not implemented")
        raise ValueError(f"{method} is not implemented")

    if return_numpy:
        return qs
        return utils.to_categorical(qs)
def calc_pB_info_gain(pB, qs_pi, qs_prev, policy):
    Compute expected Dirichlet information gain about parameters pB under a given policy
    pB [numpy nd-array, array-of-arrays (where each entry is a numpy nd-array), or Dirichlet (either single-factor of AoA)]:
        Prior dirichlet parameters parameterizing beliefs about the likelihood describing transitions bewteen hidden states,
        with different factors (if there are multiple) stored in different arrays.
    qs_pi [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), Categorical (either single-factor or AoA), or list]:
        Posterior predictive density over hidden states. If a list, each entry of the list is the posterior predictive for a given timepoint of an expected trajectory
    qs_prev [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), or Categorical (either single-factor or AoA)]:
        Posterior over hidden states (before getting observations)
    policy [numpy 2D ndarray, of size n_steps x n_control_factors]:
        Policy to consider. Each row of the matrix encodes the action index along a different control factor for a given timestep.  
    infogain_pB [scalar]:
        Surprise (about dirichlet parameters) expected under the policy in question

    if isinstance(qs_pi, list):
        n_steps = len(qs_pi)
        for t in range(n_steps):
            qs_pi[t] = utils.to_numpy(qs_pi[t], flatten=True)
        n_steps = 1
        qs_pi = [utils.to_numpy(qs_pi, flatten=True)]

    if isinstance(qs_prev, Categorical):
        qs_prev = utils.to_numpy(qs_prev, flatten=True)

    if isinstance(pB, Dirichlet):
        if pB.IS_AOA:
            num_factors = pB.n_arrays
            num_factors = 1
        wB = pB.expectation_of_log()
        if utils.is_arr_of_arr(pB):
            num_factors = len(pB)
            wB = np.empty(num_factors, dtype=object)
            for factor in range(num_factors):
                wB[factor] = spm_wnorm(pB[factor])
            num_factors = 1
            wB = spm_wnorm(pB)

    pB = utils.to_numpy(pB)

    pB_infogain = 0

    if num_factors == 1:

        for t in range(n_steps):

            if t == 0:
                previous_qs = qs_prev
                previous_qs = qs_pi[t - 1]

            a_i = policy[t, 0]

            wB_t = wB[:, :, a_i] * (pB[:, :, a_i] > 0).astype("float")
            pB_infogain = -qs_pi[t].dot(wB_t.dot(qs_prev))

        for t in range(n_steps):

            # the 'past posterior' used for the information gain about pB here is the posterior over expected states at the timestep previous to the one under consideration
            if (
                    t == 0
            ):  # if we're on the first timestep, we just use the latest posterior in the entire action-perception cycle as the previous posterior
                previous_qs = qs_prev
            else:  # otherwise, we use the expected states for the timestep previous to the timestep under consideration
                previous_qs = qs_pi[t - 1]

            policy_t = policy[
                t, :]  # get the list of action-indices for the current timestep

            for factor, a_i in enumerate(policy_t):
                wB_factor_t = wB[factor][:, :, a_i] * (pB[factor][:, :, a_i] >
                pB_infogain -= qs_pi[t][factor].dot(

    return pB_infogain
def calc_pA_info_gain(pA, qo_pi, qs_pi):
    Compute expected Dirichlet information gain about parameters pA under a policy
    pA [numpy nd-array, array-of-arrays (where each entry is a numpy nd-array), or Dirichlet (either single-factor of AoA)]:
        Prior dirichlet parameters parameterizing beliefs about the likelihood mapping from hidden states to observations, 
        with different modalities (if there are multiple) stored in different arrays.
    qo_pi [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), Categorical (either single-factor or AoA), or list]:
        Expected observations. If a list, each entry of the list is the posterior predictive for a given timepoint of an expected trajectory
    qs_pi [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), Categorical (either single-factor or AoA), or list]:
        Posterior predictive density over hidden states. If a list, each entry of the list is the posterior predictive for a given timepoint of an expected trajectory
    infogain_pA [scalar]:
        Surprise (about dirichlet parameters) expected under the policy in question

    if isinstance(qo_pi, list):
        n_steps = len(qo_pi)
        for t in range(n_steps):
            qo_pi[t] = utils.to_numpy(qo_pi[t], flatten=True)
        n_steps = 1
        qo_pi = [utils.to_numpy(qo_pi, flatten=True)]

    if isinstance(qs_pi, list):
        for t in range(n_steps):
            qs_pi[t] = utils.to_numpy(qs_pi[t], flatten=True)
        n_steps = 1
        qs_pi = [utils.to_numpy(qs_pi, flatten=True)]

    if isinstance(pA, Dirichlet):
        if pA.IS_AOA:
            num_modalities = pA.n_arrays
            num_modalities = 1
        wA = pA.expectation_of_log()
        if utils.is_arr_of_arr(pA):
            num_modalities = len(pA)
            wA = np.empty(num_modalities, dtype=object)
            for modality in range(num_modalities):
                wA[modality] = spm_wnorm(pA[modality])
            num_modalities = 1
            wA = spm_wnorm(pA)

    pA = utils.to_numpy(pA)

    pA_infogain = 0

    if num_modalities == 1:
        wA = wA * (pA > 0).astype("float")
        for t in range(n_steps):
            pA_infogain = -qo_pi[t].dot(spm_dot(wA, qs_pi[t])[:, np.newaxis])
        for modality in range(num_modalities):
            wA_modality = wA[modality] * (pA[modality] > 0).astype("float")
            for t in range(n_steps):
                pA_infogain -= qo_pi[t][modality].dot(
                    spm_dot(wA_modality, qs_pi[t])[:, np.newaxis])

    return pA_infogain
def get_expected_obs(qs_pi, A, return_numpy=False):
    Given a posterior predictive density Qs_pi and an observation likelihood model A,
    get the expected observations given the predictive posterior.

    qs_pi [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), Categorical (either single-factor or AoA), or list]:
        Posterior predictive density over hidden states. If a list, each entry of the list is the posterior predictive for a given timepoint of an expected trajectory
    A [numpy nd-array, array-of-arrays (where each entry is a numpy nd-array), or Categorical (either single-factor of AoA)]:
        Observation likelihood mapping from hidden states to observations, with different modalities (if there are multiple) stored in different arrays
    return_numpy [Boolean]:
        True/False flag to determine whether output of function is a numpy array or a Categorical
    qo_pi [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), Categorical (either single-factor or AoA), or list]:
        Expected observations under the given policy. If a list, a list of the expected observations over the time horizon of policy evaluation, where
        each entry is the expected observations at a given timestep. 

    # initialise expected observations
    qo_pi = []
    A = utils.to_numpy(A)

    if isinstance(qs_pi, list):
        n_steps = len(qs_pi)
        for t in range(n_steps):
            qs_pi[t] = utils.to_numpy(qs_pi[t], flatten=True)
        n_steps = 1
        qs_pi = [utils.to_numpy(qs_pi, flatten=True)]

    if utils.is_arr_of_arr(A):

        num_modalities = len(A)

        for t in range(n_steps):
            qo_pi_t = np.empty(num_modalities, dtype=object)

        # get expected observations over time
        for t in range(n_steps):
            for modality in range(num_modalities):
                qo_pi[t][modality] = spm_dot(A[modality], qs_pi[t])


        # get expected observations over time
        for t in range(n_steps):
            qo_pi.append(spm_dot(A, qs_pi[t]))

    if return_numpy:
        if n_steps == 1:
            return qo_pi[0]
            return qo_pi
        if n_steps == 1:
            return utils.to_categorical(qo_pi[0])
            for t in range(n_steps):
                qo_pi[t] = utils.to_categorical(qo_pi[t])
            return qo_pi
def get_expected_states(qs, B, policy, return_numpy=False):
    Given a posterior density qs, a transition likelihood model B, and a policy, 
    get the state distribution expected under that policy's pursuit

    - `qs` [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), or Categorical (either single-factor or AoA)]:
        Current posterior beliefs about hidden states
    - `B` [numpy nd-array, array-of-arrays (where each entry is a numpy nd-array), or Categorical (either single-factor of AoA)]:
        Transition likelihood mapping from states at t to states at t + 1, with different actions (per factor) stored along the lagging dimension
   - `policy` [np.arrays]:
        np.array of size (policy_len x n_factors) where each value corrresponds to a control state
    - `return_numpy` [Boolean]:
        True/False flag to determine whether output of function is a numpy array or a Categorical
    - `qs_pi` [ list of np.arrays with len n_steps, where in case of multiple hidden state factors, each np.array in the list is a 1 x n_factors array-of-arrays, otherwise a list of 1D numpy arrays]:
        Expected states under the given policy - also known as the 'posterior predictive density'


    n_steps = policy.shape[0]
    n_factors = policy.shape[1]

    qs = utils.to_numpy(qs, flatten=True)
    B = utils.to_numpy(B)

    # initialise beliefs over expected states
    qs_pi = []

    if utils.is_arr_of_arr(B):

        for t in range(n_steps):
            qs_pi_t = np.empty(n_factors, dtype=object)

        # initialise expected states after first action using current posterior (t = 0)
        for control_factor, control in enumerate(policy[0, :]):
            qs_pi[0][control_factor] = spm_dot(
                B[control_factor][:, :, control], qs[control_factor])

        # get expected states over time
        if n_steps > 1:
            for t in range(1, n_steps):
                for control_factor, control in enumerate(policy[t, :]):
                    qs_pi[t][control_factor] = spm_dot(
                        B[control_factor][:, :, control],
                        qs_pi[t - 1][control_factor])


        # initialise expected states after first action using current posterior (t = 0)
        qs_pi.append(spm_dot(B[:, :, policy[0, 0]], qs))

        # then loop over future timepoints
        if n_steps > 1:
            for t in range(1, n_steps):
                qs_pi.append(spm_dot(B[:, :, policy[t, 0]], qs_pi[t - 1]))

    if return_numpy:
        if len(qs_pi) == 1:
            return qs_pi[0]
            return qs_pi
        if len(qs_pi) == 1:
            return utils.to_categorical(qs_pi[0])
            for t in range(n_steps):
                qs_pi[t] = utils.to_categorical(qs_pi[t])
            return qs_pi
def update_transition_dirichlet(pB,
    Update Dirichlet parameters that parameterize the transition model of the generative model 
    (describing the probabilistic mapping between hidden states over time).

   -  pB [numpy nd.array, array-of-arrays (with np.ndarray entries), or Dirichlet (either single-modality or AoA)]:
        The prior Dirichlet parameters of the generative model, parameterizing the agent's beliefs about the transition likelihood. 
    - B [numpy nd.array, object-like array of arrays, or Categorical (either single-modality or AoA)]:
        The transition likelihood of the generative model. 
    - actions [tuple]:
        A tuple containing the action(s) performed at a given timestep.
    - Qs_curr [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), or Categorical (either single-factor or AoA)]:
        Current marginal posterior beliefs about hidden state factors
    - Qs_prev [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), or Categorical (either single-factor or AoA)]:
        Past marginal posterior beliefs about hidden state factors
    - eta [float, optional]:
        Learning rate.
    - return_numpy [bool, optional]:
        Logical flag to determine whether output is a numpy array or a Dirichlet
    - which_factors [list, optional]:
        Indices (in terms of range(Nf)) of the hidden state factors to include in learning.
        Defaults to 'all', meaning that transition likelihood matrices for all hidden state factors
        are updated as a function of transitions in the different control factors (i.e. actions)

    pB = utils.to_numpy(pB)

    if utils.is_arr_of_arr(pB):
        n_factors = len(pB)
        n_factors = 1

    if return_numpy:
        pB_updated = pB.copy()
        pB_updated = utils.to_dirichlet(pB.copy())

    if not utils.is_distribution(qs):
        qs = utils.to_categorical(qs)

    if factors == "all":
        if n_factors == 1:
            db = qs.cross(qs_prev, return_numpy=True)
            db = db * (B[:, :, actions[0]] > 0).astype("float")
            pB_updated = pB_updated + (lr * db)

        elif n_factors > 1:
            for f in range(n_factors):
                db = qs[f].cross(qs_prev[f], return_numpy=True)
                db = db * (B[f][:, :, actions[f]] > 0).astype("float")
                pB_updated[f] = pB_updated[f] + (lr * db)
        for f_idx in factors:
            db = qs[f_idx].cross(qs_prev[f_idx], return_numpy=True)
            db = db * (B[f_idx][:, :, actions[f_idx]] > 0).astype("float")
            pB_updated[f_idx] = pB_updated[f_idx] + (lr * db)

    return pB_updated
def update_likelihood_dirichlet(pA,
    """ Update Dirichlet parameters of the likelihood distribution 

    - pA [numpy nd.array, array-of-arrays (with np.ndarray entries), or Dirichlet (either single-modality or AoA)]:
        The prior Dirichlet parameters of the generative model, parameterizing the agent's beliefs about the observation likelihood. 
    - A [numpy nd.array, object-like array of arrays, or Categorical (either single-modality or AoA)]:
        The observation likelihood of the generative model. 
    - obs [numpy 1D array, array-of-arrays (with 1D numpy array entries), int or tuple]:
        A discrete observation used in the update equation
    - Qx [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), or Categorical (either single-factor or AoA)]:
            Current marginal posterior beliefs about hidden state factors
    - lr [float, optional]:
            Learning rate.
    - return_numpy [bool, optional]:
        Logical flag to determine whether output is a numpy array or a Dirichlet
    - modalities [list, optional]:
        Indices (in terms of range(n_modalities)) of the observation modalities to include in learning.
        Defaults to 'all, meaning that observation likelihood matrices for all modalities
        are updated as a function of observations in the different modalities.

    pA = utils.to_numpy(pA)

    if utils.is_arr_of_arr(pA):
        n_modalities = len(pA)
        n_observations = [pA[m].shape[0] for m in range(n_modalities)]
        n_modalities = 1
        n_observations = [pA.shape[0]]

    if return_numpy:
        pA_updated = pA.copy()
        pA_updated = utils.to_dirichlet(pA.copy())

    # observation index
    if isinstance(obs, (int, np.integer)):
        obs = np.eye(A.shape[0])[obs]

    # observation indices
    elif isinstance(obs, tuple):
        obs = np.array(
            [np.eye(n_observations[g])[obs[g]] for g in range(n_modalities)],

    # convert to Categorical to make the cross product easier
    obs = utils.to_categorical(obs)

    if modalities == "all":
        if n_modalities == 1:
            da = obs.cross(qs, return_numpy=True)
            da = da * (A > 0).astype("float")
            pA_updated = pA_updated + (lr * da)

        elif n_modalities > 1:
            for g in range(n_modalities):
                da = obs[g].cross(qs, return_numpy=True)
                da = da * (A[g] > 0).astype("float")
                pA_updated[g] = pA_updated[g] + (lr * da)
        for g_idx in modalities:
            da = obs[g_idx].cross(qs, return_numpy=True)
            da = da * (A[g_idx] > 0).astype("float")
            pA_updated[g_idx] = pA_updated[g_idx] + (lr * da)

    return pA_updated