Exemplo n.º 1
0
def calc_expected_utility(qo_pi, C):
    """
    Given expected observations under a policy Qo_pi and a prior over observations C
    compute the expected utility of the policy.

    Parameters
    ----------
    qo_pi [numpy 1D array, array-of-arrays (where each entry is a numpy 1D array), 
    Categorical (either single-factor or AoA), or list]:
        Expected observations under the given policy (predictive posterior over outcomes). 
        If a list, a list of the expected observations
        over the time horizon of policy evaluation, where each entry is the expected 
        observations at a given timestep. 
    C [numpy nd-array, array-of-arrays (where each entry is a numpy nd-array):
        Prior beliefs over outcomes, expressed in terms of relative log probabilities
    Returns
    -------
    expected_util [scalar]:
        Utility (reward) expected under the policy in question
    """
    if isinstance(qo_pi, list):
        n_steps = len(qo_pi)
        for t in range(n_steps):
            qo_pi[t] = utils.to_numpy(qo_pi[t], flatten=True)
    else:
        n_steps = 1
        qo_pi = [utils.to_numpy(qo_pi, flatten=True)]

    C = utils.to_numpy(C, flatten=True)

    # initialise expected utility
    expected_util = 0

    # in case of multiple observation modalities, loop over time points and modalities
    if utils.is_arr_of_arr(C):
        num_modalities = len(C)
        for t in range(n_steps):
            for modality in range(num_modalities):
                lnC = np.log(softmax(C[modality][:, np.newaxis]) + 1e-16)
                expected_util += qo_pi[t][modality].dot(lnC)

    # else, just loop over time (since there's only one modality)
    else:
        lnC = np.log(softmax(C[:, np.newaxis]) + 1e-16)
        for t in range(n_steps):
            lnC = np.log(softmax(C[:, np.newaxis] + 1e-16))
            expected_util += qo_pi[t].dot(lnC)

    return expected_util
Exemplo n.º 2
0
def infer_action(qs, A, B, C, n_control, policies):
    n_policies = len(policies)

    # negative expected free energy
    neg_G = np.zeros([n_policies, 1])

    for i, policy in enumerate(policies):
        neg_G[i] = evaluate_policy(policy, qs, A, B, C)

    # get distribution over policies
    q_pi = maths.softmax(neg_G)

    # probabilites of control states
    qu = Categorical(dims=n_control)

    # sum probabilites of controls
    for i, policy in enumerate(policies):
        # control state specified by policy
        u = int(policy[0, :])
        # add probability of policy
        qu[u] += q_pi[i]

    # normalize
    qu.normalize()

    # sample control
    u = qu.sample()

    return u
Exemplo n.º 3
0
def run_fpi_faster(A,
                   obs,
                   n_observations,
                   n_states,
                   prior=None,
                   num_iter=10,
                   dF=1.0,
                   dF_tol=0.001):
    """
    Update marginal posterior beliefs about hidden states
    using a new version of variational fixed point iteration (FPI). 
    @NOTE (Conor, 26.02.2020):
    This method uses a faster algorithm than the traditional 'spm_dot' approach. Instead of
    separately computing a conditional joint log likelihood of an outcome, under the
    posterior probabilities of a certain marginal, instead all marginals are multiplied into one 
    joint tensor that gives the joint likelihood of an observation under all hidden states, 
    that is then sequentially (and *parallelizably*) marginalized out to get each marginal posterior. 
    This method is less RAM-intensive, admits heavy parallelization, and runs (about 2x) faster.
    @NOTE (Conor, 28.02.2020):
    After further testing, discovered interesting differences between this version and the 
    original version. It appears that the
    original version (simple 'run_FPI') shows mean-field biases or 'explaining away' 
    effects, whereas this version spreads probabilities more 'fairly' among possibilities.
    To summarize: it actually matters what order you do the summing across the joint likelihood tensor. 
    In this verison, all marginals are multiplied into the likelihood tensor before summing out, 
    whereas in the previous version, marginals are recursively multiplied and summed out.
    @NOTE (Conor, 24.04.2020): I would expect that the factor_order approach used above would help 
    ameliorate the effects of the mean-field bias. I would also expect that the use of a factor_order 
    below is unnnecessary, since the marginalisation w.r.t. each factor is done only after all marginals 
    are multiplied into the larger tensor.

    Parameters
    ----------
    - 'A' [numpy nd.array (matrix or tensor or array-of-arrays)]:
        Observation likelihood of the generative model, mapping from hidden states to observations. 
        Used to invert generative model to obtain marginal likelihood over hidden states, 
        given the observation
    - 'obs' [numpy 1D array or array of arrays (with 1D numpy array entries)]:
        The observation (generated by the environment). If single modality, this can be a 1D array 
        (one-hot vector representation). If multi-modality, this can be an array of arrays 
        (whose entries are 1D one-hot vectors).
    - 'n_observations' [int or list of ints]
    - 'n_states' [int or list of ints]
    - 'prior' [numpy 1D array, array of arrays (with 1D numpy array entries) or None]:
        Prior beliefs of the agent, to be integrated with the marginal likelihood to obtain posterior. 
        If absent, prior is set to be a uniform distribution over hidden states 
        (identical to the initialisation of the posterior)
    -'num_iter' [int]:
        Number of variational fixed-point iterations to run.
    -'dF' [float]:
        Starting free energy gradient (dF/dt) before updating in the course of gradient descent.
    -'dF_tol' [float]:
        Threshold value of the gradient of the variational free energy (dF/dt), 
        to be checked at each iteration. If dF <= dF_tol, the iterations are halted pre-emptively 
        and the final marginal posterior belief(s) is(are) returned
    Returns
    ----------
    -'qs' [numpy 1D array or array of arrays (with 1D numpy array entries):
        Marginal posterior beliefs over hidden states (single- or multi-factor) achieved 
        via variational fixed point iteration (mean-field)
    """

    # get model dimensions
    n_modalities = len(n_observations)
    n_factors = len(n_states)
    """
    =========== Step 1 ===========
        Loop over the observation modalities and use assumption of independence 
        among observation modalities to multiply each modality-specific likelihood 
        onto a single joint likelihood over hidden factors [size n_states]
    """

    # likelihood = np.ones(tuple(n_states))

    # if n_modalities is 1:
    #     likelihood *= spm_dot(A, obs, obs_mode=True)
    # else:
    #     for modality in range(n_modalities):
    #         likelihood *= spm_dot(A[modality], obs[modality], obs_mode=True)
    likelihood = get_joint_likelihood(A, obs, n_states)
    likelihood = np.log(likelihood + 1e-16)
    """
    =========== Step 2 ===========
        Create a flat posterior (and prior if necessary)
    """

    qs = np.empty(n_factors, dtype=object)
    for factor in range(n_factors):
        qs[factor] = np.ones(n_states[factor]) / n_states[factor]
    """
    If prior is not provided, initialise prior to be identical to posterior 
    (namely, a flat categorical distribution). Take the logarithm of it 
    (required for FPI algorithm below).
    """
    if prior is None:
        prior = np.empty(n_factors, dtype=object)
        for factor in range(n_factors):
            prior[factor] = np.log(
                np.ones(n_states[factor]) / n_states[factor] + 1e-16)
    """
    =========== Step 3 ===========
        Initialize initial free energy
    """
    prev_vfe = calc_free_energy(qs, prior, n_factors)
    """
    =========== Step 4 ===========
        If we have a single factor, we can just add prior and likelihood,
        otherwise we run FPI
    """

    if n_factors == 1:
        qL = spm_dot(likelihood, qs, [0])
        return softmax(qL + prior[0])

    else:
        """
        =========== Step 5 ===========
        Run the revised fixed-point iteration scheme
        """

        curr_iter = 0

        while curr_iter < num_iter and dF >= dF_tol:
            # Initialise variational free energy
            vfe = 0

            # List of orders in which marginal posteriors are sequentially
            # multiplied into the joint likelihood: First order loops over
            # factors starting at index = 0, second order goes in reverse
            factor_orders = [range(n_factors), range((n_factors - 1), -1, -1)]

            for factor_order in factor_orders:
                # reset the log likelihood
                L = likelihood.copy()

                # multiply each marginal onto a growing single joint distribution
                for factor in factor_order:
                    s = np.ones(np.ndim(L), dtype=int)
                    s[factor] = len(qs[factor])
                    L *= qs[factor].reshape(tuple(s))

                # now loop over factors again, and this time divide out the
                # appropriate marginal before summing out.
                # !!! KEY DIFFERENCE BETWEEN THIS AND 'VANILLA' FPI,
                # WHERE THE ORDER OF THE MARGINALIZATION MATTERS !!!
                for f in factor_order:
                    s = np.ones(np.ndim(L), dtype=int)
                    s[factor] = len(qs[factor])  # type: ignore

                    # divide out the factor we multiplied into X already
                    temp = L * (1.0 / qs[factor]).reshape(
                        tuple(s))  # type: ignore
                    dims2sum = tuple(np.where(np.arange(n_factors) != f)[0])
                    qL = np.sum(temp, dims2sum)

                    temp = L * (1.0 / qs[factor]).reshape(
                        tuple(s))  # type: ignore
                    qs[factor] = softmax(qL + prior[factor])  # type: ignore

            # calculate new free energy
            vfe = calc_free_energy(qs, prior, n_factors, likelihood)

            # stopping condition - time derivative of free energy
            dF = np.abs(prev_vfe - vfe)
            prev_vfe = vfe

            curr_iter += 1

        return qs
