def fit_comparison(sessions_A, sessions_B, agent, fig_no = 1, title = None, max_change = 0.005): ''' Fit the two groups of sessions with the specified agent and plot the results on the same axis. ''' fit_A = mf.fit_population(sessions_A, agent, max_change = max_change) fit_B = mf.fit_population(sessions_B, agent, max_change = max_change) rp.scatter_plot_comp(fit_A, fit_B, fig_no = fig_no) if title:p.title(title)
def fit_agent_to_sim_data(agent, simulated_datasets, n_eval): agent_simdata_fits = [] init_params = None for sim_data in simulated_datasets: agent_simdata_fits.append(mf.fit_population(sim_data, agent, eval_BIC=n_eval, pop_init_params=init_params)) init_params = agent_simdata_fits[-1]["pop_params"] return agent_simdata_fits
def test_population_fitting(task, agent, n_sessions=8, n_trials=1000, pop_params=None): """Simulate a set of sessions using parameters drawn from normal distributions specified by pop_params. Then fit the agent model to the simulated data and plot correspondence between true and fitted paramter values. """ sessions = simulate_sessions(task, agent, n_sessions, n_trials, pop_params) ML_fits, MAP_fits, pop_params = mf.fit_population(sessions, agent, max_iter=15) rp.plot_true_fitted_params(sessions, ML_fits, MAP_fits) return (sessions, ML_fits, MAP_fits, pop_params)
def estimate_test_time(sessions_A, sessions_B, RL_agent, LR_agent, perm_type, max_change_LR = 0.01, max_change_RL = 0.05, n_test_perm = 3, parallel = False): '''Estimate time taken per permutation to run compare_groups.''' start_time = time.time() for i in range(n_test_perm): shuffled_ses_A, shuffled_ses_B = _permuted_dataset(sessions_A, sessions_B, perm_type) mf.fit_population(shuffled_ses_A, RL_agent, eval_BIC = False, max_change = max_change_RL, parallel = parallel) mf.fit_population(shuffled_ses_B, RL_agent, eval_BIC = False, max_change = max_change_RL, parallel = parallel) mf.fit_population(shuffled_ses_A, LR_agent, eval_BIC = False, max_change = max_change_LR, parallel = parallel) mf.fit_population(shuffled_ses_B, LR_agent, eval_BIC = False, max_change = max_change_LR, parallel = parallel) pl.reversal_analysis(shuffled_ses_A, return_fits = True, by_type = False) pl.reversal_analysis(shuffled_ses_B, return_fits = True, by_type = False) print('Estimated time per permuation: ' + str((time.time() - start_time)/n_test_perm))
def plots(sessions_A, sessions_B, RL_agent, LR_agent = None, title = None, max_change_LR = 0.001, max_change_RL = 0.01, test_time = 20, parallel = False, test_data = None): if test_data: RL_fit_A = test_data['RL_fit']['fit_A'] RL_fit_B = test_data['RL_fit']['fit_B'] LR_fit_A = test_data['LR_fit']['fit_A'] LR_fit_B = test_data['LR_fit']['fit_B'] title = test_data['title'] else: RL_fit_A = mf.fit_population(sessions_A, RL_agent, max_change = max_change_RL, parallel = parallel) RL_fit_B = mf.fit_population(sessions_B, RL_agent, max_change = max_change_RL, parallel = parallel) LR_fit_A = mf.fit_population(sessions_A, LR_agent, max_change = max_change_LR, parallel = parallel) LR_fit_B = mf.fit_population(sessions_B, LR_agent, max_change = max_change_LR, parallel = parallel) trial_rate_comparison(sessions_A, sessions_B, test_time, 1, title) reversal_comparison(sessions_A, sessions_B, 2, title) rp.scatter_plot_comp(LR_fit_A, LR_fit_B, fig_no = 3) p.title(title) rp.pop_fit_comparison(RL_fit_A, RL_fit_B, fig_no = 4, normalize = False) p.suptitle(title) abs_preference_comparison(sessions_A, sessions_B, RL_fit_A, RL_fit_B, RL_agent, 5, title)
def model_comparison_robustness(sessions, agents, task, n_eval=100, n_sim=100): """ Model comparison includeing an estimation of how robust is the conlusion about which model is best. The approach taken is as follows: 1. Evaluate the quality off fit of the models to the data provided using the specified metric (e.g. BIC score) 2. Using the best fitting model generate a population of simulated datasets, each of which is the same size as the real dataset. 3. Fit all model to each simulated dataset and evaluate the BIC scores for the fit. 4. Plot the distibutions of BIC scores for each model, and the distribution of BIC score difference between the best fitting model and each other model. """ print ("Fitting real data.") model_fits = [mf.fit_population(sessions, agent, eval_BIC=n_eval) for agent in agents] best_agent_n = np.argmin([fit["BIC_score"] for fit in model_fits]) best_agent = agents[best_agent_n] best_agent_fit = model_fits[best_agent_n] simulated_datasets = [] for i in range(n_sim): simulated_datasets.append(sim_sessions_from_pop_fit(task, best_agent, best_agent_fit, use_MAP=False)) # simulated_data_fits, i, n_fits = ([], 1, len(agents) * n_sim ) # for agent in agents: # agent_simdata_fits = [] # init_params = None # for sim_data in simulated_datasets: # print('Simulated dataset fit {} of {}'.format(i, n_fits)) # agent_simdata_fits.append(mf.fit_population(sim_data, agent, # eval_BIC = n_eval, pop_init_params = init_params)) # init_params = agent_simdata_fits[-1]['pop_params'] # i += 1 # simulated_data_fits.append(agent_simdata_fits) fit_func = partial(fit_agent_to_sim_data, simulated_datasets=simulated_datasets, n_eval=n_eval) simulated_data_fits = mp_pool.map(fit_func, agents) mod_comp = { "agents": agents, "sessions": sessions, "task": task, "best_agent_n": best_agent_n, "model_fits": model_fits, "simulated_datasets": simulated_datasets, "simulated_data_fits": simulated_data_fits, } plot_BIC_dists(mod_comp) return mod_comp
def MAP_fit_test(sessions_A, sessions_B, agent, perm_type, n_resample = 1000, max_change = 0.01, parallel = False, use_median = False): ''' A test for differences in model fits between two groups of subjects which fits a single population distribution to both sets of sessions combined and then looks for differences in the distribution of MAP fits between the two groups. ''' all_sessions = sessions_A + sessions_B all_sessions_fit = mf.fit_population(all_sessions, agent, parallel = parallel, max_change = max_change) for i, MAP_fit in enumerate(all_sessions_fit['MAP_fits']): all_sessions[i].MAP_fit = MAP_fit true_MAP_fits_A = np.array([s.MAP_fit['params_T'] for s in sessions_A]) true_MAP_fits_B = np.array([s.MAP_fit['params_T'] for s in sessions_B]) if use_median: ave_func = np.median else: ave_func = np.mean true_fit_dists = np.abs(ave_func(true_MAP_fits_A, 0) - ave_func(true_MAP_fits_B, 0)) shuffled_fit_dists = np.zeros([n_resample, agent.n_params]) for i in range(n_resample): print('Evaluating permuted sessions, round: {} of {}'.format(i+1, n_resample)) shuffled_ses_A, shuffled_ses_B = _permuted_dataset(sessions_A, sessions_B, perm_type) shuffled_MAP_fits_A = np.array([s.MAP_fit['params_T'] for s in shuffled_ses_A]) shuffled_MAP_fits_B = np.array([s.MAP_fit['params_T'] for s in shuffled_ses_B]) shuffled_fit_dists[i,:] = np.abs(ave_func(shuffled_MAP_fits_A, 0) - ave_func(shuffled_MAP_fits_B, 0)) dist_ranks = sum(shuffled_fit_dists>=np.tile(true_fit_dists,(n_resample,1)),0) p_vals = dist_ranks / n_resample return p_vals
def model_fit_test(sessions_A, sessions_B, agent, perm_type, n_resample = 100, max_change = 0.001, max_iter = 300, true_init = False, parallel = True, mft = None): '''Permutation test for significant differences in model fits between two groups of sessions. If a previous model_fit_test object (mft) is passed in, additional permutations are performed and the results added to the current test. Outline of procedure: 1. Perform model fitting seperately on both groups of sessions to give mean and standard devaiation of population level distributions for each group. 2. Evaluate distance metric (KL divergence or difference of means) between these population level distibutions for each parameter. 3. Generate population of resampled groups in which sessions are randomly allocated to the A or B groups. For more information on how permutations are created see _permuted_dataset doc. 4. Perform model fitting and evalute distance metric for these resampled groups to get a distribution of the distance metric under the null hypothesis that there is no difference between groups. 5. Compare the true distance metric for each parameter with the distribution for the resampled groups to get a confidence value. ''' assert perm_type in ('within_subject', 'cross_subject', 'ignore_subject'), \ 'Invalid permutation type.' if true_init: comb_fit = mf.fit_population(sessions_A + sessions_B, agent, eval_BIC = False, parallel = parallel, max_change = max_change * 2, max_iter = max_iter) init_params = comb_fit['pop_params'] else: init_params = None n_params = agent.n_params if not mft: # No previously calculated permutations passed in. true_model_fit_A = mf.fit_population(sessions_A, agent, eval_BIC = False, parallel = parallel, max_change = max_change, max_iter = max_iter, pop_init_params = init_params) true_model_fit_B = mf.fit_population(sessions_B, agent, eval_BIC = False, parallel = parallel, max_change = max_change, max_iter = max_iter, pop_init_params = init_params) true_distances_KL = _population_fit_distance(true_model_fit_A, true_model_fit_B, 'KL') true_distances_means = _population_fit_distance(true_model_fit_A, true_model_fit_B, 'means') if isinstance(agent, _RL_agent): # Evaluate mean abs. preference. true_preferences_A = rp.abs_preference_plot(sessions_A, true_model_fit_A, agent, to_plot = False) true_preferences_B = rp.abs_preference_plot(sessions_B, true_model_fit_B, agent, to_plot = False) true_pref_dists = np.abs(np.array(true_preferences_A) - np.array(true_preferences_B)) else: # Previously calculated permutation test passed in. n_resample_orig = mft['n_resample'] n_resample = n_resample + n_resample_orig true_model_fit_A, true_model_fit_B = mft['fit_A'], mft['fit_B'] true_distances_KL, true_distances_means = mft['KL_data']['true_distances'], mft['means_data']['true_distances'] if isinstance(agent, _RL_agent): true_preferences_A, true_preferences_B = mft['pref_data']['true_preferences_A'], mft['pref_data']['true_preferences_B'] true_pref_dists = mft['pref_data']['true_distances'] # Creat structures to store permuted data. shuffled_distances_KL = np.zeros((n_resample, n_params)) shuffled_distances_means = np.zeros((n_resample, n_params)) shuffled_pref_dists = np.zeros((n_resample, 2)) shuffled_fits = [] if not mft: perm_indices = range(n_resample) else: # fill first part of arrays with previously calculated data. perm_indices = range(n_resample_orig, n_resample) shuffled_distances_KL [:n_resample_orig,:] = mft['KL_data'] ['shuffled_distances'] shuffled_distances_means[:n_resample_orig,:] = mft['means_data']['shuffled_distances'] shuffled_fits += mft['shuffled_fits'] if isinstance(agent, _RL_agent): shuffled_pref_dists [:n_resample_orig,:] = mft['pref_data'] ['shuffled_distances'] for i in perm_indices: print('Fitting permuted sessions, round: {} of {}'.format(i+1, n_resample)) shuffled_ses_A, shuffled_ses_B = _permuted_dataset(sessions_A, sessions_B, perm_type) shuffled_fit_A = mf.fit_population(shuffled_ses_A, agent, eval_BIC = False, max_change = max_change, max_iter = max_iter, pop_init_params = init_params, parallel = parallel) shuffled_fit_B = mf.fit_population(shuffled_ses_B, agent, eval_BIC = False, max_change = max_change, max_iter = max_iter, pop_init_params = init_params, parallel = parallel) shuffled_fits.append(({'means':shuffled_fit_A['pop_params']['means'],'SDs':shuffled_fit_A['pop_params']['SDs']}, {'means':shuffled_fit_B['pop_params']['means'],'SDs':shuffled_fit_B['pop_params']['SDs']})) shuffled_distances_KL[i,:] = _population_fit_distance(shuffled_fit_A, shuffled_fit_B, 'KL') shuffled_distances_means[i,:] = _population_fit_distance(shuffled_fit_A, shuffled_fit_B, 'means') if isinstance(agent, _RL_agent): shuffled_preferences_A = rp.abs_preference_plot(shuffled_ses_A, shuffled_fit_A, agent, to_plot = False) shuffled_preferences_B = rp.abs_preference_plot(shuffled_ses_B, shuffled_fit_B, agent, to_plot = False) shuffled_pref_dists[i,:] = np.abs(np.array(shuffled_preferences_A) - np.array(shuffled_preferences_B)) dist_ranks_KL = sum(shuffled_distances_KL>=np.tile(true_distances_KL,(n_resample,1)),0) p_vals_KL = dist_ranks_KL / n_resample # Should this be n_resample + 1? dist_ranks_means = sum(shuffled_distances_means>=np.tile(true_distances_means,(n_resample,1)),0) p_vals_means = dist_ranks_means / n_resample # Should this be n_resample + 1? mft = {'fit_A': true_model_fit_A, 'fit_B': true_model_fit_B, 'n_resample': n_resample, 'perm_type': perm_type, 'shuffled_fits': shuffled_fits, 'KL_data': {'true_distances': true_distances_KL, 'shuffled_distances': shuffled_distances_KL, 'dist_ranks': dist_ranks_KL, 'p_vals': p_vals_KL}, 'means_data': {'true_distances': true_distances_means, 'shuffled_distances': shuffled_distances_means, 'dist_ranks': dist_ranks_means, 'p_vals': p_vals_means} } if isinstance(agent, _RL_agent): dist_ranks_pref = sum(shuffled_pref_dists>=np.tile(true_pref_dists,(n_resample,1)),0) p_vals_pref = dist_ranks_pref / n_resample # Should this be n_resample + 1? mft['pref_data'] = {'true_preferences_A' : true_preferences_A, 'true_preferences_B' : true_preferences_B, 'true_distances': true_pref_dists, 'shuffled_distances': shuffled_pref_dists, 'dist_ranks': dist_ranks_pref, 'p_vals': p_vals_pref} return mft