def estimate_test_time(sessions_A, sessions_B, RL_agent, LR_agent, perm_type, max_change_LR = 0.01, max_change_RL = 0.05, n_test_perm = 3, parallel = False): '''Estimate time taken per permutation to run compare_groups.''' start_time = time.time() for i in range(n_test_perm): shuffled_ses_A, shuffled_ses_B = _permuted_dataset(sessions_A, sessions_B, perm_type) mf.fit_population(shuffled_ses_A, RL_agent, eval_BIC = False, max_change = max_change_RL, parallel = parallel) mf.fit_population(shuffled_ses_B, RL_agent, eval_BIC = False, max_change = max_change_RL, parallel = parallel) mf.fit_population(shuffled_ses_A, LR_agent, eval_BIC = False, max_change = max_change_LR, parallel = parallel) mf.fit_population(shuffled_ses_B, LR_agent, eval_BIC = False, max_change = max_change_LR, parallel = parallel) pl.reversal_analysis(shuffled_ses_A, return_fits = True, by_type = False) pl.reversal_analysis(shuffled_ses_B, return_fits = True, by_type = False) print('Estimated time per permuation: ' + str((time.time() - start_time)/n_test_perm))
def reversal_test(sessions_A, sessions_B, perm_type, n_resample = 1000, by_type = False, groups = None): ''' Permutation test for differences in the fraction correct at end of blocks and the time constant of adaptation to block transitions. ''' fit_A = pl.reversal_analysis(sessions_A, return_fits = True, by_type = by_type) fit_B = pl.reversal_analysis(sessions_B, return_fits = True, by_type = by_type) true_reversal_fit_distances = _reversal_fit_distances(fit_A,fit_B) permuted_reversal_fit_distances = np.zeros([n_resample, 4]) for i in range(n_resample): print('Fitting permuted sessions, round: {} of {}'.format(i+1, n_resample)) shuffled_ses_A, shuffled_ses_B = _permuted_dataset(sessions_A, sessions_B, perm_type, groups) shuffled_fit_A = pl.reversal_analysis(shuffled_ses_A, return_fits = True, by_type = by_type) shuffled_fit_B = pl.reversal_analysis(shuffled_ses_B, return_fits = True, by_type = by_type) permuted_reversal_fit_distances[i,:] = _reversal_fit_distances(shuffled_fit_A, shuffled_fit_B) dist_ranks = sum(permuted_reversal_fit_distances>=np.tile(true_reversal_fit_distances,(n_resample,1)),0) p_vals = dist_ranks / float(n_resample) print('Block end choice probability P value : {}'.format(p_vals[0])) print('All reversals tau P value : {}'.format(p_vals[1])) if by_type: print('Reward probability reversal tau P value: {}'.format(p_vals[2])) print('Trans. probability reversal tau P value: {}'.format(p_vals[3])) return {'block_end_P_value': p_vals[0], 'tau_P_value' : p_vals[1]}
def reversal_comparison(sessions_A, sessions_B, fig_no = 1, title = None, groups = None): '''Plot choice trajectories around reversals for both groups. ''' pl.reversal_analysis(sessions_A, cols = 0, fig_no = fig_no, by_type = False) pl.reversal_analysis(sessions_B, cols = 1, fig_no = fig_no, by_type = False, clf = False) if title: p.title(title)