Exemplo n.º 4
0
def run_fpi(A,
            obs,
            n_observations,
            n_states,
            prior=None,
            num_iter=10,
            dF=1.0,
            dF_tol=0.001):
    """
    Update marginal posterior beliefs about hidden states using variational fixed point iteration (FPI).
   
    Parameters
    ----------
    - 'A' [numpy nd.array (matrix or tensor or array-of-arrays)]:
        Observation likelihood of the generative model, mapping from hidden states to observations. 
        Used to invert generative model to obtain marginal likelihood over hidden states, given 
        the observation
    - 'obs' [numpy 1D array or array of arrays (with 1D numpy array entries)]:
        The observation (generated by the environment). If single modality, this can be a 
        1D array (one-hot vector representation).
        If multi-modality, this can be an array of arrays (whose entries are 1D one-hot vectors).
    - 'n_observations' [int or list of ints]
    - 'n_states' [int or list of ints]
    - 'prior' [numpy 1D array, array of arrays (with 1D numpy array entries) or None]:
        Prior beliefs of the agent, to be integrated with the marginal likelihood to obtain posterior. 
        If absent, prior is set to be a uniform distribution over hidden states (identical to the 
        initialisation of the posterior)
    -'num_iter' [int]:
        Number of variational fixed-point iterations to run.
    -'dF' [float]:
        Starting free energy gradient (dF/dt) before updating in the course of gradient descent.
    -'dF_tol' [float]:
        Threshold value of the gradient of the variational free energy (dF/dt), to be checked at 
        each iteration. If 
        dF <= dF_tol, the iterations are halted pre-emptively and the final 
        marginal posterior belief(s) is(are) returned
  
    Returns
    ----------
    -'qs' [numpy 1D array or array of arrays (with 1D numpy array entries):
        Marginal posterior beliefs over hidden states (single- or multi-factor) 
        achieved via variational fixed point iteration (mean-field)
    """

    # get model dimensions
    n_modalities = len(n_observations)
    n_factors = len(n_states)
    """
    =========== Step 1 ===========
        Loop over the observation modalities and use assumption of independence 
        among observation modalitiesto multiply each modality-specific likelihood 
        onto a single joint likelihood over hidden factors [size n_states]
    """

    likelihood = get_joint_likelihood(A, obs, n_states)

    likelihood = np.log(likelihood + 1e-16)
    """
    =========== Step 2 ===========
        Create a flat posterior (and prior if necessary)
    """

    qs = np.empty(n_factors, dtype=object)
    for factor in range(n_factors):
        qs[factor] = np.ones(n_states[factor]) / n_states[factor]
    """
    If prior is not provided, initialise prior to be identical to posterior 
    (namely, a flat categorical distribution). Take the logarithm of it (required for 
    FPI algorithm below).
    """
    if prior is None:
        prior = np.empty(n_factors, dtype=object)
        for factor in range(n_factors):
            prior[factor] = np.log(
                np.ones(n_states[factor]) / n_states[factor] + 1e-16)
    """
    =========== Step 3 ===========
        Initialize initial free energy
    """
    prev_vfe = calc_free_energy(qs, prior, n_factors)
    """
    =========== Step 4 ===========
        If we have a single factor, we can just add prior and likelihood,
        otherwise we run FPI
    """

    if n_factors == 1:
        qL = spm_dot(likelihood, qs, [0])
        return softmax(qL + prior[0])

    else:
        """
        =========== Step 5 ===========
        Run the FPI scheme
        """

        curr_iter = 0
        while curr_iter < num_iter and dF >= dF_tol:
            # Initialise variational free energy
            vfe = 0

            # arg_list = [likelihood, list(range(n_factors))]
            # arg_list = arg_list + list(chain(*([qs_i,[i]] for i, qs_i in enumerate(qs)))) + [list(range(n_factors))]
            # LL_tensor = np.einsum(*arg_list)

            qs_all = qs[0]
            for factor in range(n_factors - 1):
                qs_all = qs_all[..., None] * qs[factor + 1]
            LL_tensor = likelihood * qs_all

            for factor, qs_i in enumerate(qs):
                # qL = np.einsum(LL_tensor, list(range(n_factors)), 1.0/qs_i, [factor], [factor])
                qL = np.einsum(LL_tensor, list(range(n_factors)),
                               [factor]) / qs_i
                qs[factor] = softmax(qL + prior[factor])

            # List of orders in which marginal posteriors are sequentially multiplied into the joint likelihood:
            # First order loops over factors starting at index = 0, second order goes in reverse
            # factor_orders = [range(n_factors), range((n_factors - 1), -1, -1)]

            # iteratively marginalize out each posterior marginal from the joint log-likelihood
            # except for the one associated with a given factor
            # for factor_order in factor_orders:
            #     for factor in factor_order:
            #         qL = spm_dot(likelihood, qs, [factor])
            #         qs[factor] = softmax(qL + prior[factor])

            # calculate new free energy
            vfe = calc_free_energy(qs, prior, n_factors, likelihood)

            # stopping condition - time derivative of free energy
            dF = np.abs(prev_vfe - vfe)
            prev_vfe = vfe

            curr_iter += 1

        return qs
