Exemple #1
0
def abs_preference_comparison(sessions_A, sessions_B, population_fit_A, population_fit_B, agent,
                               fig_no = 1, title = None):
    ''' Plot mean absolute preference of model based and model free system based on population fits.
    '''
    mean_preference_mb_A, mean_preference_td_A = rp.abs_preference_plot(sessions_A, population_fit_A, agent, to_plot = False)
    mean_preference_mb_B, mean_preference_td_B = rp.abs_preference_plot(sessions_B, population_fit_B, agent, to_plot = False)
    p.figure(fig_no)
    p.clf()
    p.bar([1  , 3],[mean_preference_mb_A, mean_preference_td_A])
    p.bar([1.8,3.8],[mean_preference_mb_B, mean_preference_td_B],color = 'r')
    p.xticks([1.8, 3.8], ['Model based', 'Model free'])
    p.xlim(0.8,4.8)
    p.ylabel('Mean abs. preference')
    if title:p.title(title)
Exemple #2
0
def model_fit_test(sessions_A, sessions_B, agent,  perm_type, n_resample = 100, 
                   max_change = 0.001, max_iter = 300, true_init = False, parallel = True, mft = None):
    '''Permutation test for significant differences in model fits between two groups of 
    sessions.  If a previous model_fit_test object (mft) is passed in, additional 
    permutations are performed and the results added to the current test.

    Outline of procedure:
    1. Perform model fitting seperately on both groups of sessions to give mean and standard
    devaiation of population level distributions for each group.
    2. Evaluate distance metric (KL divergence or difference of means) between these population
    level distibutions for each parameter.
    3. Generate population of resampled groups in which sessions are randomly allocated to 
    the A or B groups.  For more information on how permutations are created see _permuted_dataset doc.
    4. Perform model fitting and evalute distance metric for these resampled groups to get a 
    distribution of the distance metric under the null hypothesis that there is no difference 
    between groups.
    5. Compare the true distance metric for each parameter with the distribution for the 
    resampled groups to get a confidence value. 
    '''
    assert perm_type in ('within_subject', 'cross_subject', 'ignore_subject'), \
        'Invalid permutation type.'
        
    if true_init:
        comb_fit = mf.fit_population(sessions_A + sessions_B, agent, eval_BIC = False, parallel = parallel, max_change = max_change * 2, max_iter = max_iter)
        init_params = comb_fit['pop_params']
    else:
        init_params = None

    n_params = agent.n_params

    if not mft: # No previously calculated permutations passed in.

        true_model_fit_A = mf.fit_population(sessions_A, agent, eval_BIC = False, parallel = parallel, max_change = max_change, max_iter = max_iter, pop_init_params = init_params)
        true_model_fit_B = mf.fit_population(sessions_B, agent, eval_BIC = False, parallel = parallel, max_change = max_change, max_iter = max_iter, pop_init_params = init_params)

        true_distances_KL = _population_fit_distance(true_model_fit_A, true_model_fit_B, 'KL')
        true_distances_means = _population_fit_distance(true_model_fit_A, true_model_fit_B, 'means')

        if isinstance(agent, _RL_agent):  # Evaluate mean abs. preference.
            true_preferences_A = rp.abs_preference_plot(sessions_A, true_model_fit_A, agent, to_plot = False)
            true_preferences_B = rp.abs_preference_plot(sessions_B, true_model_fit_B, agent, to_plot = False)
            true_pref_dists = np.abs(np.array(true_preferences_A) - np.array(true_preferences_B))
            
    else: # Previously calculated permutation test passed in.

        n_resample_orig = mft['n_resample']
        n_resample = n_resample + n_resample_orig
        true_model_fit_A, true_model_fit_B  = mft['fit_A'], mft['fit_B']
        true_distances_KL, true_distances_means = mft['KL_data']['true_distances'], mft['means_data']['true_distances']      
        if isinstance(agent, _RL_agent):
            true_preferences_A, true_preferences_B = mft['pref_data']['true_preferences_A'], mft['pref_data']['true_preferences_B']
            true_pref_dists = mft['pref_data']['true_distances']
            
    # Creat structures to store permuted data.
    shuffled_distances_KL = np.zeros((n_resample, n_params))
    shuffled_distances_means = np.zeros((n_resample, n_params))
    shuffled_pref_dists = np.zeros((n_resample, 2))
    shuffled_fits = []

    if not mft:
        perm_indices = range(n_resample)

    else:  # fill first part of arrays with previously calculated data.
        perm_indices = range(n_resample_orig, n_resample)
        shuffled_distances_KL   [:n_resample_orig,:] = mft['KL_data']   ['shuffled_distances']
        shuffled_distances_means[:n_resample_orig,:] = mft['means_data']['shuffled_distances']
        shuffled_fits += mft['shuffled_fits']
        if isinstance(agent, _RL_agent):
            shuffled_pref_dists     [:n_resample_orig,:] = mft['pref_data'] ['shuffled_distances']

    for i in perm_indices:
        print('Fitting permuted sessions, round: {} of {}'.format(i+1, n_resample))

        shuffled_ses_A, shuffled_ses_B = _permuted_dataset(sessions_A, sessions_B, perm_type)
            
        shuffled_fit_A = mf.fit_population(shuffled_ses_A, agent, eval_BIC = False, max_change = max_change, max_iter = max_iter,
                                           pop_init_params = init_params, parallel = parallel)
        shuffled_fit_B = mf.fit_population(shuffled_ses_B, agent, eval_BIC = False, max_change = max_change, max_iter = max_iter,
                                           pop_init_params = init_params, parallel = parallel)
        shuffled_fits.append(({'means':shuffled_fit_A['pop_params']['means'],'SDs':shuffled_fit_A['pop_params']['SDs']},
                              {'means':shuffled_fit_B['pop_params']['means'],'SDs':shuffled_fit_B['pop_params']['SDs']}))
        shuffled_distances_KL[i,:]    = _population_fit_distance(shuffled_fit_A, shuffled_fit_B, 'KL')
        shuffled_distances_means[i,:] = _population_fit_distance(shuffled_fit_A, shuffled_fit_B, 'means')

        if isinstance(agent, _RL_agent): 
            shuffled_preferences_A = rp.abs_preference_plot(shuffled_ses_A, shuffled_fit_A, agent, to_plot = False)
            shuffled_preferences_B = rp.abs_preference_plot(shuffled_ses_B, shuffled_fit_B, agent, to_plot = False)
            shuffled_pref_dists[i,:] = np.abs(np.array(shuffled_preferences_A) -
                                       np.array(shuffled_preferences_B))

    dist_ranks_KL = sum(shuffled_distances_KL>=np.tile(true_distances_KL,(n_resample,1)),0)
    p_vals_KL = dist_ranks_KL / n_resample  # Should this be n_resample + 1?

    dist_ranks_means = sum(shuffled_distances_means>=np.tile(true_distances_means,(n_resample,1)),0)   
    p_vals_means = dist_ranks_means / n_resample  # Should this be n_resample + 1?

    mft =  {'fit_A': true_model_fit_A,
            'fit_B': true_model_fit_B,
            'n_resample': n_resample,
            'perm_type': perm_type,
            'shuffled_fits': shuffled_fits, 
            'KL_data':     {'true_distances': true_distances_KL,
                            'shuffled_distances': shuffled_distances_KL,
                            'dist_ranks': dist_ranks_KL,
                            'p_vals': p_vals_KL},
            'means_data':  {'true_distances': true_distances_means,
                            'shuffled_distances': shuffled_distances_means,
                            'dist_ranks': dist_ranks_means,
                            'p_vals': p_vals_means}
            }

    if isinstance(agent, _RL_agent): 
        dist_ranks_pref = sum(shuffled_pref_dists>=np.tile(true_pref_dists,(n_resample,1)),0)
        p_vals_pref = dist_ranks_pref / n_resample  # Should this be n_resample + 1?
        mft['pref_data'] = {'true_preferences_A' : true_preferences_A,
                            'true_preferences_B' : true_preferences_B,
                            'true_distances': true_pref_dists,
                            'shuffled_distances': shuffled_pref_dists,
                            'dist_ranks': dist_ranks_pref,
                            'p_vals': p_vals_pref}
    return mft