def fit_bhv_model(data, model_path=bmp, targ_field='LABthetaTarget', dist_field='LABthetaDist', resp_field='LABthetaResp', prior_dict=None, stan_iters=2000, stan_chains=4, arviz=mixture_arviz, adapt_delta=.9, diagnostics=True, **stan_params): if prior_dict is None: prior_dict = default_prior_dict targs_is = data[targ_field] session_list = np.array(data[['animal', 'date']]) mapping_list = [] session_nums = np.array([], dtype=int) for i, x in enumerate(targs_is): sess = np.ones(len(x), dtype=int) * (i + 1) session_nums = np.concatenate((session_nums, sess)) indices = x.index sess_info0 = (str(session_list[i, 0]), ) * len(x) sess_info1 = (str(session_list[i, 1]), ) * len(x) mapping_list = mapping_list + list(zip(indices, sess_info0, sess_info1)) mapping_dict = {i: mapping_list[i] for i in range(len(session_nums))} targs = np.concatenate(targs_is, axis=0) dists = np.concatenate(data[dist_field], axis=0) resps = np.concatenate(data[resp_field], axis=0) errs = u.normalize_periodic_range(targs - resps) dist_errs = u.normalize_periodic_range(dists - resps) dists_per = u.normalize_periodic_range(dists - targs) stan_data = dict(T=dist_errs.shape[0], S=len(targs_is), err=errs, dist_err=dist_errs, run_ind=session_nums, dist_loc=dists_per, **prior_dict) control = { 'adapt_delta': stan_params.pop('adapt_delta', .8), 'max_treedepth': stan_params.pop('max_treedepth', 10) } sm = pickle.load(open(model_path, 'rb')) fit = sm.sampling(data=stan_data, iter=stan_iters, chains=stan_chains, control=control, **stan_params) if diagnostics: diag = ps.diagnostics.check_hmc_diagnostics(fit) else: diag = None fit_av = az.from_pystan(posterior=fit, **arviz) return fit, diag, fit_av, stan_data, mapping_dict
def compute_diff_dependence(data, targ_field='LABthetaTarget', dist_field='LABthetaDist', resp_field='LABthetaResp'): targ = np.concatenate(data[targ_field]) dist = np.concatenate(data[dist_field]) resp = np.concatenate(data[resp_field]) td_diff = u.normalize_periodic_range(targ - dist) resp_diff = u.normalize_periodic_range(targ - resp) dist_diff = u.normalize_periodic_range(dist - resp) return td_diff, resp_diff, dist_diff
def plot_error_swap_distribs_err(errs, dist_errs, axs=None, fwid=3, label='', model_data=None, color=None, model_derr=None): if axs is None: fsize = (2 * fwid, fwid) f, axs = plt.subplots(1, 2, figsize=fsize, sharey=True, sharex=True) l = axs[0].hist(errs, density=True, color=color) if model_data is not None: axs[0].hist(model_data.flatten(), histtype='step', density=True, color='k', linestyle='dashed') axs[1].hist(dist_errs, label=label, density=True, color=color) if model_derr is not None: m_derr = u.normalize_periodic_range(model_derr - model_data) axs[1].hist(m_derr.flatten(), histtype='step', density=True, color='k', linestyle='dashed') axs[1].legend(frameon=False) axs[0].set_xlabel('error (rads)') axs[0].set_ylabel('density') axs[1].set_xlabel('distractor distance (rads)') gpl.clean_plot(axs[0], 0) gpl.clean_plot(axs[1], 1) return axs
def load_bhv_data(fl, flname='bhv.mat', const_fields=('Date', 'Monkey'), extract_fields=busch_bhv_fields, add_color=True, add_err=True): bhv = sio.loadmat(os.path.join(fl, flname))['bhv'] const_dict = {cf: np.squeeze(bhv[cf][0, 0]) for cf in const_fields} trl_dict = {} for tf in extract_fields: elements = bhv['Trials'][0, 0][tf][0] for i, el in enumerate(elements): if len(el) == 0: elements[i] = np.array([[np.nan]]) trl_dict[tf] = np.squeeze(np.stack(elements, axis=0)) if add_color: targ_color = trl_dict['LABthetaTarget'] dist_color = trl_dict['LABthetaDist'] upper_col = np.zeros(len(targ_color)) lower_col = np.zeros_like(upper_col) upper_mask = trl_dict['IsUpperSample'] == 1 n_upper_mask = np.logical_not(upper_mask) upper_col[upper_mask] = targ_color[upper_mask] upper_col[n_upper_mask] = dist_color[n_upper_mask] lower_col[upper_mask] = dist_color[upper_mask] lower_col[n_upper_mask] = targ_color[n_upper_mask] trl_dict['upper_color'] = upper_col trl_dict['lower_color'] = lower_col if add_err: err = u.normalize_periodic_range(trl_dict['LABthetaTarget'] - trl_dict['LABthetaResp']) trl_dict['err'] = err return const_dict, trl_dict
def plot_error_swap_distribs(data, err_field='err', dist_field='LABthetaDist', resp_field='LABthetaResp', **kwargs): errs = np.concatenate(data[err_field]) dist_errs = np.concatenate(data[dist_field] - data[resp_field]) dist_errs = u.normalize_periodic_range(dist_errs) return plot_error_swap_distribs_err(errs, dist_errs, **kwargs)
def _get_cmean(trls, trl_cols, targ_col, all_cols, color_window=.2, positions=None): if positions is None: u_pos = [0] positions = np.zeros(trls.shape[2]) else: u_pos = np.unique(positions) out = np.zeros((len(u_pos), trls.shape[0], trls.shape[-1])) for i, pos in enumerate(u_pos): pos_mask = pos == positions col_dists = np.abs(u.normalize_periodic_range(targ_col - all_cols)) col_mask = col_dists < color_window use_cols = all_cols[col_mask] col_trl_mask = np.isin(trl_cols, use_cols) trial_mask = np.logical_and(pos_mask, col_trl_mask) pop_col = trls[:, 0, trial_mask] out[i] = np.nanmean(pop_col, axis=1) if positions is None: out = out[0] return out