Exemplo n.º 5
0
"""

# reset environment
first_state = env.reset(init_state=s[0])

# get an observation, given the state
first_obs = gp_likelihood[:, first_state]

# turn observation into an index
first_obs = np.where(first_obs)[0][0]

print("Initial Location {}".format(first_state))
print("Initial Observation {}".format(first_obs))

# infer initial state, given first observation
qs = maths.softmax(A[first_obs, :].log() + D.log(), return_numpy=False)

# loop over time
for t in range(T):

    qs_past = qs.copy()

    s_t = env.set_state(s[t + 1])

    # evoke observation, given the state
    obs = gp_likelihood[:, s_t]

    # turn observation into an index
    obs = np.where(obs)[0][0]

    # get transition likelihood
Exemplo n.º 6
0
def run_mmp(lh_seq,
            B,
            policy,
            prev_actions=None,
            prior=None,
            num_iter=10,
            grad_descent=False,
            tau=0.25,
            last_timestep=False,
            save_vfe_seq=False):
    """
    Marginal message passing scheme for updating posterior beliefs about multi-factor hidden states over time, 
    conditioned on a particular policy.
    Parameters:
    --------------
    `lh_seq`[numpy object array]:
        Likelihoods of hidden state factors given a sequence of observations over time. This is logged beforehand
    `B`[numpy object array]:
        Transition likelihood of the generative model, mapping from hidden states at T to hidden states at T+1. One B matrix per modality (e.g. `B[f]` corresponds to f-th factor's B matrix)
        This is used in inference to compute the 'forward' and 'backward' messages conveyed between beliefs about temporally-adjacent timepoints.
    `policy` [2-D numpy.ndarray]:
        Matrix of shape (policy_len, num_control_factors) that indicates the indices of each action (control state index) upon timestep t and control_factor f in the element `policy[t,f]` for a given policy.
    `prev_actions` [None or 2-D numpy.ndarray]:
        If provided, should be a matrix of previous actions of shape (infer_len, num_control_factors) taht indicates the indices of each action (control state index) taken in the past (up until the current timestep).
    `prior`[None or numpy object array]:
        If provided, this a numpy object array with one sub-array per hidden state factor, that stores the prior beliefs about initial states (at t = 0, relative to `infer_len`).
    `num_iter`[Int]:
        Number of variational iterations
    `grad_descent` [Bool]:
        Flag for whether to use gradient descent (predictive coding style)
    `tau` [Float]:
        Decay constant for use in `grad_descent` version
    `last_timestep` [Bool]:
        Flag for whether we are at the last timestep of belief updating
    `save_vfe_seq` [Bool]:
        Flag for whether to save the sequence of variational free energies over time (for this policy). If `False`, then VFE is integrated across time/iterations.
    Returns:
    --------------
    `qs_seq`[list]: the sequence of beliefs about the different hidden state factors over time, one multi-factor posterior belief per timestep in `infer_len`
    `F`[Float or list, depending on setting of save_vfe_seq]
    """

    # window
    past_len = len(lh_seq)
    future_len = policy.shape[0]

    if last_timestep:
        infer_len = past_len + future_len - 1
    else:
        infer_len = past_len + future_len

    future_cutoff = past_len + future_len - 2

    # dimensions
    _, num_states, _, num_factors = get_model_dimensions(A=None, B=B)
    B = to_arr_of_arr(B)

    # beliefs
    qs_seq = obj_array(infer_len)
    for t in range(infer_len):
        qs_seq[t] = obj_array_uniform(num_states)

    # last message
    qs_T = obj_array_zeros(num_states)

    # prior
    if prior is None:
        prior = obj_array_uniform(num_states)

    # transposed transition
    trans_B = obj_array(num_factors)

    for f in range(num_factors):
        trans_B[f] = spm_norm(np.swapaxes(B[f], 0, 1))

    # full policy
    if prev_actions is None:
        prev_actions = np.zeros((past_len, policy.shape[1]))
    policy = np.vstack((prev_actions, policy))

    # initialise variational free energy of policy (accumulated over time)

    if save_vfe_seq:
        F = []
        F.append(0.0)
    else:
        F = 0.0

    for itr in range(num_iter):
        for t in range(infer_len):
            for f in range(num_factors):
                # likelihood
                if t < past_len:
                    lnA = spm_log(spm_dot(lh_seq[t], qs_seq[t], [f]))
                else:
                    lnA = np.zeros(num_states[f])

                # past message
                if t == 0:
                    lnB_past = spm_log(prior[f])
                else:
                    past_msg = B[f][:, :, int(policy[t - 1,
                                                     f])].dot(qs_seq[t - 1][f])
                    lnB_past = spm_log(past_msg)

                # future message
                if t >= future_cutoff:
                    lnB_future = qs_T[f]
                else:
                    future_msg = trans_B[f][:, :, int(policy[t, f])].dot(
                        qs_seq[t + 1][f])
                    lnB_future = spm_log(future_msg)

                # inference
                if grad_descent:
                    lnqs = spm_log(qs_seq[t][f])
                    coeff = 1 if (t >= future_cutoff) else 2
                    err = (coeff * lnA + lnB_past + lnB_future) - coeff * lnqs
                    err -= err.mean()
                    lnqs = lnqs + tau * err
                    qs_seq[t][f] = softmax(lnqs)
                    if (t == 0) or (t == (infer_len - 1)):
                        F += +0.5 * lnqs.dot(0.5 * err)
                    else:
                        F += lnqs.dot(
                            0.5 * (err - (num_factors - 1) * lnA / num_factors)
                        )  # @NOTE: not sure why Karl does this in SPM_MDP_VB_X, we should look into this
                else:
                    qs_seq[t][f] = softmax(lnA + lnB_past + lnB_future)

            if not grad_descent:

                if save_vfe_seq:
                    if t < past_len:
                        F.append(
                            F[-1] +
                            calc_free_energy(qs_seq[t],
                                             prior,
                                             num_factors,
                                             likelihood=spm_log(lh_seq[t]))[0])
                    else:
                        F.append(
                            F[-1] +
                            calc_free_energy(qs_seq[t], prior, num_factors)[0])
                else:
                    if t < past_len:
                        F += calc_free_energy(qs_seq[t],
                                              prior,
                                              num_factors,
                                              likelihood=spm_log(lh_seq[t]))
                    else:
                        F += calc_free_energy(qs_seq[t], prior, num_factors)

    return qs_seq, F
Exemplo n.º 7
0
action_names = ["uncontrolled", "decision_state"]

num_obs = [3, 3, 3]
num_states = [2, 3]
num_modalities = len(num_obs)
num_factors = len(num_states)

A = utils.obj_array_zeros([[o] + num_states for _, o in enumerate(num_obs)])

A[0][:, :, 0] = np.ones((num_obs[0], num_states[0])) / num_obs[0]
A[0][:, :, 1] = np.ones((num_obs[0], num_states[0])) / num_obs[0]
A[0][:, :, 2] = np.array([[0.8, 0.2], [0.0, 0.0], [0.2, 0.8]])

A[1][2, :, 0] = np.ones(num_states[0])
A[1][0:2, :, 1] = softmax(
    np.eye(num_obs[1] - 1)
)  # bandit statistics (mapping between reward-state (first hidden state factor) and rewards (Good vs Bad))
A[1][2, :, 2] = np.ones(num_states[0])

# establish a proprioceptive mapping that determines how the agent perceives its own `decision_state`
A[2][0, :, 0] = 1.0
A[2][1, :, 1] = 1.0
A[2][2, :, 2] = 1.0

control_fac_idx = [1]
B = utils.obj_array(num_factors)
for f, ns in enumerate(num_states):
    B[f] = np.eye(ns)
    if f in control_fac_idx:
        B[f] = B[f].reshape(ns, ns, 1)
        B[f] = np.tile(B[f], (1, 1, ns))
Exemplo n.º 8
0
def update_posterior_policies(
    qs,
    A,
    B,
    C,
    policies,
    use_utility=True,
    use_states_info_gain=True,
    use_param_info_gain=False,
    pA=None,
    pB=None,
    gamma=16.0,
    return_numpy=True,
):
    """ Updates the posterior beliefs about policies based on expected free energy prior

        @TODO: Needs to be amended for use with multi-step policies (where possible_policies is a 
        list of np.arrays (n_step x n_factor), not just a list of tuples as it is now)

        Parameters
        ----------
        - `qs` [1D numpy array, array-of-arrays, or Categorical (either single- or multi-factor)]:
            Current marginal beliefs about hidden state factors
        - `A` [numpy ndarray, array-of-arrays (in case of multiple modalities), or Categorical 
                (both single and multi-modality)]:
            Observation likelihood model (beliefs about the likelihood mapping entertained by the agent)
        - `B` [numpy ndarray, array-of-arrays (in case of multiple hidden state factors), or Categorical 
                (both single and multi-factor)]:
                Transition likelihood model (beliefs about the likelihood mapping entertained by the agent)
        - `C` [numpy 1D-array, array-of-arrays (in case of multiple modalities), or Categorical 
                (both single and multi-modality)]:
            Prior beliefs about outcomes (prior preferences)
        - `policies` [list of tuples]:
            A list of all the possible policies, each expressed as a tuple of indices, where a given 
            index corresponds to an action on a particular hidden state factor e.g. policies[1][2] yields the 
            index of the action under policy 1 that affects hidden state factor 2
        - `use_utility` [bool]:
            Whether to calculate utility term, i.e how much expected observation confer with prior expectations
        - `use_states_info_gain` [bool]:
            Whether to calculate state information gain
        - `use_param_info_gain` [bool]:
            Whether to calculate parameter information gain @NOTE requires pA or pB to be specified 
        - `pA` [numpy ndarray, array-of-arrays (in case of multiple modalities), or Dirichlet 
                (both single and multi-modality)]:
            Prior dirichlet parameters for A. Defaults to none, in which case info gain w.r.t. Dirichlet 
            parameters over A is skipped.
        - `pB` [numpy ndarray, array-of-arrays (in case of multiple hidden state factors), or 
            Dirichlet (both single and multi-factor)]:
            Prior dirichlet parameters for B. Defaults to none, in which case info gain w.r.t. 
            Dirichlet parameters over A is skipped.
        - `gamma` [float, defaults to 16.0]:
            Precision over policies, used as the inverse temperature parameter of a softmax transformation 
            of the expected free energies of each policy
        - `return_numpy` [Boolean]:
            True/False flag to determine whether output of function is a numpy array or a Categorical
        
        Returns
        --------
        - `qp` [1D numpy array or Categorical]:
            Posterior beliefs about policies, defined here as a softmax function of the 
            expected free energies of policies
        - `efe` - [1D numpy array or Categorical]:
            The expected free energies of policies

    """
    n_policies = len(policies)
    efe = np.zeros(n_policies)
    q_pi = np.zeros((n_policies, 1))

    for idx, policy in enumerate(policies):
        qs_pi = get_expected_states(qs, B, policy)
        qo_pi = get_expected_obs(qs_pi, A)

        if use_utility:
            efe[idx] += calc_expected_utility(qo_pi, C)

        if use_states_info_gain:
            efe[idx] += calc_states_info_gain(A, qs_pi)

        if use_param_info_gain:
            if pA is not None:
                efe[idx] += calc_pA_info_gain(pA, qo_pi, qs_pi)
            if pB is not None:
                efe[idx] += calc_pB_info_gain(pB, qs_pi, qs, policy)

    q_pi = softmax(efe * gamma)

    if return_numpy:
        q_pi = q_pi / q_pi.sum(axis=0)  # type: ignore
    else:
        q_pi = utils.to_categorical(q_pi)
        q_pi.normalize()

    return q_pi, efe
Exemplo n.º 9
0
def run_mmp_old(
    A,
    B,
    obs_t,
    policy,
    curr_t,
    t_horizon,
    T,
    qs_bma=None,
    prior=None,
    num_iter=10,
    dF=1.0,
    dF_tol=0.001,
    previous_actions=None,
    use_gradient_descent=False,
    tau=0.25,
):
    """
    Optimise marginal posterior beliefs about hidden states using marginal message-passing scheme (MMP) developed
    by Thomas Parr and colleagues, see https://github.com/tejparr/nmpassing
   
    Parameters
    ----------
    - 'A' [numpy nd.array (matrix or tensor or array-of-arrays)]:
        Observation likelihood of the generative model, mapping from hidden states to observations. 
        Used in inference to get the likelihood of an observation, under different hidden state configurations.
    - 'B' [numpy.ndarray (tensor or array-of-arrays)]:
        Transition likelihood of the generative model, mapping from hidden states at t to hidden states at t+1.
        Used in inference to get expected future (or past) hidden states, given past (or future) hidden 
        states (or expectations thereof).
    - 'obs_t' [list of length t_horizon of numpy 1D array or array of arrays (with 1D numpy array entries)]:
        Sequence of observations sampled from beginning of time horizon the current timestep t. 
        The first observation (the start of the time horizon) is either the first timestep of the generative 
        process or the first timestep of the policy horizon (whichever is closer to 'curr_t' in time).
        The observations over time are stored as a list of numpy arrays, where in case of multi-modalities 
        each numpy array is an array-of-arrays, with one 1D numpy.ndarray for each modality. 
        In the case of a single modality, each observation is a single 1D numpy.ndarray.
    - 'policy' [2D np.ndarray]:
        Array of actions constituting a single policy. Policy is a shape 
        (n_steps, n_control_factors) numpy.ndarray, the values of which indicate actions along a given control 
        factor (column index) at a given timestep (row index).
    - 'curr_t' [int]:
        Current timestep (relative to the 'absolute' time of the generative process).
    - 't_horizon'[int]:
        Temporal horizon of inference for states and policies.
    - 'T' [int]:
        Temporal horizon of the generative process (absolute time)
    - `qs_bma` [numpy 1D array, array of arrays (with 1D numpy array entries) or None]:
    - 'prior' [numpy 1D array, array of arrays (with 1D numpy array entries) or None]:
        Prior beliefs of the agent at the beginning of the time horizon, to be integrated 
        with the marginal likelihood to obtain posterior at the first timestep.
        If absent, prior is set to be a uniform distribution over hidden states (identical to the 
        initialisation of the posterior.
    -'num_iter' [int]:
        Number of variational iterations to run. (optional)
    -'dF' [float]:
        Starting free energy gradient (dF/dt) before updating in the course of gradient descent. (optional)
    -'dF_tol' [float]:
        Threshold value of the gradient of the variational free energy (dF/dt), to be checked 
        at each iteration. If dF <= dF_tol, the iterations are halted pre-emptively and the final 
        marginal posterior belief(s) is(are) returned.  (optional)
    -'previous_actions' [numpy.ndarray with shape (num_steps, n_control_factors) or None]:
        Array of previous actions, which can be used to constrain the 'past' messages in inference 
        to only consider states of affairs that were possible under actions that are known to have been taken. 
        The first dimension of previous-arrays (previous_actions.shape[0]) encodes how far back in time the agent is 
        considering. The first timestep of this either corresponds to either the first timestep of the generative 
        process or the first timestep of the policy horizon (whichever is sooner in time).  (optional)
    -'use_gradient_descent' [bool]:
        Flag to indicate whether to use gradient descent to optimise posterior beliefs.
    -'tau' [float]:
        Learning rate for gradient descent (only used if use_gradient_descent is True)
 
  
    Returns
    ----------
    -'qs' [list of length T of numpy 1D arrays or array of arrays (with 1D numpy array entries):
        Marginal posterior beliefs over hidden states (single- or multi-factor) achieved 
        via marginal message pasing
    -'qss' [list of lists of length T of numpy 1D arrays or array of arrays (with 1D numpy array entries):
        Marginal posterior beliefs about hidden states (single- or multi-factor) held at 
        each timepoint, *about* each timepoint of the observation
        sequence
    -'F' [2D np.ndarray]:
        Variational free energy of beliefs about hidden states, indexed by time point and variational iteration
    -'F_pol' [float]:
        Total free energy of the policy under consideration.
    """

    # get temporal window for inference
    min_time = max(0, curr_t - t_horizon)
    max_time = min(T, curr_t + t_horizon)
    window_idxs = np.array([t for t in range(min_time, max_time)])
    window_len = len(window_idxs)
    # TODO: needs a better name - the point at which we ignore future messages
    future_cutoff = window_len - 1
    inference_len = window_len + 1
    obs_seq_len = len(obs_t)

    # get relevant observations, given our current time point
    if curr_t == 0:
        obs_range = [0]
    else:
        min_obs_idx = max(0, curr_t - t_horizon)
        max_obs_idx = curr_t + 1
        obs_range = range(min_obs_idx, max_obs_idx)

    # get model dimensions
    # TODO: make a general function in `utils` for extracting model dimensions
    if utils.is_arr_of_arr(obs_t[0]):
        num_obs = [obs.shape[0] for obs in obs_t[0]]
    else:
        num_obs = [obs_t[0].shape[0]]

    if utils.is_arr_of_arr(B):
        num_states = [b.shape[0] for b in B]
    else:
        num_states = [B[0].shape[0]]
        B = utils.to_arr_of_arr(B)

    num_modalities = len(num_obs)
    num_factors = len(num_states)
    """
    =========== Step 1 ===========
        Calculate likelihood
        Loop over modalities and use assumption of independence among observation modalities
        to combine each modality-specific likelihood into a single joint likelihood over hidden states 
    """

    # likelihood of observations under configurations of hidden states (over time)
    likelihood = np.empty(len(obs_range), dtype=object)
    for t, obs in enumerate(obs_range):
        # likelihood_t = np.ones(tuple(num_states))

        # if num_modalities == 1:
        #     likelihood_t *= spm_dot(A[0], obs_t[obs], obs_mode=True)
        # else:
        #     for modality in range(num_modalities):
        #         likelihood_t *= spm_dot(A[modality], obs_t[obs][modality], obs_mode=True)

        likelihood_t = get_joint_likelihood(A, obs_t, num_states)

        # The Thomas Parr MMP version, you log the likelihood first
        # likelihood[t] = np.log(likelihood_t + 1e-16)

        # Karl SPM version, logging doesn't happen until *after* the dotting with the posterior
        likelihood[t] = likelihood_t
    """
    =========== Step 2 ===========
        Initialise a flat posterior (and prior if necessary)
        If a prior is not provided, initialise a uniform prior
    """

    qs = [np.empty(num_factors, dtype=object) for i in range(inference_len)]

    for t in range(inference_len):
        # if t == window_len:
        #     # final message is zeros - has no effect on inference
        #     # TODO: this may be redundant now that we skip last step
        #     for f in range(num_factors):
        #         qs[t][f] = np.zeros(num_states[f])
        # else:
        # for f in range(num_factors):
        #     qs[t][f] = np.ones(num_states[f]) / num_states[f]
        for f in range(num_factors):
            qs[t][f] = np.ones(num_states[f]) / num_states[f]

    if prior is None:
        prior = np.empty(num_factors, dtype=object)
        for f in range(num_factors):
            prior[f] = np.ones(num_states[f]) / num_states[f]
    """ 
    =========== Step 3 ===========
        Create a normalized transpose of the transition distribution `B_transposed`
        Used for computing backwards messages 'from the future'
    """

    B_transposed = np.empty(num_factors, dtype=object)
    for f in range(num_factors):
        B_transposed[f] = np.zeros_like(B[f])
        for u in range(B[f].shape[2]):
            B_transposed[f][:, :, u] = spm_norm(B[f][:, :, u].T)

    # zero out final message
    # TODO: may be redundant now we skip final step
    last_message = np.empty(num_factors, dtype=object)
    for f in range(num_factors):
        last_message[f] = np.zeros(num_states[f])

    # if previous actions not given, zero out to stop any influence on inference
    if previous_actions is None:
        previous_actions = np.zeros((1, policy.shape[1]))

    full_policy = np.vstack((previous_actions, policy))

    # print(full_policy.shape)
    """
    =========== Step 3 ===========
        Loop over time indices of time window, updating posterior as we go
        This includes past time steps and future time steps
    """

    qss = [[] for i in range(num_iter)]
    free_energy = np.zeros((len(qs), num_iter))
    free_energy_pol = 0.0

    # print(obs_seq_len)

    print('Full policy history')
    print('------------------')
    print(full_policy)

    for n in range(num_iter):
        for t in range(inference_len):

            lnB_past_tensor = np.empty(num_factors, dtype=object)
            for f in range(num_factors):

                # if t == 0 and n == 0:
                #     print(f"qs at time t = {t}, factor f = {f}, iteration i = {n}: {qs[t][f]}")
                """
                =========== Step 3.a ===========
                    Calculate likelihood
                """
                if t < len(obs_range):
                    # if t < len(obs_seq_len):
                    # Thomas Parr MMP version
                    # lnA = spm_dot(likelihood[t], qs[t], [f])

                    # Karl SPM version
                    lnA = np.log(spm_dot(likelihood[t], qs[t], [f]) + 1e-16)
                else:
                    lnA = np.zeros(num_states[f])

                if t == 1 and n == 0:
                    # pass
                    print(
                        f"lnA at time t = {t}, factor f = {f}, iteration i = {n}: {lnA}"
                    )

                # print(f"lnA at time t = {t}, factor f = {f}, iteration i = {n}: {lnA}")
                """
                =========== Step 3.b ===========
                    Calculate past message
                """
                if t == 0 and window_idxs[0] == 0:
                    lnB_past = np.log(prior[f] + 1e-16)
                else:
                    # Thomas Parr MMP version
                    # lnB_past = 0.5 * np.log(B[f][:, :, full_policy[t - 1, f]].dot(qs[t - 1][f]) + 1e-16)

                    # Karl SPM version
                    if t == 1 and n == 0 and f == 1:
                        print('past action:')
                        print('-------------')
                        print(full_policy[t - 1, :])
                        print(B[f][:, :, 0])
                        print(B[f][:, :, 1])
                        print(qs[t - 1][f])
                    lnB_past = np.log(
                        B[f][:, :, full_policy[t - 1, f]].dot(qs[t - 1][f]) +
                        1e-16)
                    # if t == 0:
                    # print(
                    # f"qs_t_1 at time t = {t}, factor f = {f}, iteration i = {n}: {qs[t - 1][f]}"
                    # )

                if t == 1 and n == 0:
                    print(
                        f"lnB_past at time t = {t}, factor f = {f}, iteration i = {n}: {lnB_past}"
                    )
                """
                =========== Step 3.c ===========
                    Calculate future message
                """
                if t >= future_cutoff:
                    # TODO: this is redundant - not used in code
                    lnB_future = last_message[f]
                else:
                    # Thomas Parr MMP version
                    # B_future = B_transposed[f][:, :, int(full_policy[t, f])].dot(qs[t + 1][f])
                    # lnB_future = 0.5 * np.log(B_future + 1e-16)

                    # Karl Friston SPM version
                    B_future = B_transposed[f][:, :,
                                               int(full_policy[t, f])].dot(
                                                   qs[t + 1][f])
                    lnB_future = np.log(B_future + 1e-16)

                # Thomas Parr MMP version
                # lnB_past_tensor[f] = 2 * lnBpast
                # Karl SPM version
                lnB_past_tensor[f] = lnB_past
                """
                =========== Step 3.d ===========
                    Update posterior
                """
                if use_gradient_descent:
                    lns = np.log(qs[t][f] + 1e-16)

                    # Thomas Parr MMP version
                    # error = (lnA + lnBpast + lnBfuture) - lns

                    # Karl SPM version
                    if t >= future_cutoff:
                        error = lnA + lnB_past - lns

                    else:
                        error = (2 * lnA + lnB_past + lnB_future) - 2 * lns

                    # print(f"prediction error at time t = {t}, factor f = {f}, iteration i = {n}: {error}")
                    # print(f"OG {t} {f} {error}")
                    error -= error.mean()
                    lns = lns + tau * error
                    qs_t_f = softmax(lns)
                    free_energy_pol += 0.5 * qs[t][f].dot(error)
                    qs[t][f] = qs_t_f
                else:
                    qs[t][f] = softmax(lnA + lnB_past + lnB_future)

            # TODO: probably works anyways
            # free_energy[t, n] = calc_free_energy(qs[t], lnB_past_tensor, num_factors, likelihood[t])
            # free_energy_pol += F[t, n]
        qss[n].append(qs)

    return qs, qss, free_energy, free_energy_pol
Exemplo n.º 10
0
def run_mmp(
    A,
    B,
    obs_t,
    policy,
    curr_t,
    t_horizon,
    T,
    qs_bma=None,
    prior=None,
    num_iter=10,
    dF=1.0,
    dF_tol=0.001,
    previous_actions=None,
    use_gradient_descent=False,
    tau=0.25,
):
    """
    Optimise marginal posterior beliefs about hidden states using marginal message-passing scheme (MMP) developed
    by Thomas Parr and colleagues, see https://github.com/tejparr/nmpassing
   
    Parameters
    ----------
    - 'A' [numpy nd.array (matrix or tensor or array-of-arrays)]:
        Observation likelihood of the generative model, mapping from hidden states to observations. 
        Used in inference to get the likelihood of an observation, under different hidden state configurations.
    - 'B' [numpy.ndarray (tensor or array-of-arrays)]:
        Transition likelihood of the generative model, mapping from hidden states at t to hidden states at t+1.
        Used in inference to get expected future (or past) hidden states, given past (or future) hidden states (or expectations thereof).
    - 'obs_t' [list of length t_horizon of numpy 1D array or array of arrays (with 1D numpy array entries)]:
        Sequence of observations sampled from beginning of time horizon the current timestep t. The first observation (the start of the time horizon) 
        is either the first timestep of the generative process or the first timestep of the policy horizon (whichever is closer to 'curr_t' in time).
        The observations over time are stored as a list of numpy arrays, where in case of multi-modalities each numpy array is an array-of-arrays, with
        one 1D numpy.ndarray for each modality. In the case of a single modality, each observation is a single 1D numpy.ndarray.
    - 'policy' [2D np.ndarray]:
        Array of actions constituting a single policy. Policy is a shape (n_steps, n_control_factors) numpy.ndarray, the values of which
        indicate actions along a given control factor (column index) at a given timestep (row index).
    - 'curr_t' [int]:
        Current timestep (relative to the 'absolute' time of the generative process).
    - 't_horizon'[int]:
        Temporal horizon of inference for states and policies.
    - 'T' [int]:
        Temporal horizon of the generative process (absolute time)
    - `qs_bma` [numpy 1D array, array of arrays (with 1D numpy array entries) or None]:
    - 'prior' [numpy 1D array, array of arrays (with 1D numpy array entries) or None]:
        Prior beliefs of the agent at the beginning of the time horizon, to be integrated with the marginal likelihood to obtain posterior at the first timestep.
        If absent, prior is set to be a uniform distribution over hidden states (identical to the initialisation of the posterior.
    -'num_iter' [int]:
        Number of variational iterations to run. (optional)
    -'dF' [float]:
        Starting free energy gradient (dF/dt) before updating in the course of gradient descent.  (optional)
    -'dF_tol' [float]:
        Threshold value of the gradient of the variational free energy (dF/dt), to be checked at each iteration. If 
        dF <= dF_tol, the iterations are halted pre-emptively and the final marginal posterior belief(s) is(are) returned.  (optional)
    -'previous_actions' [numpy.ndarray with shape (num_steps, n_control_factors) or None]:
        Array of previous actions, which can be used to constrain the 'past' messages in inference to only consider states of affairs that were possible
        under actions that are known to have been taken. The first dimension of previous-arrays (previous_actions.shape[0]) encodes how far back in time
        the agent is considering. The first timestep of this either corresponds to either the first timestep of the generative process or the f
        first timestep of the policy horizon (whichever is sooner in time).  (optional)
    -'use_gradient_descent' [bool]:
        Flag to indicate whether to use gradient descent to optimise posterior beliefs.
    -'tau' [float]:
        Learning rate for gradient descent (only used if use_gradient_descent is True)
 
  
    Returns
    ----------
    -'qs' [list of length T of numpy 1D arrays or array of arrays (with 1D numpy array entries):
        Marginal posterior beliefs over hidden states (single- or multi-factor) achieved via marginal message pasing
    -'qss' [list of lists of length T of numpy 1D arrays or array of arrays (with 1D numpy array entries):
        Marginal posterior beliefs about hidden states (single- or multi-factor) held at each timepoint, *about* each timepoint of the observation
        sequence
    -'F' [2D np.ndarray]:
        Variational free energy of beliefs about hidden states, indexed by time point and variational iteration
    -'F_pol' [float]:
        Total free energy of the policy under consideration.
    """

    # get model dimensions
    time_window_idxs = np.array([
        i
        for i in range(max(0, curr_t - t_horizon), min(T, curr_t + t_horizon))
    ])
    window_len = len(time_window_idxs)
    print("t_horizon ", t_horizon)
    print("window_len ", window_len)
    if utils.is_arr_of_arr(obs_t[0]):
        n_observations = [obs_array_i.shape[0] for obs_array_i in obs_t[0]]
    else:
        n_observations = [obs_t[0].shape[0]]

    if utils.is_arr_of_arr(B):
        n_states = [sub_B.shape[0] for sub_B in B]
    else:
        n_states = [B[0].shape[0]]
        B = utils.to_arr_of_arr(B)

    n_modalities = len(n_observations)
    n_factors = len(n_states)
    """
    =========== Step 1 ===========
        Loop over the observation modalities and use assumption of independence among observation modalities
        to multiply each modality-specific likelihood onto a single joint likelihood over hidden states [shape = n_states]
    """

    # compute time-window, taking into account boundary conditions
    if curr_t == 0:
        obs_range = [0]
    else:
        obs_range = range(max(0, curr_t - t_horizon), curr_t + 1)

    # likelihood of observations under configurations of hidden causes (over time)
    likelihood = np.empty(len(obs_range), dtype=object)
    for t in range(len(obs_range)):
        # likelihood_t = np.ones(tuple(n_states))

        # if n_modalities == 1:
        #     likelihood_t *= spm_dot(A, obs_t[obs_range[t]], obs_mode=True)
        # else:
        #     for modality in range(n_modalities):
        #         likelihood_t *= spm_dot(A[modality], obs_t[obs_range[t]][modality], obs_mode=True)

        likelihood_t = get_joint_likelihood(A, obs_t[obs_range[t]], n_states)

        # print(f"likelihood (pre-logging) {likelihood_t}")
        # likelihood[t] = np.log(likelihood_t + 1e-16) # The Thomas Parr MMP version, you log the likelihood first
        likelihood[
            t] = likelihood_t  # Karl SPM version, logging doesn't happen until *after* the dotting with the posterior
    """
    =========== Step 2 ===========
        Create a flat posterior (and prior if necessary)
        If prior is not provided, initialise prior to be identical to posterior
        (namely, a flat categorical distribution). Also make a normalized version of
        the transpose of the transition likelihood (for computing backwards messages 'from the future')
        called `B_t`
    """

    qs = [np.empty(n_factors, dtype=object) for i in range(window_len + 1)]
    print(len(qs))
    for t in range(window_len + 1):
        if t == window_len:
            for f in range(n_factors):
                qs[t][f] = np.zeros(n_states[f])
        else:
            for f in range(n_factors):
                qs[t][f] = np.ones(n_states[f]) / n_states[f]

    if prior is None:
        prior = np.empty(n_factors, dtype=object)
        for f in range(n_factors):
            prior[f] = np.ones(n_states[f]) / n_states[f]

    if n_factors == 1:
        B_t = np.zeros_like(B)
        for u in range(B.shape[2]):
            B_t[:, :, u] = spm_norm(B[:, :, u].T)
    elif n_factors > 1:
        B_t = np.empty(n_factors, dtype=object)
        for f in range(n_factors):
            B_t[f] = np.zeros_like(B[f])
            for u in range(B[f].shape[2]):
                B_t[f][:, :, u] = spm_norm(B[f][:, :, u].T)

    # set final future message as all ones at the time horizon (no information from beyond the horizon)
    last_message = np.empty(n_factors, dtype=object)
    for f in range(n_factors):
        last_message[f] = np.zeros(n_states[f])
    """
    =========== Step 3 ===========
        Loop over time indices of time window, which includes time before the policy horizon 
        as well as including the policy horizon
        n_steps, n_factors [0 1 2 0;
                            1 2 0 1]
    """

    if previous_actions is None:
        previous_actions = np.zeros((1, policy.shape[1]))

    full_policy = np.vstack((previous_actions, policy))
    # print(f"full_policy shape {full_policy.shape}")

    qss = [[] for i in range(num_iter)]
    F = np.zeros((len(qs), num_iter))
    F_pol = 0.0

    print('length of qs:', len(qs))
    # print(f"length obs_t {len(obs_t)}")
    for n in range(num_iter):
        for t in range(0, len(qs)):
            # for t in range(0, len(qs)):
            lnBpast_tensor = np.empty(n_factors, dtype=object)
            for f in range(n_factors):
                if t < len(
                        obs_t
                ):  # this is because of Python indexing (when t == len(obs_t)-1, we're at t == curr_t)
                    print(t)
                    # if t <= len(obs_t):
                    # print(f"t index {t}")
                    # print(f"length likelihood {len(likelihood)}")
                    # print(f"length qs {len(qs)}")
                    # lnA = spm_dot(likelihood[t], qs[t], [f]) # the Thomas Parr MMP version
                    lnA = np.log(spm_dot(likelihood[t], qs[t], [f]) + 1e-16)
                    if t == 2 and f == 0:
                        print(f"lnA at time t = {t}, factor f = {f}: {lnA}")
                else:
                    lnA = np.zeros(n_states[f])

                if t == 0:
                    lnBpast = np.log(prior[f] + 1e-16)
                else:
                    # lnBpast = 0.5 * np.log(
                    #     B[f][:, :, full_policy[t - 1, f]].dot(qs[t - 1][f]) + 1e-16
                    # ) # the Thomas Parr MMP version
                    lnBpast = np.log(
                        B[f][:, :, full_policy[t - 1, f]].dot(qs[t - 1][f]) +
                        1e-16)  # the Karl SPM version

                if t == 2 and f == 0:
                    print(
                        f"lnBpast at time t = {t}, factor f = {f}: {lnBpast}")

                # print(f"lnBpast at time t = {t}, factor f = {f}: {lnBpast}")

                # this is never reached
                if t >= len(
                        qs
                ) - 2:  # if we're at the end of the inference chain (at the current moment), the last message is just zeros
                    lnBfuture = last_message[f]
                    print('At final timestep!')
                    # print(f"lnBfuture at time t = {t}, factor f = {f}: {lnBfuture}")
                else:
                    # if t == 0 and f == 0:
                    #     print(B_t[f][:, :, int(full_policy[t, f])])
                    #     print(qs[t + 1][f])
                    # lnBfuture = 0.5 * np.log(
                    #     B_t[f][:, :, int(full_policy[t, f])].dot(qs[t + 1][f]) + 1e-16
                    # ) # the Thomas Parr MMP version
                    lnBfuture = np.log(
                        B_t[f][:, :, int(full_policy[t, f])].dot(
                            qs[t + 1][f]) + 1e-16
                    )  # the Karl SPM  version (without the 0.5 in front)

                if t == 2 and f == 0:
                    print(
                        f"lnBfuture at time t = {t}, factor f = {f}: {lnBfuture}"
                    )

                # if t == 0 and f == 0:
                #     print(f"lnBfuture at time t= {t}: {lnBfuture}")

                # lnBpast_tensor[f] = 2 * lnBpast # the Thomas Parr MMP version
                lnBpast_tensor[f] = lnBpast  # the Karl version
                if use_gradient_descent:
                    # gradients
                    lns = np.log(qs[t][f] + 1e-16)  # current estimate
                    # e = (lnA + lnBpast + lnBfuture) - lns  # prediction error, Thomas Parr version
                    if t >= len(qs) - 2:
                        e = lnA + lnBpast - lns
                    else:
                        e = (2 * lnA + lnBpast + lnBfuture
                             ) - 2 * lns  # prediction error, Karl SPM version
                    e -= e.mean()  # Karl SPM version
                    print(
                        f"prediction error at time t = {t}, factor f = {f}: {e}"
                    )
                    lns += tau * e  # increment the current (log) belief with the prediction error

                    qs_t_f = softmax(lns)

                    F_pol += 0.5 * qs[t][f].dot(e)

                    qs[t][f] = qs_t_f
                else:
                    # free energy minimum for the factor in question
                    qs[t][f] = softmax(lnA + lnBpast + lnBfuture)

            # F[t, n] = calc_free_energy(qs[t], lnBpast_tensor, n_factors, likelihood[t])
            # F_pol += F[t, n]
        qss[n].append(qs)

    return qs, qss, F, F_pol
Exemplo n.º 11
0
def update_posterior_policies_mmp(
    qs_seq_pi,
    A,
    B,
    C,
    policies,
    use_utility=True,
    use_states_info_gain=True,
    use_param_info_gain=False,
    prior=None,
    pA=None,
    pB=None,
    F=None,
    E=None,
    gamma=16.0,
    return_numpy=True,
):
    """
    `qs_seq_pi`: numpy object array that stores posterior marginals beliefs over hidden states for each policy. 
                The structure is nested as policies --> timesteps --> hidden state factors. So qs_seq_pi[p_idx][t][f] is the belief about factor `f` at time `t`, under policy `p_idx`
    `A`: numpy object array that stores likelihood mappings for each modality.
    `B`: numpy object array that stores transition matrices (possibly action-conditioned) for each hidden state factor
    `policies`: numpy object array that stores each (potentially-multifactorial) policy in `policies[p_idx]`. Shape of `policies[p_idx]` is `(num_timesteps, num_factors)`
    `use_utility`: Boolean that determines whether expected utility should be incorporated into computation of EFE (default: `True`)
    `use_states_info_gain`: Boolean that determines whether state epistemic value (info gain about hidden states) should be incorporated into computation of EFE (default: `True`)
    `use_param_info_gain`: Boolean that determines whether parameter epistemic value (info gain about generative model parameters) should be incorporated into computation of EFE (default: `False`)
    `prior`: numpy object array that stores priors over hidden states - this matters when computing the first value of the parameter info gain for the Dirichlet parameters over B
    `pA`: numpy object array that stores Dirichlet priors over likelihood mappings (one per modality)
    `pB`: numpy object array that stores Dirichlet priors over transition mappings (one per hidden state factor)
    `F` : 1D numpy array that stores variational free energy of each policy 
    `E` : 1D numpy array that stores prior probability each policy (e.g. 'habits')
    `gamma`: Float that encodes the precision over policies
    `return_numpy`: Boolean that determines whether output should be a numpy array or an instance of the Categorical class (default: `True`)
    """

    A = utils.to_numpy(A)
    B = utils.to_numpy(B)
    num_obs, num_states, num_modalities, num_factors = utils.get_model_dimensions(
        A, B)
    horizon = len(qs_seq_pi[0])
    num_policies = len(qs_seq_pi)

    # initialise`qo_seq` as object arrays to initially populate `qo_seq_pi`
    qo_seq = utils.obj_array(horizon)
    for t in range(horizon):
        qo_seq[t] = utils.obj_array_zeros(num_obs)

    # initialise expected observations
    qo_seq_pi = utils.obj_array(num_policies)
    for p_idx in range(num_policies):
        # qo_seq_pi[p_idx] = copy.deepcopy(obs_over_time)
        qo_seq_pi[p_idx] = qo_seq

    efe = np.zeros(num_policies)

    if F is None:
        F = np.zeros(num_policies)
    if E is None:
        E = np.zeros(num_policies)

    for p_idx, policy in enumerate(policies):

        qs_seq_pi_i = qs_seq_pi[p_idx]

        for t in range(horizon):

            qo_pi_t = get_expected_obs(qs_seq_pi_i[t], A)
            qo_seq_pi[p_idx][t] = qo_pi_t

            if use_utility:
                efe[p_idx] += calc_expected_utility(qo_seq_pi[p_idx][t], C)

            if use_states_info_gain:
                efe[p_idx] += calc_states_info_gain(A, qs_seq_pi_i[t])

            if use_param_info_gain:
                if pA is not None:
                    efe[p_idx] += calc_pA_info_gain(pA, qo_seq_pi[p_idx][t],
                                                    qs_seq_pi_i[t])
                if pB is not None:
                    if t > 0:
                        efe[p_idx] += calc_pB_info_gain(
                            pB, qs_seq_pi_i[t], qs_seq_pi_i[t - 1], policy)
                    else:
                        if prior is not None:
                            efe[p_idx] += calc_pB_info_gain(
                                pB, qs_seq_pi_i[t], prior, policy)

    q_pi = softmax(efe * gamma - F + E)
    if return_numpy:
        q_pi = q_pi / q_pi.sum(axis=0)
    else:
        q_pi = utils.to_categorical(q_pi)
        q_pi.normalize()
    return q_pi, efe
Exemplo n.º 12
0
    return u


"""
Experiment 
"""

# number of time steps
T = 10

# reset environment
obs = env.reset()
print("Initial Location {}".format(env.state))
# infer initial state
qs = maths.softmax(A[obs, :].log())

# loop over time
for t in range(T):

    # infer action
    action = infer_action(qs, A, B, C, n_control, policies)

    # perform action
    obs = env.step(action)

    # infer new hidden state
    qs = maths.softmax(A[obs, :].log() + B[:, :, action].dot(qs).log())

    # print information
    print("Time step {} Location {}".format(t, env.state))
Exemplo n.º 13
0
def run_mmp_v2(A,
               B,
               ll_seq,
               policy,
               prev_actions=None,
               prior=None,
               num_iter=10,
               grad_descent=False,
               tau=0.25):
    # window
    past_len = len(ll_seq)
    future_len = policy.shape[0]
    infer_len = past_len + future_len
    future_cutoff = past_len + future_len - 2

    # dimensions
    _, num_states, _, num_factors = get_model_dimensions(A, B)
    A = to_arr_of_arr(A)
    B = to_arr_of_arr(B)

    # beliefs
    qs_seq = [np.empty(num_factors, dtype=object) for _ in range(infer_len)]
    for t in range(infer_len):
        for f in range(num_factors):
            qs_seq[t][f] = np.ones(num_states[f]) / num_states[f]

    # last message
    qs_T = np.empty(num_factors, dtype=object)
    for f in range(num_factors):
        qs_T[f] = np.zeros(num_states[f])

    # prior
    if prior is None:
        prior = np.empty(num_factors, dtype=object)
        for f in range(num_factors):
            prior[f] = np.ones(num_states[f]) / num_states[f]

    # transposed transition
    trans_B = np.empty(num_factors, dtype=object)
    for f in range(num_factors):
        trans_B[f] = np.zeros_like(B[f])
        for u in range(B[f].shape[2]):
            trans_B[f][:, :, u] = spm_norm(B[f][:, :, u].T)

    # full policy
    if prev_actions is None:
        prev_actions = np.zeros((past_len, policy.shape[1]))
    policy = np.vstack((prev_actions, policy))

    for _ in range(num_iter):
        for t in range(infer_len):
            for f in range(num_factors):
                # likelihood
                if t < past_len:
                    lnA = np.log(spm_dot(ll_seq[t], qs_seq[t], [f]) + 1e-16)
                else:
                    lnA = np.zeros(num_states[f])

                # past message
                if t == 0:
                    lnB_past = np.log(prior[f] + 1e-16)
                else:
                    past_msg = B[f][:, :, int(policy[t - 1,
                                                     f])].dot(qs_seq[t - 1][f])
                    lnB_past = np.log(past_msg + 1e-16)

                # future message
                if t >= future_cutoff:
                    lnB_future = qs_T[f]
                else:
                    future_msg = trans_B[f][:, :, int(policy[t, f])].dot(
                        qs_seq[t + 1][f])
                    lnB_future = np.log(future_msg + 1e-16)

                # inference
                if grad_descent:
                    lnqs = np.log(qs_seq[t][f] + 1e-16)
                    coeff = 1 if (t >= future_cutoff) else 2
                    err = (coeff * lnA + lnB_past + lnB_future) - coeff * lnqs
                    err -= err.mean()
                    lnqs = lnqs + tau * err
                    qs_seq[t][f] = softmax(lnqs)
                else:
                    qs_seq[t][f] = softmax(lnA + lnB_past + lnB_future)

    return qs_seq