Beispiel #1
0
def run_fpi(A,
            obs,
            n_observations,
            n_states,
            prior=None,
            num_iter=10,
            dF=1.0,
            dF_tol=0.001):
    """
    Update marginal posterior beliefs about hidden states using variational fixed point iteration (FPI).
   
    Parameters
    ----------
    - 'A' [numpy nd.array (matrix or tensor or array-of-arrays)]:
        Observation likelihood of the generative model, mapping from hidden states to observations. 
        Used to invert generative model to obtain marginal likelihood over hidden states, given 
        the observation
    - 'obs' [numpy 1D array or array of arrays (with 1D numpy array entries)]:
        The observation (generated by the environment). If single modality, this can be a 
        1D array (one-hot vector representation).
        If multi-modality, this can be an array of arrays (whose entries are 1D one-hot vectors).
    - 'n_observations' [int or list of ints]
    - 'n_states' [int or list of ints]
    - 'prior' [numpy 1D array, array of arrays (with 1D numpy array entries) or None]:
        Prior beliefs of the agent, to be integrated with the marginal likelihood to obtain posterior. 
        If absent, prior is set to be a uniform distribution over hidden states (identical to the 
        initialisation of the posterior)
    -'num_iter' [int]:
        Number of variational fixed-point iterations to run.
    -'dF' [float]:
        Starting free energy gradient (dF/dt) before updating in the course of gradient descent.
    -'dF_tol' [float]:
        Threshold value of the gradient of the variational free energy (dF/dt), to be checked at 
        each iteration. If 
        dF <= dF_tol, the iterations are halted pre-emptively and the final 
        marginal posterior belief(s) is(are) returned
  
    Returns
    ----------
    -'qs' [numpy 1D array or array of arrays (with 1D numpy array entries):
        Marginal posterior beliefs over hidden states (single- or multi-factor) 
        achieved via variational fixed point iteration (mean-field)
    """

    # get model dimensions
    n_modalities = len(n_observations)
    n_factors = len(n_states)
    """
    =========== Step 1 ===========
        Loop over the observation modalities and use assumption of independence 
        among observation modalitiesto multiply each modality-specific likelihood 
        onto a single joint likelihood over hidden factors [size n_states]
    """

    likelihood = get_joint_likelihood(A, obs, n_states)

    likelihood = np.log(likelihood + 1e-16)
    """
    =========== Step 2 ===========
        Create a flat posterior (and prior if necessary)
    """

    qs = np.empty(n_factors, dtype=object)
    for factor in range(n_factors):
        qs[factor] = np.ones(n_states[factor]) / n_states[factor]
    """
    If prior is not provided, initialise prior to be identical to posterior 
    (namely, a flat categorical distribution). Take the logarithm of it (required for 
    FPI algorithm below).
    """
    if prior is None:
        prior = np.empty(n_factors, dtype=object)
        for factor in range(n_factors):
            prior[factor] = np.log(
                np.ones(n_states[factor]) / n_states[factor] + 1e-16)
    """
    =========== Step 3 ===========
        Initialize initial free energy
    """
    prev_vfe = calc_free_energy(qs, prior, n_factors)
    """
    =========== Step 4 ===========
        If we have a single factor, we can just add prior and likelihood,
        otherwise we run FPI
    """

    if n_factors == 1:
        qL = spm_dot(likelihood, qs, [0])
        return softmax(qL + prior[0])

    else:
        """
        =========== Step 5 ===========
        Run the FPI scheme
        """

        curr_iter = 0
        while curr_iter < num_iter and dF >= dF_tol:
            # Initialise variational free energy
            vfe = 0

            # arg_list = [likelihood, list(range(n_factors))]
            # arg_list = arg_list + list(chain(*([qs_i,[i]] for i, qs_i in enumerate(qs)))) + [list(range(n_factors))]
            # LL_tensor = np.einsum(*arg_list)

            qs_all = qs[0]
            for factor in range(n_factors - 1):
                qs_all = qs_all[..., None] * qs[factor + 1]
            LL_tensor = likelihood * qs_all

            for factor, qs_i in enumerate(qs):
                # qL = np.einsum(LL_tensor, list(range(n_factors)), 1.0/qs_i, [factor], [factor])
                qL = np.einsum(LL_tensor, list(range(n_factors)),
                               [factor]) / qs_i
                qs[factor] = softmax(qL + prior[factor])

            # List of orders in which marginal posteriors are sequentially multiplied into the joint likelihood:
            # First order loops over factors starting at index = 0, second order goes in reverse
            # factor_orders = [range(n_factors), range((n_factors - 1), -1, -1)]

            # iteratively marginalize out each posterior marginal from the joint log-likelihood
            # except for the one associated with a given factor
            # for factor_order in factor_orders:
            #     for factor in factor_order:
            #         qL = spm_dot(likelihood, qs, [factor])
            #         qs[factor] = softmax(qL + prior[factor])

            # calculate new free energy
            vfe = calc_free_energy(qs, prior, n_factors, likelihood)

            # stopping condition - time derivative of free energy
            dF = np.abs(prev_vfe - vfe)
            prev_vfe = vfe

            curr_iter += 1

        return qs
