def run_fpi(A, obs, n_observations, n_states, prior=None, num_iter=10, dF=1.0, dF_tol=0.001): """ Update marginal posterior beliefs about hidden states using variational fixed point iteration (FPI). Parameters ---------- - 'A' [numpy nd.array (matrix or tensor or array-of-arrays)]: Observation likelihood of the generative model, mapping from hidden states to observations. Used to invert generative model to obtain marginal likelihood over hidden states, given the observation - 'obs' [numpy 1D array or array of arrays (with 1D numpy array entries)]: The observation (generated by the environment). If single modality, this can be a 1D array (one-hot vector representation). If multi-modality, this can be an array of arrays (whose entries are 1D one-hot vectors). - 'n_observations' [int or list of ints] - 'n_states' [int or list of ints] - 'prior' [numpy 1D array, array of arrays (with 1D numpy array entries) or None]: Prior beliefs of the agent, to be integrated with the marginal likelihood to obtain posterior. If absent, prior is set to be a uniform distribution over hidden states (identical to the initialisation of the posterior) -'num_iter' [int]: Number of variational fixed-point iterations to run. -'dF' [float]: Starting free energy gradient (dF/dt) before updating in the course of gradient descent. -'dF_tol' [float]: Threshold value of the gradient of the variational free energy (dF/dt), to be checked at each iteration. If dF <= dF_tol, the iterations are halted pre-emptively and the final marginal posterior belief(s) is(are) returned Returns ---------- -'qs' [numpy 1D array or array of arrays (with 1D numpy array entries): Marginal posterior beliefs over hidden states (single- or multi-factor) achieved via variational fixed point iteration (mean-field) """ # get model dimensions n_modalities = len(n_observations) n_factors = len(n_states) """ =========== Step 1 =========== Loop over the observation modalities and use assumption of independence among observation modalitiesto multiply each modality-specific likelihood onto a single joint likelihood over hidden factors [size n_states] """ likelihood = get_joint_likelihood(A, obs, n_states) likelihood = np.log(likelihood + 1e-16) """ =========== Step 2 =========== Create a flat posterior (and prior if necessary) """ qs = np.empty(n_factors, dtype=object) for factor in range(n_factors): qs[factor] = np.ones(n_states[factor]) / n_states[factor] """ If prior is not provided, initialise prior to be identical to posterior (namely, a flat categorical distribution). Take the logarithm of it (required for FPI algorithm below). """ if prior is None: prior = np.empty(n_factors, dtype=object) for factor in range(n_factors): prior[factor] = np.log( np.ones(n_states[factor]) / n_states[factor] + 1e-16) """ =========== Step 3 =========== Initialize initial free energy """ prev_vfe = calc_free_energy(qs, prior, n_factors) """ =========== Step 4 =========== If we have a single factor, we can just add prior and likelihood, otherwise we run FPI """ if n_factors == 1: qL = spm_dot(likelihood, qs, [0]) return softmax(qL + prior[0]) else: """ =========== Step 5 =========== Run the FPI scheme """ curr_iter = 0 while curr_iter < num_iter and dF >= dF_tol: # Initialise variational free energy vfe = 0 # arg_list = [likelihood, list(range(n_factors))] # arg_list = arg_list + list(chain(*([qs_i,[i]] for i, qs_i in enumerate(qs)))) + [list(range(n_factors))] # LL_tensor = np.einsum(*arg_list) qs_all = qs[0] for factor in range(n_factors - 1): qs_all = qs_all[..., None] * qs[factor + 1] LL_tensor = likelihood * qs_all for factor, qs_i in enumerate(qs): # qL = np.einsum(LL_tensor, list(range(n_factors)), 1.0/qs_i, [factor], [factor]) qL = np.einsum(LL_tensor, list(range(n_factors)), [factor]) / qs_i qs[factor] = softmax(qL + prior[factor]) # List of orders in which marginal posteriors are sequentially multiplied into the joint likelihood: # First order loops over factors starting at index = 0, second order goes in reverse # factor_orders = [range(n_factors), range((n_factors - 1), -1, -1)] # iteratively marginalize out each posterior marginal from the joint log-likelihood # except for the one associated with a given factor # for factor_order in factor_orders: # for factor in factor_order: # qL = spm_dot(likelihood, qs, [factor]) # qs[factor] = softmax(qL + prior[factor]) # calculate new free energy vfe = calc_free_energy(qs, prior, n_factors, likelihood) # stopping condition - time derivative of free energy dF = np.abs(prev_vfe - vfe) prev_vfe = vfe curr_iter += 1 return qs
def run_fpi_faster(A, obs, n_observations, n_states, prior=None, num_iter=10, dF=1.0, dF_tol=0.001): """ Update marginal posterior beliefs about hidden states using a new version of variational fixed point iteration (FPI). @NOTE (Conor, 26.02.2020): This method uses a faster algorithm than the traditional 'spm_dot' approach. Instead of separately computing a conditional joint log likelihood of an outcome, under the posterior probabilities of a certain marginal, instead all marginals are multiplied into one joint tensor that gives the joint likelihood of an observation under all hidden states, that is then sequentially (and *parallelizably*) marginalized out to get each marginal posterior. This method is less RAM-intensive, admits heavy parallelization, and runs (about 2x) faster. @NOTE (Conor, 28.02.2020): After further testing, discovered interesting differences between this version and the original version. It appears that the original version (simple 'run_FPI') shows mean-field biases or 'explaining away' effects, whereas this version spreads probabilities more 'fairly' among possibilities. To summarize: it actually matters what order you do the summing across the joint likelihood tensor. In this verison, all marginals are multiplied into the likelihood tensor before summing out, whereas in the previous version, marginals are recursively multiplied and summed out. @NOTE (Conor, 24.04.2020): I would expect that the factor_order approach used above would help ameliorate the effects of the mean-field bias. I would also expect that the use of a factor_order below is unnnecessary, since the marginalisation w.r.t. each factor is done only after all marginals are multiplied into the larger tensor. Parameters ---------- - 'A' [numpy nd.array (matrix or tensor or array-of-arrays)]: Observation likelihood of the generative model, mapping from hidden states to observations. Used to invert generative model to obtain marginal likelihood over hidden states, given the observation - 'obs' [numpy 1D array or array of arrays (with 1D numpy array entries)]: The observation (generated by the environment). If single modality, this can be a 1D array (one-hot vector representation). If multi-modality, this can be an array of arrays (whose entries are 1D one-hot vectors). - 'n_observations' [int or list of ints] - 'n_states' [int or list of ints] - 'prior' [numpy 1D array, array of arrays (with 1D numpy array entries) or None]: Prior beliefs of the agent, to be integrated with the marginal likelihood to obtain posterior. If absent, prior is set to be a uniform distribution over hidden states (identical to the initialisation of the posterior) -'num_iter' [int]: Number of variational fixed-point iterations to run. -'dF' [float]: Starting free energy gradient (dF/dt) before updating in the course of gradient descent. -'dF_tol' [float]: Threshold value of the gradient of the variational free energy (dF/dt), to be checked at each iteration. If dF <= dF_tol, the iterations are halted pre-emptively and the final marginal posterior belief(s) is(are) returned Returns ---------- -'qs' [numpy 1D array or array of arrays (with 1D numpy array entries): Marginal posterior beliefs over hidden states (single- or multi-factor) achieved via variational fixed point iteration (mean-field) """ # get model dimensions n_modalities = len(n_observations) n_factors = len(n_states) """ =========== Step 1 =========== Loop over the observation modalities and use assumption of independence among observation modalities to multiply each modality-specific likelihood onto a single joint likelihood over hidden factors [size n_states] """ # likelihood = np.ones(tuple(n_states)) # if n_modalities is 1: # likelihood *= spm_dot(A, obs, obs_mode=True) # else: # for modality in range(n_modalities): # likelihood *= spm_dot(A[modality], obs[modality], obs_mode=True) likelihood = get_joint_likelihood(A, obs, n_states) likelihood = np.log(likelihood + 1e-16) """ =========== Step 2 =========== Create a flat posterior (and prior if necessary) """ qs = np.empty(n_factors, dtype=object) for factor in range(n_factors): qs[factor] = np.ones(n_states[factor]) / n_states[factor] """ If prior is not provided, initialise prior to be identical to posterior (namely, a flat categorical distribution). Take the logarithm of it (required for FPI algorithm below). """ if prior is None: prior = np.empty(n_factors, dtype=object) for factor in range(n_factors): prior[factor] = np.log( np.ones(n_states[factor]) / n_states[factor] + 1e-16) """ =========== Step 3 =========== Initialize initial free energy """ prev_vfe = calc_free_energy(qs, prior, n_factors) """ =========== Step 4 =========== If we have a single factor, we can just add prior and likelihood, otherwise we run FPI """ if n_factors == 1: qL = spm_dot(likelihood, qs, [0]) return softmax(qL + prior[0]) else: """ =========== Step 5 =========== Run the revised fixed-point iteration scheme """ curr_iter = 0 while curr_iter < num_iter and dF >= dF_tol: # Initialise variational free energy vfe = 0 # List of orders in which marginal posteriors are sequentially # multiplied into the joint likelihood: First order loops over # factors starting at index = 0, second order goes in reverse factor_orders = [range(n_factors), range((n_factors - 1), -1, -1)] for factor_order in factor_orders: # reset the log likelihood L = likelihood.copy() # multiply each marginal onto a growing single joint distribution for factor in factor_order: s = np.ones(np.ndim(L), dtype=int) s[factor] = len(qs[factor]) L *= qs[factor].reshape(tuple(s)) # now loop over factors again, and this time divide out the # appropriate marginal before summing out. # !!! KEY DIFFERENCE BETWEEN THIS AND 'VANILLA' FPI, # WHERE THE ORDER OF THE MARGINALIZATION MATTERS !!! for f in factor_order: s = np.ones(np.ndim(L), dtype=int) s[factor] = len(qs[factor]) # type: ignore # divide out the factor we multiplied into X already temp = L * (1.0 / qs[factor]).reshape( tuple(s)) # type: ignore dims2sum = tuple(np.where(np.arange(n_factors) != f)[0]) qL = np.sum(temp, dims2sum) temp = L * (1.0 / qs[factor]).reshape( tuple(s)) # type: ignore qs[factor] = softmax(qL + prior[factor]) # type: ignore # calculate new free energy vfe = calc_free_energy(qs, prior, n_factors, likelihood) # stopping condition - time derivative of free energy dF = np.abs(prev_vfe - vfe) prev_vfe = vfe curr_iter += 1 return qs
def run_mmp(lh_seq, B, policy, prev_actions=None, prior=None, num_iter=10, grad_descent=False, tau=0.25, last_timestep=False, save_vfe_seq=False): """ Marginal message passing scheme for updating posterior beliefs about multi-factor hidden states over time, conditioned on a particular policy. Parameters: -------------- `lh_seq`[numpy object array]: Likelihoods of hidden state factors given a sequence of observations over time. This is logged beforehand `B`[numpy object array]: Transition likelihood of the generative model, mapping from hidden states at T to hidden states at T+1. One B matrix per modality (e.g. `B[f]` corresponds to f-th factor's B matrix) This is used in inference to compute the 'forward' and 'backward' messages conveyed between beliefs about temporally-adjacent timepoints. `policy` [2-D numpy.ndarray]: Matrix of shape (policy_len, num_control_factors) that indicates the indices of each action (control state index) upon timestep t and control_factor f in the element `policy[t,f]` for a given policy. `prev_actions` [None or 2-D numpy.ndarray]: If provided, should be a matrix of previous actions of shape (infer_len, num_control_factors) taht indicates the indices of each action (control state index) taken in the past (up until the current timestep). `prior`[None or numpy object array]: If provided, this a numpy object array with one sub-array per hidden state factor, that stores the prior beliefs about initial states (at t = 0, relative to `infer_len`). `num_iter`[Int]: Number of variational iterations `grad_descent` [Bool]: Flag for whether to use gradient descent (predictive coding style) `tau` [Float]: Decay constant for use in `grad_descent` version `last_timestep` [Bool]: Flag for whether we are at the last timestep of belief updating `save_vfe_seq` [Bool]: Flag for whether to save the sequence of variational free energies over time (for this policy). If `False`, then VFE is integrated across time/iterations. Returns: -------------- `qs_seq`[list]: the sequence of beliefs about the different hidden state factors over time, one multi-factor posterior belief per timestep in `infer_len` `F`[Float or list, depending on setting of save_vfe_seq] """ # window past_len = len(lh_seq) future_len = policy.shape[0] if last_timestep: infer_len = past_len + future_len - 1 else: infer_len = past_len + future_len future_cutoff = past_len + future_len - 2 # dimensions _, num_states, _, num_factors = get_model_dimensions(A=None, B=B) B = to_arr_of_arr(B) # beliefs qs_seq = obj_array(infer_len) for t in range(infer_len): qs_seq[t] = obj_array_uniform(num_states) # last message qs_T = obj_array_zeros(num_states) # prior if prior is None: prior = obj_array_uniform(num_states) # transposed transition trans_B = obj_array(num_factors) for f in range(num_factors): trans_B[f] = spm_norm(np.swapaxes(B[f], 0, 1)) # full policy if prev_actions is None: prev_actions = np.zeros((past_len, policy.shape[1])) policy = np.vstack((prev_actions, policy)) # initialise variational free energy of policy (accumulated over time) if save_vfe_seq: F = [] F.append(0.0) else: F = 0.0 for itr in range(num_iter): for t in range(infer_len): for f in range(num_factors): # likelihood if t < past_len: lnA = spm_log(spm_dot(lh_seq[t], qs_seq[t], [f])) else: lnA = np.zeros(num_states[f]) # past message if t == 0: lnB_past = spm_log(prior[f]) else: past_msg = B[f][:, :, int(policy[t - 1, f])].dot(qs_seq[t - 1][f]) lnB_past = spm_log(past_msg) # future message if t >= future_cutoff: lnB_future = qs_T[f] else: future_msg = trans_B[f][:, :, int(policy[t, f])].dot( qs_seq[t + 1][f]) lnB_future = spm_log(future_msg) # inference if grad_descent: lnqs = spm_log(qs_seq[t][f]) coeff = 1 if (t >= future_cutoff) else 2 err = (coeff * lnA + lnB_past + lnB_future) - coeff * lnqs err -= err.mean() lnqs = lnqs + tau * err qs_seq[t][f] = softmax(lnqs) if (t == 0) or (t == (infer_len - 1)): F += +0.5 * lnqs.dot(0.5 * err) else: F += lnqs.dot( 0.5 * (err - (num_factors - 1) * lnA / num_factors) ) # @NOTE: not sure why Karl does this in SPM_MDP_VB_X, we should look into this else: qs_seq[t][f] = softmax(lnA + lnB_past + lnB_future) if not grad_descent: if save_vfe_seq: if t < past_len: F.append( F[-1] + calc_free_energy(qs_seq[t], prior, num_factors, likelihood=spm_log(lh_seq[t]))[0]) else: F.append( F[-1] + calc_free_energy(qs_seq[t], prior, num_factors)[0]) else: if t < past_len: F += calc_free_energy(qs_seq[t], prior, num_factors, likelihood=spm_log(lh_seq[t])) else: F += calc_free_energy(qs_seq[t], prior, num_factors) return qs_seq, F