def get_joint_likelihood(A, obs, num_states): # deal with single modality case if type(num_states) is int: num_states = [num_states] A = utils.to_arr_of_arr(A) obs = utils.to_arr_of_arr(obs) ll = np.ones(tuple(num_states)) for modality in range(len(A)): ll = ll * dot_likelihood(A[modality], obs[modality]) return ll
def spm_dot_classic(X, x, dims_to_omit=None): """ Dot product of a multidimensional array with `x`. The dimensions in `dims_to_omit` will not be summed across during the dot product Parameters ---------- - `x` [1D numpy.ndarray] - either vector or array of arrays The alternative array to perform the dot product with - `dims_to_omit` [list :: int] (optional) Which dimensions to omit Returns ------- - `Y` [1D numpy.ndarray] - the result of the dot product """ # Construct dims to perform dot product on if utils.is_arr_of_arr(x): dims = (np.arange(0, len(x)) + X.ndim - len(x)).astype(int) else: dims = np.array([1], dtype=int) x = utils.to_arr_of_arr(x) # delete ignored dims if dims_to_omit is not None: if not isinstance(dims_to_omit, list): raise ValueError("`dims_to_omit` must be a `list` of `int`") dims = np.delete(dims, dims_to_omit) if len(x) == 1: x = np.empty([0], dtype=object) else: x = np.delete(x, dims_to_omit) # compute dot product for d in range(len(x)): s = np.ones(np.ndim(X), dtype=int) s[dims[d]] = np.shape(x[d])[0] X = X * x[d].reshape(tuple(s)) # X = np.sum(X, axis=dims[d], keepdims=True) Y = np.sum(X, axis=tuple(dims.astype(int))).squeeze() # Y = np.squeeze(X) # check to see if `Y` is a scalar if np.prod(Y.shape) <= 1.0: Y = Y.item() Y = np.array([Y]).astype("float64") return Y
def process_priors(prior, n_factors): """ Helper function for formatting observations @TODO """ if utils.is_distribution(prior): prior_arr = np.empty(n_factors, dtype=object) if n_factors == 1: prior_arr[0] = prior.values.squeeze() else: for factor in range(n_factors): prior_arr[factor] = prior[factor].values.squeeze() prior = prior_arr elif not utils.is_arr_of_arr(prior): prior = utils.to_arr_of_arr(prior) return prior
def spm_dot(X, x, dims_to_omit=None): """ Dot product of a multidimensional array with `x`. The dimensions in `dims_to_omit` will not be summed across during the dot product Parameters ---------- - `x` [1D numpy.ndarray] - either vector or array of arrays The alternative array to perform the dot product with - `dims_to_omit` [list :: int] (optional) Which dimensions to omit Returns ------- - `Y` [1D numpy.ndarray] - the result of the dot product """ # Construct dims to perform dot product on if utils.is_arr_of_arr(x): # dims = list((np.arange(0, len(x)) + X.ndim - len(x)).astype(int)) dims = list(range(X.ndim - len(x),len(x)+X.ndim - len(x))) # dims = list(range(X.ndim)) else: dims = [1] x = utils.to_arr_of_arr(x) if dims_to_omit is not None: arg_list = [X, list(range(X.ndim))] + list(chain(*([x[xdim_i],[dims[xdim_i]]] for xdim_i in range(len(x)) if xdim_i not in dims_to_omit))) + [dims_to_omit] else: arg_list = [X, list(range(X.ndim))] + list(chain(*([x[xdim_i],[dims[xdim_i]]] for xdim_i in range(len(x))))) + [[0]] Y = np.einsum(*arg_list) # check to see if `Y` is a scalar if np.prod(Y.shape) <= 1.0: Y = Y.item() Y = np.array([Y]).astype("float64") return Y
def run_mmp(lh_seq, B, policy, prev_actions=None, prior=None, num_iter=10, grad_descent=False, tau=0.25, last_timestep=False, save_vfe_seq=False): """ Marginal message passing scheme for updating posterior beliefs about multi-factor hidden states over time, conditioned on a particular policy. Parameters: -------------- `lh_seq`[numpy object array]: Likelihoods of hidden state factors given a sequence of observations over time. This is logged beforehand `B`[numpy object array]: Transition likelihood of the generative model, mapping from hidden states at T to hidden states at T+1. One B matrix per modality (e.g. `B[f]` corresponds to f-th factor's B matrix) This is used in inference to compute the 'forward' and 'backward' messages conveyed between beliefs about temporally-adjacent timepoints. `policy` [2-D numpy.ndarray]: Matrix of shape (policy_len, num_control_factors) that indicates the indices of each action (control state index) upon timestep t and control_factor f in the element `policy[t,f]` for a given policy. `prev_actions` [None or 2-D numpy.ndarray]: If provided, should be a matrix of previous actions of shape (infer_len, num_control_factors) taht indicates the indices of each action (control state index) taken in the past (up until the current timestep). `prior`[None or numpy object array]: If provided, this a numpy object array with one sub-array per hidden state factor, that stores the prior beliefs about initial states (at t = 0, relative to `infer_len`). `num_iter`[Int]: Number of variational iterations `grad_descent` [Bool]: Flag for whether to use gradient descent (predictive coding style) `tau` [Float]: Decay constant for use in `grad_descent` version `last_timestep` [Bool]: Flag for whether we are at the last timestep of belief updating `save_vfe_seq` [Bool]: Flag for whether to save the sequence of variational free energies over time (for this policy). If `False`, then VFE is integrated across time/iterations. Returns: -------------- `qs_seq`[list]: the sequence of beliefs about the different hidden state factors over time, one multi-factor posterior belief per timestep in `infer_len` `F`[Float or list, depending on setting of save_vfe_seq] """ # window past_len = len(lh_seq) future_len = policy.shape[0] if last_timestep: infer_len = past_len + future_len - 1 else: infer_len = past_len + future_len future_cutoff = past_len + future_len - 2 # dimensions _, num_states, _, num_factors = get_model_dimensions(A=None, B=B) B = to_arr_of_arr(B) # beliefs qs_seq = obj_array(infer_len) for t in range(infer_len): qs_seq[t] = obj_array_uniform(num_states) # last message qs_T = obj_array_zeros(num_states) # prior if prior is None: prior = obj_array_uniform(num_states) # transposed transition trans_B = obj_array(num_factors) for f in range(num_factors): trans_B[f] = spm_norm(np.swapaxes(B[f], 0, 1)) # full policy if prev_actions is None: prev_actions = np.zeros((past_len, policy.shape[1])) policy = np.vstack((prev_actions, policy)) # initialise variational free energy of policy (accumulated over time) if save_vfe_seq: F = [] F.append(0.0) else: F = 0.0 for itr in range(num_iter): for t in range(infer_len): for f in range(num_factors): # likelihood if t < past_len: lnA = spm_log(spm_dot(lh_seq[t], qs_seq[t], [f])) else: lnA = np.zeros(num_states[f]) # past message if t == 0: lnB_past = spm_log(prior[f]) else: past_msg = B[f][:, :, int(policy[t - 1, f])].dot(qs_seq[t - 1][f]) lnB_past = spm_log(past_msg) # future message if t >= future_cutoff: lnB_future = qs_T[f] else: future_msg = trans_B[f][:, :, int(policy[t, f])].dot( qs_seq[t + 1][f]) lnB_future = spm_log(future_msg) # inference if grad_descent: lnqs = spm_log(qs_seq[t][f]) coeff = 1 if (t >= future_cutoff) else 2 err = (coeff * lnA + lnB_past + lnB_future) - coeff * lnqs err -= err.mean() lnqs = lnqs + tau * err qs_seq[t][f] = softmax(lnqs) if (t == 0) or (t == (infer_len - 1)): F += +0.5 * lnqs.dot(0.5 * err) else: F += lnqs.dot( 0.5 * (err - (num_factors - 1) * lnA / num_factors) ) # @NOTE: not sure why Karl does this in SPM_MDP_VB_X, we should look into this else: qs_seq[t][f] = softmax(lnA + lnB_past + lnB_future) if not grad_descent: if save_vfe_seq: if t < past_len: F.append( F[-1] + calc_free_energy(qs_seq[t], prior, num_factors, likelihood=spm_log(lh_seq[t]))[0]) else: F.append( F[-1] + calc_free_energy(qs_seq[t], prior, num_factors)[0]) else: if t < past_len: F += calc_free_energy(qs_seq[t], prior, num_factors, likelihood=spm_log(lh_seq[t])) else: F += calc_free_energy(qs_seq[t], prior, num_factors) return qs_seq, F
def spm_dot(X, x, dims_to_omit=None, obs_mode=False): """ Dot product of a multidimensional array with `x`. The dimensions in `dims_to_omit` will not be summed across during the dot product #TODO: we should look for an alternative to obs_mode Parameters ---------- - `x` [1D numpy.ndarray] - either vector or array of arrays The alternative array to perform the dot product with - `dims_to_omit` [list :: int] (optional) Which dimensions to omit Returns ------- - `Y` [1D numpy.ndarray] - the result of the dot product """ # Construct dims to perform dot product on if utils.is_arr_of_arr(x): dims = (np.arange(0, len(x)) + X.ndim - len(x)).astype(int) else: if obs_mode is True: """ @NOTE Case when you're getting the likelihood of an observation under the generative model. Equivalent to something like self.values[np.where(x),:] when `x` is a discrete 'one-hot' observation vector """ dims = np.array([0], dtype=int) else: """ @NOTE Case when `x` leading dimension matches the lagging dimension of `values` E.g. a more 'classical' dot product of a likelihood with hidden states """ dims = np.array([1], dtype=int) x = utils.to_arr_of_arr(x) # delete ignored dims if dims_to_omit is not None: if not isinstance(dims_to_omit, list): raise ValueError("`dims_to_omit` must be a `list` of `int`") dims = np.delete(dims, dims_to_omit) if len(x) == 1: x = np.empty([0], dtype=object) else: x = np.delete(x, dims_to_omit) # compute dot product for d in range(len(x)): s = np.ones(np.ndim(X), dtype=int) s[dims[d]] = np.shape(x[d])[0] X = X * x[d].reshape(tuple(s)) X = np.sum(X, axis=dims[d], keepdims=True) Y = np.squeeze(X) # check to see if `Y` is a scalar if np.prod(Y.shape) <= 1.0: Y = Y.item() Y = np.array([Y]).astype("float64") return Y
def spm_dot(X, y, dims_to_omit=None, obs_mode=False): """ Dot product of a multidimensional array `X` with `y` The dimensions in `dims_to_omit` will not be summed across during dot product @TODO: we need documentation describing `obs_mode` Ideally, we could find a way to avoid it altogether Parameters ---------- `y` [1D numpy.ndarray] Either vector or array of arrays `dims_to_omit` [list :: int] (optional) Which dimensions to omit """ X = utils.to_numpy(X) y = utils.to_numpy(y) # if `X` is array of array, we need to construct the dims to sum if utils.is_arr_of_arr(X): dims = (np.arange(0, len(y)) + X.ndim - len(y)).astype(int) else: """ Deal with particular use case - see above @TODO """ if obs_mode is True: """ Case when you're getting the likelihood of an observation under model. Equivalent to something like self.values[np.where(x),:] where `y` is a discrete 'one-hot' observation vector """ dims = np.array([0], dtype=int) else: """ Case when `y` leading dimension matches the lagging dimension of `values` E.g. a more 'classical' dot product of a likelihood with hidden states """ dims = np.array([1], dtype=int) # convert `y` to array of array y = utils.to_arr_of_arr(y) # omit dims not needed for dot product if dims_to_omit is not None: if not isinstance(dims_to_omit, list): raise ValueError("`dims_to_omit` must be a `list` of `int`") # delete dims dims = np.delete(dims, dims_to_omit) if len(y) == 1: y = np.empty([0], dtype=object) else: y = np.delete(y, dims_to_omit) print(dims) # perform dot product for d in range(len(y)): s = np.ones(np.ndim(X), dtype=int) s[dims[d]] = np.shape(y[d])[0] X = X * y[d].reshape(tuple(s)) X = np.sum(X, axis=dims[d], keepdims=True) X = np.squeeze(X) # perform check to see if `x` is a scalar if np.prod(X.shape) <= 1.0: X = X.item() X = np.array([X]).astype("float64") return X
def run_mmp_old( A, B, obs_t, policy, curr_t, t_horizon, T, qs_bma=None, prior=None, num_iter=10, dF=1.0, dF_tol=0.001, previous_actions=None, use_gradient_descent=False, tau=0.25, ): """ Optimise marginal posterior beliefs about hidden states using marginal message-passing scheme (MMP) developed by Thomas Parr and colleagues, see https://github.com/tejparr/nmpassing Parameters ---------- - 'A' [numpy nd.array (matrix or tensor or array-of-arrays)]: Observation likelihood of the generative model, mapping from hidden states to observations. Used in inference to get the likelihood of an observation, under different hidden state configurations. - 'B' [numpy.ndarray (tensor or array-of-arrays)]: Transition likelihood of the generative model, mapping from hidden states at t to hidden states at t+1. Used in inference to get expected future (or past) hidden states, given past (or future) hidden states (or expectations thereof). - 'obs_t' [list of length t_horizon of numpy 1D array or array of arrays (with 1D numpy array entries)]: Sequence of observations sampled from beginning of time horizon the current timestep t. The first observation (the start of the time horizon) is either the first timestep of the generative process or the first timestep of the policy horizon (whichever is closer to 'curr_t' in time). The observations over time are stored as a list of numpy arrays, where in case of multi-modalities each numpy array is an array-of-arrays, with one 1D numpy.ndarray for each modality. In the case of a single modality, each observation is a single 1D numpy.ndarray. - 'policy' [2D np.ndarray]: Array of actions constituting a single policy. Policy is a shape (n_steps, n_control_factors) numpy.ndarray, the values of which indicate actions along a given control factor (column index) at a given timestep (row index). - 'curr_t' [int]: Current timestep (relative to the 'absolute' time of the generative process). - 't_horizon'[int]: Temporal horizon of inference for states and policies. - 'T' [int]: Temporal horizon of the generative process (absolute time) - `qs_bma` [numpy 1D array, array of arrays (with 1D numpy array entries) or None]: - 'prior' [numpy 1D array, array of arrays (with 1D numpy array entries) or None]: Prior beliefs of the agent at the beginning of the time horizon, to be integrated with the marginal likelihood to obtain posterior at the first timestep. If absent, prior is set to be a uniform distribution over hidden states (identical to the initialisation of the posterior. -'num_iter' [int]: Number of variational iterations to run. (optional) -'dF' [float]: Starting free energy gradient (dF/dt) before updating in the course of gradient descent. (optional) -'dF_tol' [float]: Threshold value of the gradient of the variational free energy (dF/dt), to be checked at each iteration. If dF <= dF_tol, the iterations are halted pre-emptively and the final marginal posterior belief(s) is(are) returned. (optional) -'previous_actions' [numpy.ndarray with shape (num_steps, n_control_factors) or None]: Array of previous actions, which can be used to constrain the 'past' messages in inference to only consider states of affairs that were possible under actions that are known to have been taken. The first dimension of previous-arrays (previous_actions.shape[0]) encodes how far back in time the agent is considering. The first timestep of this either corresponds to either the first timestep of the generative process or the first timestep of the policy horizon (whichever is sooner in time). (optional) -'use_gradient_descent' [bool]: Flag to indicate whether to use gradient descent to optimise posterior beliefs. -'tau' [float]: Learning rate for gradient descent (only used if use_gradient_descent is True) Returns ---------- -'qs' [list of length T of numpy 1D arrays or array of arrays (with 1D numpy array entries): Marginal posterior beliefs over hidden states (single- or multi-factor) achieved via marginal message pasing -'qss' [list of lists of length T of numpy 1D arrays or array of arrays (with 1D numpy array entries): Marginal posterior beliefs about hidden states (single- or multi-factor) held at each timepoint, *about* each timepoint of the observation sequence -'F' [2D np.ndarray]: Variational free energy of beliefs about hidden states, indexed by time point and variational iteration -'F_pol' [float]: Total free energy of the policy under consideration. """ # get temporal window for inference min_time = max(0, curr_t - t_horizon) max_time = min(T, curr_t + t_horizon) window_idxs = np.array([t for t in range(min_time, max_time)]) window_len = len(window_idxs) # TODO: needs a better name - the point at which we ignore future messages future_cutoff = window_len - 1 inference_len = window_len + 1 obs_seq_len = len(obs_t) # get relevant observations, given our current time point if curr_t == 0: obs_range = [0] else: min_obs_idx = max(0, curr_t - t_horizon) max_obs_idx = curr_t + 1 obs_range = range(min_obs_idx, max_obs_idx) # get model dimensions # TODO: make a general function in `utils` for extracting model dimensions if utils.is_arr_of_arr(obs_t[0]): num_obs = [obs.shape[0] for obs in obs_t[0]] else: num_obs = [obs_t[0].shape[0]] if utils.is_arr_of_arr(B): num_states = [b.shape[0] for b in B] else: num_states = [B[0].shape[0]] B = utils.to_arr_of_arr(B) num_modalities = len(num_obs) num_factors = len(num_states) """ =========== Step 1 =========== Calculate likelihood Loop over modalities and use assumption of independence among observation modalities to combine each modality-specific likelihood into a single joint likelihood over hidden states """ # likelihood of observations under configurations of hidden states (over time) likelihood = np.empty(len(obs_range), dtype=object) for t, obs in enumerate(obs_range): # likelihood_t = np.ones(tuple(num_states)) # if num_modalities == 1: # likelihood_t *= spm_dot(A[0], obs_t[obs], obs_mode=True) # else: # for modality in range(num_modalities): # likelihood_t *= spm_dot(A[modality], obs_t[obs][modality], obs_mode=True) likelihood_t = get_joint_likelihood(A, obs_t, num_states) # The Thomas Parr MMP version, you log the likelihood first # likelihood[t] = np.log(likelihood_t + 1e-16) # Karl SPM version, logging doesn't happen until *after* the dotting with the posterior likelihood[t] = likelihood_t """ =========== Step 2 =========== Initialise a flat posterior (and prior if necessary) If a prior is not provided, initialise a uniform prior """ qs = [np.empty(num_factors, dtype=object) for i in range(inference_len)] for t in range(inference_len): # if t == window_len: # # final message is zeros - has no effect on inference # # TODO: this may be redundant now that we skip last step # for f in range(num_factors): # qs[t][f] = np.zeros(num_states[f]) # else: # for f in range(num_factors): # qs[t][f] = np.ones(num_states[f]) / num_states[f] for f in range(num_factors): qs[t][f] = np.ones(num_states[f]) / num_states[f] if prior is None: prior = np.empty(num_factors, dtype=object) for f in range(num_factors): prior[f] = np.ones(num_states[f]) / num_states[f] """ =========== Step 3 =========== Create a normalized transpose of the transition distribution `B_transposed` Used for computing backwards messages 'from the future' """ B_transposed = np.empty(num_factors, dtype=object) for f in range(num_factors): B_transposed[f] = np.zeros_like(B[f]) for u in range(B[f].shape[2]): B_transposed[f][:, :, u] = spm_norm(B[f][:, :, u].T) # zero out final message # TODO: may be redundant now we skip final step last_message = np.empty(num_factors, dtype=object) for f in range(num_factors): last_message[f] = np.zeros(num_states[f]) # if previous actions not given, zero out to stop any influence on inference if previous_actions is None: previous_actions = np.zeros((1, policy.shape[1])) full_policy = np.vstack((previous_actions, policy)) # print(full_policy.shape) """ =========== Step 3 =========== Loop over time indices of time window, updating posterior as we go This includes past time steps and future time steps """ qss = [[] for i in range(num_iter)] free_energy = np.zeros((len(qs), num_iter)) free_energy_pol = 0.0 # print(obs_seq_len) print('Full policy history') print('------------------') print(full_policy) for n in range(num_iter): for t in range(inference_len): lnB_past_tensor = np.empty(num_factors, dtype=object) for f in range(num_factors): # if t == 0 and n == 0: # print(f"qs at time t = {t}, factor f = {f}, iteration i = {n}: {qs[t][f]}") """ =========== Step 3.a =========== Calculate likelihood """ if t < len(obs_range): # if t < len(obs_seq_len): # Thomas Parr MMP version # lnA = spm_dot(likelihood[t], qs[t], [f]) # Karl SPM version lnA = np.log(spm_dot(likelihood[t], qs[t], [f]) + 1e-16) else: lnA = np.zeros(num_states[f]) if t == 1 and n == 0: # pass print( f"lnA at time t = {t}, factor f = {f}, iteration i = {n}: {lnA}" ) # print(f"lnA at time t = {t}, factor f = {f}, iteration i = {n}: {lnA}") """ =========== Step 3.b =========== Calculate past message """ if t == 0 and window_idxs[0] == 0: lnB_past = np.log(prior[f] + 1e-16) else: # Thomas Parr MMP version # lnB_past = 0.5 * np.log(B[f][:, :, full_policy[t - 1, f]].dot(qs[t - 1][f]) + 1e-16) # Karl SPM version if t == 1 and n == 0 and f == 1: print('past action:') print('-------------') print(full_policy[t - 1, :]) print(B[f][:, :, 0]) print(B[f][:, :, 1]) print(qs[t - 1][f]) lnB_past = np.log( B[f][:, :, full_policy[t - 1, f]].dot(qs[t - 1][f]) + 1e-16) # if t == 0: # print( # f"qs_t_1 at time t = {t}, factor f = {f}, iteration i = {n}: {qs[t - 1][f]}" # ) if t == 1 and n == 0: print( f"lnB_past at time t = {t}, factor f = {f}, iteration i = {n}: {lnB_past}" ) """ =========== Step 3.c =========== Calculate future message """ if t >= future_cutoff: # TODO: this is redundant - not used in code lnB_future = last_message[f] else: # Thomas Parr MMP version # B_future = B_transposed[f][:, :, int(full_policy[t, f])].dot(qs[t + 1][f]) # lnB_future = 0.5 * np.log(B_future + 1e-16) # Karl Friston SPM version B_future = B_transposed[f][:, :, int(full_policy[t, f])].dot( qs[t + 1][f]) lnB_future = np.log(B_future + 1e-16) # Thomas Parr MMP version # lnB_past_tensor[f] = 2 * lnBpast # Karl SPM version lnB_past_tensor[f] = lnB_past """ =========== Step 3.d =========== Update posterior """ if use_gradient_descent: lns = np.log(qs[t][f] + 1e-16) # Thomas Parr MMP version # error = (lnA + lnBpast + lnBfuture) - lns # Karl SPM version if t >= future_cutoff: error = lnA + lnB_past - lns else: error = (2 * lnA + lnB_past + lnB_future) - 2 * lns # print(f"prediction error at time t = {t}, factor f = {f}, iteration i = {n}: {error}") # print(f"OG {t} {f} {error}") error -= error.mean() lns = lns + tau * error qs_t_f = softmax(lns) free_energy_pol += 0.5 * qs[t][f].dot(error) qs[t][f] = qs_t_f else: qs[t][f] = softmax(lnA + lnB_past + lnB_future) # TODO: probably works anyways # free_energy[t, n] = calc_free_energy(qs[t], lnB_past_tensor, num_factors, likelihood[t]) # free_energy_pol += F[t, n] qss[n].append(qs) return qs, qss, free_energy, free_energy_pol
def run_mmp( A, B, obs_t, policy, curr_t, t_horizon, T, qs_bma=None, prior=None, num_iter=10, dF=1.0, dF_tol=0.001, previous_actions=None, use_gradient_descent=False, tau=0.25, ): """ Optimise marginal posterior beliefs about hidden states using marginal message-passing scheme (MMP) developed by Thomas Parr and colleagues, see https://github.com/tejparr/nmpassing Parameters ---------- - 'A' [numpy nd.array (matrix or tensor or array-of-arrays)]: Observation likelihood of the generative model, mapping from hidden states to observations. Used in inference to get the likelihood of an observation, under different hidden state configurations. - 'B' [numpy.ndarray (tensor or array-of-arrays)]: Transition likelihood of the generative model, mapping from hidden states at t to hidden states at t+1. Used in inference to get expected future (or past) hidden states, given past (or future) hidden states (or expectations thereof). - 'obs_t' [list of length t_horizon of numpy 1D array or array of arrays (with 1D numpy array entries)]: Sequence of observations sampled from beginning of time horizon the current timestep t. The first observation (the start of the time horizon) is either the first timestep of the generative process or the first timestep of the policy horizon (whichever is closer to 'curr_t' in time). The observations over time are stored as a list of numpy arrays, where in case of multi-modalities each numpy array is an array-of-arrays, with one 1D numpy.ndarray for each modality. In the case of a single modality, each observation is a single 1D numpy.ndarray. - 'policy' [2D np.ndarray]: Array of actions constituting a single policy. Policy is a shape (n_steps, n_control_factors) numpy.ndarray, the values of which indicate actions along a given control factor (column index) at a given timestep (row index). - 'curr_t' [int]: Current timestep (relative to the 'absolute' time of the generative process). - 't_horizon'[int]: Temporal horizon of inference for states and policies. - 'T' [int]: Temporal horizon of the generative process (absolute time) - `qs_bma` [numpy 1D array, array of arrays (with 1D numpy array entries) or None]: - 'prior' [numpy 1D array, array of arrays (with 1D numpy array entries) or None]: Prior beliefs of the agent at the beginning of the time horizon, to be integrated with the marginal likelihood to obtain posterior at the first timestep. If absent, prior is set to be a uniform distribution over hidden states (identical to the initialisation of the posterior. -'num_iter' [int]: Number of variational iterations to run. (optional) -'dF' [float]: Starting free energy gradient (dF/dt) before updating in the course of gradient descent. (optional) -'dF_tol' [float]: Threshold value of the gradient of the variational free energy (dF/dt), to be checked at each iteration. If dF <= dF_tol, the iterations are halted pre-emptively and the final marginal posterior belief(s) is(are) returned. (optional) -'previous_actions' [numpy.ndarray with shape (num_steps, n_control_factors) or None]: Array of previous actions, which can be used to constrain the 'past' messages in inference to only consider states of affairs that were possible under actions that are known to have been taken. The first dimension of previous-arrays (previous_actions.shape[0]) encodes how far back in time the agent is considering. The first timestep of this either corresponds to either the first timestep of the generative process or the f first timestep of the policy horizon (whichever is sooner in time). (optional) -'use_gradient_descent' [bool]: Flag to indicate whether to use gradient descent to optimise posterior beliefs. -'tau' [float]: Learning rate for gradient descent (only used if use_gradient_descent is True) Returns ---------- -'qs' [list of length T of numpy 1D arrays or array of arrays (with 1D numpy array entries): Marginal posterior beliefs over hidden states (single- or multi-factor) achieved via marginal message pasing -'qss' [list of lists of length T of numpy 1D arrays or array of arrays (with 1D numpy array entries): Marginal posterior beliefs about hidden states (single- or multi-factor) held at each timepoint, *about* each timepoint of the observation sequence -'F' [2D np.ndarray]: Variational free energy of beliefs about hidden states, indexed by time point and variational iteration -'F_pol' [float]: Total free energy of the policy under consideration. """ # get model dimensions time_window_idxs = np.array([ i for i in range(max(0, curr_t - t_horizon), min(T, curr_t + t_horizon)) ]) window_len = len(time_window_idxs) print("t_horizon ", t_horizon) print("window_len ", window_len) if utils.is_arr_of_arr(obs_t[0]): n_observations = [obs_array_i.shape[0] for obs_array_i in obs_t[0]] else: n_observations = [obs_t[0].shape[0]] if utils.is_arr_of_arr(B): n_states = [sub_B.shape[0] for sub_B in B] else: n_states = [B[0].shape[0]] B = utils.to_arr_of_arr(B) n_modalities = len(n_observations) n_factors = len(n_states) """ =========== Step 1 =========== Loop over the observation modalities and use assumption of independence among observation modalities to multiply each modality-specific likelihood onto a single joint likelihood over hidden states [shape = n_states] """ # compute time-window, taking into account boundary conditions if curr_t == 0: obs_range = [0] else: obs_range = range(max(0, curr_t - t_horizon), curr_t + 1) # likelihood of observations under configurations of hidden causes (over time) likelihood = np.empty(len(obs_range), dtype=object) for t in range(len(obs_range)): # likelihood_t = np.ones(tuple(n_states)) # if n_modalities == 1: # likelihood_t *= spm_dot(A, obs_t[obs_range[t]], obs_mode=True) # else: # for modality in range(n_modalities): # likelihood_t *= spm_dot(A[modality], obs_t[obs_range[t]][modality], obs_mode=True) likelihood_t = get_joint_likelihood(A, obs_t[obs_range[t]], n_states) # print(f"likelihood (pre-logging) {likelihood_t}") # likelihood[t] = np.log(likelihood_t + 1e-16) # The Thomas Parr MMP version, you log the likelihood first likelihood[ t] = likelihood_t # Karl SPM version, logging doesn't happen until *after* the dotting with the posterior """ =========== Step 2 =========== Create a flat posterior (and prior if necessary) If prior is not provided, initialise prior to be identical to posterior (namely, a flat categorical distribution). Also make a normalized version of the transpose of the transition likelihood (for computing backwards messages 'from the future') called `B_t` """ qs = [np.empty(n_factors, dtype=object) for i in range(window_len + 1)] print(len(qs)) for t in range(window_len + 1): if t == window_len: for f in range(n_factors): qs[t][f] = np.zeros(n_states[f]) else: for f in range(n_factors): qs[t][f] = np.ones(n_states[f]) / n_states[f] if prior is None: prior = np.empty(n_factors, dtype=object) for f in range(n_factors): prior[f] = np.ones(n_states[f]) / n_states[f] if n_factors == 1: B_t = np.zeros_like(B) for u in range(B.shape[2]): B_t[:, :, u] = spm_norm(B[:, :, u].T) elif n_factors > 1: B_t = np.empty(n_factors, dtype=object) for f in range(n_factors): B_t[f] = np.zeros_like(B[f]) for u in range(B[f].shape[2]): B_t[f][:, :, u] = spm_norm(B[f][:, :, u].T) # set final future message as all ones at the time horizon (no information from beyond the horizon) last_message = np.empty(n_factors, dtype=object) for f in range(n_factors): last_message[f] = np.zeros(n_states[f]) """ =========== Step 3 =========== Loop over time indices of time window, which includes time before the policy horizon as well as including the policy horizon n_steps, n_factors [0 1 2 0; 1 2 0 1] """ if previous_actions is None: previous_actions = np.zeros((1, policy.shape[1])) full_policy = np.vstack((previous_actions, policy)) # print(f"full_policy shape {full_policy.shape}") qss = [[] for i in range(num_iter)] F = np.zeros((len(qs), num_iter)) F_pol = 0.0 print('length of qs:', len(qs)) # print(f"length obs_t {len(obs_t)}") for n in range(num_iter): for t in range(0, len(qs)): # for t in range(0, len(qs)): lnBpast_tensor = np.empty(n_factors, dtype=object) for f in range(n_factors): if t < len( obs_t ): # this is because of Python indexing (when t == len(obs_t)-1, we're at t == curr_t) print(t) # if t <= len(obs_t): # print(f"t index {t}") # print(f"length likelihood {len(likelihood)}") # print(f"length qs {len(qs)}") # lnA = spm_dot(likelihood[t], qs[t], [f]) # the Thomas Parr MMP version lnA = np.log(spm_dot(likelihood[t], qs[t], [f]) + 1e-16) if t == 2 and f == 0: print(f"lnA at time t = {t}, factor f = {f}: {lnA}") else: lnA = np.zeros(n_states[f]) if t == 0: lnBpast = np.log(prior[f] + 1e-16) else: # lnBpast = 0.5 * np.log( # B[f][:, :, full_policy[t - 1, f]].dot(qs[t - 1][f]) + 1e-16 # ) # the Thomas Parr MMP version lnBpast = np.log( B[f][:, :, full_policy[t - 1, f]].dot(qs[t - 1][f]) + 1e-16) # the Karl SPM version if t == 2 and f == 0: print( f"lnBpast at time t = {t}, factor f = {f}: {lnBpast}") # print(f"lnBpast at time t = {t}, factor f = {f}: {lnBpast}") # this is never reached if t >= len( qs ) - 2: # if we're at the end of the inference chain (at the current moment), the last message is just zeros lnBfuture = last_message[f] print('At final timestep!') # print(f"lnBfuture at time t = {t}, factor f = {f}: {lnBfuture}") else: # if t == 0 and f == 0: # print(B_t[f][:, :, int(full_policy[t, f])]) # print(qs[t + 1][f]) # lnBfuture = 0.5 * np.log( # B_t[f][:, :, int(full_policy[t, f])].dot(qs[t + 1][f]) + 1e-16 # ) # the Thomas Parr MMP version lnBfuture = np.log( B_t[f][:, :, int(full_policy[t, f])].dot( qs[t + 1][f]) + 1e-16 ) # the Karl SPM version (without the 0.5 in front) if t == 2 and f == 0: print( f"lnBfuture at time t = {t}, factor f = {f}: {lnBfuture}" ) # if t == 0 and f == 0: # print(f"lnBfuture at time t= {t}: {lnBfuture}") # lnBpast_tensor[f] = 2 * lnBpast # the Thomas Parr MMP version lnBpast_tensor[f] = lnBpast # the Karl version if use_gradient_descent: # gradients lns = np.log(qs[t][f] + 1e-16) # current estimate # e = (lnA + lnBpast + lnBfuture) - lns # prediction error, Thomas Parr version if t >= len(qs) - 2: e = lnA + lnBpast - lns else: e = (2 * lnA + lnBpast + lnBfuture ) - 2 * lns # prediction error, Karl SPM version e -= e.mean() # Karl SPM version print( f"prediction error at time t = {t}, factor f = {f}: {e}" ) lns += tau * e # increment the current (log) belief with the prediction error qs_t_f = softmax(lns) F_pol += 0.5 * qs[t][f].dot(e) qs[t][f] = qs_t_f else: # free energy minimum for the factor in question qs[t][f] = softmax(lnA + lnBpast + lnBfuture) # F[t, n] = calc_free_energy(qs[t], lnBpast_tensor, n_factors, likelihood[t]) # F_pol += F[t, n] qss[n].append(qs) return qs, qss, F, F_pol
def update_posterior_states_v2( A, B, prev_obs, policies, prev_actions=None, prior=None, return_numpy=True, policy_sep_prior = True, **kwargs, ): """ Update posterior over hidden states using marginal message passing """ # safe convert to numpy A = utils.to_numpy(A) num_obs, num_states, num_modalities, num_factors = utils.get_model_dimensions(A, B) A = utils.to_arr_of_arr(A) B = utils.to_arr_of_arr(B) prev_obs = utils.process_observation_seq(prev_obs, num_modalities, num_obs) if prior is not None: if policy_sep_prior: for p_idx, policy in enumerate(policies): prior[p_idx] = utils.process_prior(prior[p_idx], num_factors) else: prior = utils.process_prior(prior, num_factors) lh_seq = get_joint_likelihood_seq(A, prev_obs, num_states) if prev_actions is not None: prev_actions = np.stack(prev_actions,0) qs_seq_pi = utils.obj_array(len(policies)) F = np.zeros(len(policies)) # variational free energy of policies if policy_sep_prior: for p_idx, policy in enumerate(policies): # get sequence and the free energy for policy qs_seq_pi[p_idx], F[p_idx] = run_mmp( lh_seq, B, policy, prev_actions=prev_actions, prior=prior[p_idx], **kwargs ) else: for p_idx, policy in enumerate(policies): # get sequence and the free energy for policy qs_seq_pi[p_idx], F[p_idx] = run_mmp( lh_seq, B, policy, prev_actions=prev_actions, prior=prior, **kwargs ) return qs_seq_pi, F
def run_mmp_v2(A, B, ll_seq, policy, prev_actions=None, prior=None, num_iter=10, grad_descent=False, tau=0.25): # window past_len = len(ll_seq) future_len = policy.shape[0] infer_len = past_len + future_len future_cutoff = past_len + future_len - 2 # dimensions _, num_states, _, num_factors = get_model_dimensions(A, B) A = to_arr_of_arr(A) B = to_arr_of_arr(B) # beliefs qs_seq = [np.empty(num_factors, dtype=object) for _ in range(infer_len)] for t in range(infer_len): for f in range(num_factors): qs_seq[t][f] = np.ones(num_states[f]) / num_states[f] # last message qs_T = np.empty(num_factors, dtype=object) for f in range(num_factors): qs_T[f] = np.zeros(num_states[f]) # prior if prior is None: prior = np.empty(num_factors, dtype=object) for f in range(num_factors): prior[f] = np.ones(num_states[f]) / num_states[f] # transposed transition trans_B = np.empty(num_factors, dtype=object) for f in range(num_factors): trans_B[f] = np.zeros_like(B[f]) for u in range(B[f].shape[2]): trans_B[f][:, :, u] = spm_norm(B[f][:, :, u].T) # full policy if prev_actions is None: prev_actions = np.zeros((past_len, policy.shape[1])) policy = np.vstack((prev_actions, policy)) for _ in range(num_iter): for t in range(infer_len): for f in range(num_factors): # likelihood if t < past_len: lnA = np.log(spm_dot(ll_seq[t], qs_seq[t], [f]) + 1e-16) else: lnA = np.zeros(num_states[f]) # past message if t == 0: lnB_past = np.log(prior[f] + 1e-16) else: past_msg = B[f][:, :, int(policy[t - 1, f])].dot(qs_seq[t - 1][f]) lnB_past = np.log(past_msg + 1e-16) # future message if t >= future_cutoff: lnB_future = qs_T[f] else: future_msg = trans_B[f][:, :, int(policy[t, f])].dot( qs_seq[t + 1][f]) lnB_future = np.log(future_msg + 1e-16) # inference if grad_descent: lnqs = np.log(qs_seq[t][f] + 1e-16) coeff = 1 if (t >= future_cutoff) else 2 err = (coeff * lnA + lnB_past + lnB_future) - coeff * lnqs err -= err.mean() lnqs = lnqs + tau * err qs_seq[t][f] = softmax(lnqs) else: qs_seq[t][f] = softmax(lnA + lnB_past + lnB_future) return qs_seq