Beispiel #2
0
def run_fpi_faster(A,
                   obs,
                   n_observations,
                   n_states,
                   prior=None,
                   num_iter=10,
                   dF=1.0,
                   dF_tol=0.001):
    """
    Update marginal posterior beliefs about hidden states
    using a new version of variational fixed point iteration (FPI). 
    @NOTE (Conor, 26.02.2020):
    This method uses a faster algorithm than the traditional 'spm_dot' approach. Instead of
    separately computing a conditional joint log likelihood of an outcome, under the
    posterior probabilities of a certain marginal, instead all marginals are multiplied into one 
    joint tensor that gives the joint likelihood of an observation under all hidden states, 
    that is then sequentially (and *parallelizably*) marginalized out to get each marginal posterior. 
    This method is less RAM-intensive, admits heavy parallelization, and runs (about 2x) faster.
    @NOTE (Conor, 28.02.2020):
    After further testing, discovered interesting differences between this version and the 
    original version. It appears that the
    original version (simple 'run_FPI') shows mean-field biases or 'explaining away' 
    effects, whereas this version spreads probabilities more 'fairly' among possibilities.
    To summarize: it actually matters what order you do the summing across the joint likelihood tensor. 
    In this verison, all marginals are multiplied into the likelihood tensor before summing out, 
    whereas in the previous version, marginals are recursively multiplied and summed out.
    @NOTE (Conor, 24.04.2020): I would expect that the factor_order approach used above would help 
    ameliorate the effects of the mean-field bias. I would also expect that the use of a factor_order 
    below is unnnecessary, since the marginalisation w.r.t. each factor is done only after all marginals 
    are multiplied into the larger tensor.

    Parameters
    ----------
    - 'A' [numpy nd.array (matrix or tensor or array-of-arrays)]:
        Observation likelihood of the generative model, mapping from hidden states to observations. 
        Used to invert generative model to obtain marginal likelihood over hidden states, 
        given the observation
    - 'obs' [numpy 1D array or array of arrays (with 1D numpy array entries)]:
        The observation (generated by the environment). If single modality, this can be a 1D array 
        (one-hot vector representation). If multi-modality, this can be an array of arrays 
        (whose entries are 1D one-hot vectors).
    - 'n_observations' [int or list of ints]
    - 'n_states' [int or list of ints]
    - 'prior' [numpy 1D array, array of arrays (with 1D numpy array entries) or None]:
        Prior beliefs of the agent, to be integrated with the marginal likelihood to obtain posterior. 
        If absent, prior is set to be a uniform distribution over hidden states 
        (identical to the initialisation of the posterior)
    -'num_iter' [int]:
        Number of variational fixed-point iterations to run.
    -'dF' [float]:
        Starting free energy gradient (dF/dt) before updating in the course of gradient descent.
    -'dF_tol' [float]:
        Threshold value of the gradient of the variational free energy (dF/dt), 
        to be checked at each iteration. If dF <= dF_tol, the iterations are halted pre-emptively 
        and the final marginal posterior belief(s) is(are) returned
    Returns
    ----------
    -'qs' [numpy 1D array or array of arrays (with 1D numpy array entries):
        Marginal posterior beliefs over hidden states (single- or multi-factor) achieved 
        via variational fixed point iteration (mean-field)
    """

    # get model dimensions
    n_modalities = len(n_observations)
    n_factors = len(n_states)
    """
    =========== Step 1 ===========
        Loop over the observation modalities and use assumption of independence 
        among observation modalities to multiply each modality-specific likelihood 
        onto a single joint likelihood over hidden factors [size n_states]
    """

    # likelihood = np.ones(tuple(n_states))

    # if n_modalities is 1:
    #     likelihood *= spm_dot(A, obs, obs_mode=True)
    # else:
    #     for modality in range(n_modalities):
    #         likelihood *= spm_dot(A[modality], obs[modality], obs_mode=True)
    likelihood = get_joint_likelihood(A, obs, n_states)
    likelihood = np.log(likelihood + 1e-16)
    """
    =========== Step 2 ===========
        Create a flat posterior (and prior if necessary)
    """

    qs = np.empty(n_factors, dtype=object)
    for factor in range(n_factors):
        qs[factor] = np.ones(n_states[factor]) / n_states[factor]
    """
    If prior is not provided, initialise prior to be identical to posterior 
    (namely, a flat categorical distribution). Take the logarithm of it 
    (required for FPI algorithm below).
    """
    if prior is None:
        prior = np.empty(n_factors, dtype=object)
        for factor in range(n_factors):
            prior[factor] = np.log(
                np.ones(n_states[factor]) / n_states[factor] + 1e-16)
    """
    =========== Step 3 ===========
        Initialize initial free energy
    """
    prev_vfe = calc_free_energy(qs, prior, n_factors)
    """
    =========== Step 4 ===========
        If we have a single factor, we can just add prior and likelihood,
        otherwise we run FPI
    """

    if n_factors == 1:
        qL = spm_dot(likelihood, qs, [0])
        return softmax(qL + prior[0])

    else:
        """
        =========== Step 5 ===========
        Run the revised fixed-point iteration scheme
        """

        curr_iter = 0

        while curr_iter < num_iter and dF >= dF_tol:
            # Initialise variational free energy
            vfe = 0

            # List of orders in which marginal posteriors are sequentially
            # multiplied into the joint likelihood: First order loops over
            # factors starting at index = 0, second order goes in reverse
            factor_orders = [range(n_factors), range((n_factors - 1), -1, -1)]

            for factor_order in factor_orders:
                # reset the log likelihood
                L = likelihood.copy()

                # multiply each marginal onto a growing single joint distribution
                for factor in factor_order:
                    s = np.ones(np.ndim(L), dtype=int)
                    s[factor] = len(qs[factor])
                    L *= qs[factor].reshape(tuple(s))

                # now loop over factors again, and this time divide out the
                # appropriate marginal before summing out.
                # !!! KEY DIFFERENCE BETWEEN THIS AND 'VANILLA' FPI,
                # WHERE THE ORDER OF THE MARGINALIZATION MATTERS !!!
                for f in factor_order:
                    s = np.ones(np.ndim(L), dtype=int)
                    s[factor] = len(qs[factor])  # type: ignore

                    # divide out the factor we multiplied into X already
                    temp = L * (1.0 / qs[factor]).reshape(
                        tuple(s))  # type: ignore
                    dims2sum = tuple(np.where(np.arange(n_factors) != f)[0])
                    qL = np.sum(temp, dims2sum)

                    temp = L * (1.0 / qs[factor]).reshape(
                        tuple(s))  # type: ignore
                    qs[factor] = softmax(qL + prior[factor])  # type: ignore

            # calculate new free energy
            vfe = calc_free_energy(qs, prior, n_factors, likelihood)

            # stopping condition - time derivative of free energy
            dF = np.abs(prev_vfe - vfe)
            prev_vfe = vfe

            curr_iter += 1

        return qs
Beispiel #3
0
def run_mmp(lh_seq,
            B,
            policy,
            prev_actions=None,
            prior=None,
            num_iter=10,
            grad_descent=False,
            tau=0.25,
            last_timestep=False,
            save_vfe_seq=False):
    """
    Marginal message passing scheme for updating posterior beliefs about multi-factor hidden states over time, 
    conditioned on a particular policy.
    Parameters:
    --------------
    `lh_seq`[numpy object array]:
        Likelihoods of hidden state factors given a sequence of observations over time. This is logged beforehand
    `B`[numpy object array]:
        Transition likelihood of the generative model, mapping from hidden states at T to hidden states at T+1. One B matrix per modality (e.g. `B[f]` corresponds to f-th factor's B matrix)
        This is used in inference to compute the 'forward' and 'backward' messages conveyed between beliefs about temporally-adjacent timepoints.
    `policy` [2-D numpy.ndarray]:
        Matrix of shape (policy_len, num_control_factors) that indicates the indices of each action (control state index) upon timestep t and control_factor f in the element `policy[t,f]` for a given policy.
    `prev_actions` [None or 2-D numpy.ndarray]:
        If provided, should be a matrix of previous actions of shape (infer_len, num_control_factors) taht indicates the indices of each action (control state index) taken in the past (up until the current timestep).
    `prior`[None or numpy object array]:
        If provided, this a numpy object array with one sub-array per hidden state factor, that stores the prior beliefs about initial states (at t = 0, relative to `infer_len`).
    `num_iter`[Int]:
        Number of variational iterations
    `grad_descent` [Bool]:
        Flag for whether to use gradient descent (predictive coding style)
    `tau` [Float]:
        Decay constant for use in `grad_descent` version
    `last_timestep` [Bool]:
        Flag for whether we are at the last timestep of belief updating
    `save_vfe_seq` [Bool]:
        Flag for whether to save the sequence of variational free energies over time (for this policy). If `False`, then VFE is integrated across time/iterations.
    Returns:
    --------------
    `qs_seq`[list]: the sequence of beliefs about the different hidden state factors over time, one multi-factor posterior belief per timestep in `infer_len`
    `F`[Float or list, depending on setting of save_vfe_seq]
    """

    # window
    past_len = len(lh_seq)
    future_len = policy.shape[0]

    if last_timestep:
        infer_len = past_len + future_len - 1
    else:
        infer_len = past_len + future_len

    future_cutoff = past_len + future_len - 2

    # dimensions
    _, num_states, _, num_factors = get_model_dimensions(A=None, B=B)
    B = to_arr_of_arr(B)

    # beliefs
    qs_seq = obj_array(infer_len)
    for t in range(infer_len):
        qs_seq[t] = obj_array_uniform(num_states)

    # last message
    qs_T = obj_array_zeros(num_states)

    # prior
    if prior is None:
        prior = obj_array_uniform(num_states)

    # transposed transition
    trans_B = obj_array(num_factors)

    for f in range(num_factors):
        trans_B[f] = spm_norm(np.swapaxes(B[f], 0, 1))

    # full policy
    if prev_actions is None:
        prev_actions = np.zeros((past_len, policy.shape[1]))
    policy = np.vstack((prev_actions, policy))

    # initialise variational free energy of policy (accumulated over time)

    if save_vfe_seq:
        F = []
        F.append(0.0)
    else:
        F = 0.0

    for itr in range(num_iter):
        for t in range(infer_len):
            for f in range(num_factors):
                # likelihood
                if t < past_len:
                    lnA = spm_log(spm_dot(lh_seq[t], qs_seq[t], [f]))
                else:
                    lnA = np.zeros(num_states[f])

                # past message
                if t == 0:
                    lnB_past = spm_log(prior[f])
                else:
                    past_msg = B[f][:, :, int(policy[t - 1,
                                                     f])].dot(qs_seq[t - 1][f])
                    lnB_past = spm_log(past_msg)

                # future message
                if t >= future_cutoff:
                    lnB_future = qs_T[f]
                else:
                    future_msg = trans_B[f][:, :, int(policy[t, f])].dot(
                        qs_seq[t + 1][f])
                    lnB_future = spm_log(future_msg)

                # inference
                if grad_descent:
                    lnqs = spm_log(qs_seq[t][f])
                    coeff = 1 if (t >= future_cutoff) else 2
                    err = (coeff * lnA + lnB_past + lnB_future) - coeff * lnqs
                    err -= err.mean()
                    lnqs = lnqs + tau * err
                    qs_seq[t][f] = softmax(lnqs)
                    if (t == 0) or (t == (infer_len - 1)):
                        F += +0.5 * lnqs.dot(0.5 * err)
                    else:
                        F += lnqs.dot(
                            0.5 * (err - (num_factors - 1) * lnA / num_factors)
                        )  # @NOTE: not sure why Karl does this in SPM_MDP_VB_X, we should look into this
                else:
                    qs_seq[t][f] = softmax(lnA + lnB_past + lnB_future)

            if not grad_descent:

                if save_vfe_seq:
                    if t < past_len:
                        F.append(
                            F[-1] +
                            calc_free_energy(qs_seq[t],
                                             prior,
                                             num_factors,
                                             likelihood=spm_log(lh_seq[t]))[0])
                    else:
                        F.append(
                            F[-1] +
                            calc_free_energy(qs_seq[t], prior, num_factors)[0])
                else:
                    if t < past_len:
                        F += calc_free_energy(qs_seq[t],
                                              prior,
                                              num_factors,
                                              likelihood=spm_log(lh_seq[t]))
                    else:
                        F += calc_free_energy(qs_seq[t], prior, num_factors)

    return qs_seq, F