Example #1
0
def fit_comparison(sessions_A, sessions_B, agent, fig_no = 1, title = None, max_change = 0.005):
    ''' Fit the two groups of sessions with the specified agent and plot the results on the same axis.
    '''
    fit_A = mf.fit_population(sessions_A, agent, max_change = max_change)
    fit_B = mf.fit_population(sessions_B, agent, max_change = max_change)
    rp.scatter_plot_comp(fit_A, fit_B, fig_no = fig_no)
    if title:p.title(title)
Example #2
0
def test_population_fitting(task, agent, n_sessions=8, n_trials=1000, pop_params=None):
    """Simulate a set of sessions using parameters drawn from normal distributions
    specified by pop_params.  Then fit the agent model to the simulated data and plot
    correspondence between true and fitted paramter values.
    """
    sessions = simulate_sessions(task, agent, n_sessions, n_trials, pop_params)
    ML_fits, MAP_fits, pop_params = mf.fit_population(sessions, agent, max_iter=15)
    rp.plot_true_fitted_params(sessions, ML_fits, MAP_fits)
    return (sessions, ML_fits, MAP_fits, pop_params)
Example #3
0
def abs_preference_comparison(sessions_A, sessions_B, population_fit_A, population_fit_B, agent,
                               fig_no = 1, title = None):
    ''' Plot mean absolute preference of model based and model free system based on population fits.
    '''
    mean_preference_mb_A, mean_preference_td_A = rp.abs_preference_plot(sessions_A, population_fit_A, agent, to_plot = False)
    mean_preference_mb_B, mean_preference_td_B = rp.abs_preference_plot(sessions_B, population_fit_B, agent, to_plot = False)
    p.figure(fig_no)
    p.clf()
    p.bar([1  , 3],[mean_preference_mb_A, mean_preference_td_A])
    p.bar([1.8,3.8],[mean_preference_mb_B, mean_preference_td_B],color = 'r')
    p.xticks([1.8, 3.8], ['Model based', 'Model free'])
    p.xlim(0.8,4.8)
    p.ylabel('Mean abs. preference')
    if title:p.title(title)
Example #4
0
def eval_calibration(sessions, agent, population_fit, use_MAP=True, n_bins=10, fixed_widths=False, to_plot=False):
    """Caluculate real choice probabilities as function of model choice probabilities."""

    session_fits = population_fit["MAP_fits"]

    assert len(session_fits[0]["params_T"]) == agent.n_params, "agent n_params does not match population_fit."
    assert len(sessions) == len(session_fits), "Number of fits does not match number of sessions."
    assert population_fit["agent_name"] == agent.name, "Agent name different from that used for fits."

    # Create arrays containing model choice probabilites and true choices for each trial.
    session_choices, session_choice_probs = ([], [])
    for fit, session in zip(session_fits, sessions):
        if use_MAP:
            params_T = fit["params_T"]
        else:
            params_T = ru.sample_params_T_from_pop_params(population_fit["pop_params"], agent)
        session_choices.append(session.CTSO["choices"].tolist())
        session_choice_probs.append(agent.session_likelihood(session, params_T, return_trial_data=True)["choice_probs"])
    choices = np.hstack(session_choices)
    choice_probs = np.vstack(session_choice_probs)[:, 1]

    # Calculate true vs model choice probs.
    true_probs = np.zeros(n_bins)
    model_probs = np.zeros(n_bins)
    if fixed_widths:  # Bins of equal width in model choice probability.
        bin_edges = np.linspace(0, 1, n_bins + 1)
        bin_width = bin_edges[1] - bin_edges[0]
    else:  # Bins of equal trial number.
        choices = choices[np.argsort(choice_probs)]
        choice_probs.sort()
        bin_edges = choice_probs[np.round(np.linspace(0, len(choice_probs) - 1, n_bins + 1)).astype(int)]
        bin_edges[0] = 0.0
    for b in range(n_bins):
        true_probs[b] = np.mean(choices[np.logical_and(bin_edges[b] < choice_probs, choice_probs <= bin_edges[b + 1])])
        model_probs[b] = np.mean(
            choice_probs[np.logical_and(bin_edges[b] < choice_probs, choice_probs <= bin_edges[b + 1])]
        )
        calibration = {"true_probs": true_probs, "model_probs": model_probs}
    if to_plot:
        rp.calibration_plot(calibration)
    print ("Fraction correct: {}".format(sum((choice_probs > 0.5) == choices.astype(bool)) / float(len(choices))))
    chosen_probs = np.hstack([choice_probs[choices == 1], 1.0 - choice_probs[choices == 0]])
    print ("Geometric mean choice prob: {}".format(np.exp(np.mean(np.log(chosen_probs)))))
    return calibration
Example #5
0
def plots(sessions_A, sessions_B, RL_agent, LR_agent = None, title = None,
                   max_change_LR = 0.001, max_change_RL = 0.01, 
                   test_time = 20, parallel = False, test_data = None):
    if test_data:
        RL_fit_A = test_data['RL_fit']['fit_A']
        RL_fit_B = test_data['RL_fit']['fit_B']
        LR_fit_A = test_data['LR_fit']['fit_A']
        LR_fit_B = test_data['LR_fit']['fit_B']
        title = test_data['title']
    else:
        RL_fit_A = mf.fit_population(sessions_A, RL_agent, max_change = max_change_RL, parallel = parallel)
        RL_fit_B = mf.fit_population(sessions_B, RL_agent, max_change = max_change_RL, parallel = parallel)
        LR_fit_A = mf.fit_population(sessions_A, LR_agent, max_change = max_change_LR, parallel = parallel)
        LR_fit_B = mf.fit_population(sessions_B, LR_agent, max_change = max_change_LR, parallel = parallel)

    trial_rate_comparison(sessions_A, sessions_B, test_time, 1, title)
    reversal_comparison(sessions_A, sessions_B,  2, title)
    rp.scatter_plot_comp(LR_fit_A, LR_fit_B, fig_no = 3)
    p.title(title)
    rp.pop_fit_comparison(RL_fit_A, RL_fit_B, fig_no = 4, normalize = False)
    p.suptitle(title)
    abs_preference_comparison(sessions_A, sessions_B, RL_fit_A, RL_fit_B, RL_agent, 5, title)
Example #6
0
def plot_resampled_dists(mft, fig_title = 'Permutation test', fig_no = 1, x_offset = 0.1):
    n_resample = mft['n_resample']
    
    print('Permutations evaluated: {}'.format(mft['n_resample']))
    print('P values    KL: {}'.format(mft['KL_data']['p_vals']))
    print('P values means: {}'.format(mft['means_data']['p_vals']))
    if 'pref_data' in mft.keys():
        print('P values pref: {}'.format(mft['pref_data']['p_vals']))


    #Plotting
    p.figure(fig_no)
    p.clf()
    rp.pop_scatter_plot(mft['fit_A'], col = 'b', clf = True,  subplot = (3,1,1), x_offset = -x_offset)
    rp.pop_scatter_plot(mft['fit_B'], col = 'r', clf = False, subplot = (3,1,1), x_offset =  x_offset)
    
    if fig_title:
        p.suptitle(fig_title)

    p.subplot(3,1,2)
    _plot_dist(mft, 'KL')

    p.subplot(3,1,3)
    _plot_dist(mft, 'means')
Example #7
0
def model_fit_test(sessions_A, sessions_B, agent,  perm_type, n_resample = 100, 
                   max_change = 0.001, max_iter = 300, true_init = False, parallel = True, mft = None):
    '''Permutation test for significant differences in model fits between two groups of 
    sessions.  If a previous model_fit_test object (mft) is passed in, additional 
    permutations are performed and the results added to the current test.

    Outline of procedure:
    1. Perform model fitting seperately on both groups of sessions to give mean and standard
    devaiation of population level distributions for each group.
    2. Evaluate distance metric (KL divergence or difference of means) between these population
    level distibutions for each parameter.
    3. Generate population of resampled groups in which sessions are randomly allocated to 
    the A or B groups.  For more information on how permutations are created see _permuted_dataset doc.
    4. Perform model fitting and evalute distance metric for these resampled groups to get a 
    distribution of the distance metric under the null hypothesis that there is no difference 
    between groups.
    5. Compare the true distance metric for each parameter with the distribution for the 
    resampled groups to get a confidence value. 
    '''
    assert perm_type in ('within_subject', 'cross_subject', 'ignore_subject'), \
        'Invalid permutation type.'
        
    if true_init:
        comb_fit = mf.fit_population(sessions_A + sessions_B, agent, eval_BIC = False, parallel = parallel, max_change = max_change * 2, max_iter = max_iter)
        init_params = comb_fit['pop_params']
    else:
        init_params = None

    n_params = agent.n_params

    if not mft: # No previously calculated permutations passed in.

        true_model_fit_A = mf.fit_population(sessions_A, agent, eval_BIC = False, parallel = parallel, max_change = max_change, max_iter = max_iter, pop_init_params = init_params)
        true_model_fit_B = mf.fit_population(sessions_B, agent, eval_BIC = False, parallel = parallel, max_change = max_change, max_iter = max_iter, pop_init_params = init_params)

        true_distances_KL = _population_fit_distance(true_model_fit_A, true_model_fit_B, 'KL')
        true_distances_means = _population_fit_distance(true_model_fit_A, true_model_fit_B, 'means')

        if isinstance(agent, _RL_agent):  # Evaluate mean abs. preference.
            true_preferences_A = rp.abs_preference_plot(sessions_A, true_model_fit_A, agent, to_plot = False)
            true_preferences_B = rp.abs_preference_plot(sessions_B, true_model_fit_B, agent, to_plot = False)
            true_pref_dists = np.abs(np.array(true_preferences_A) - np.array(true_preferences_B))
            
    else: # Previously calculated permutation test passed in.

        n_resample_orig = mft['n_resample']
        n_resample = n_resample + n_resample_orig
        true_model_fit_A, true_model_fit_B  = mft['fit_A'], mft['fit_B']
        true_distances_KL, true_distances_means = mft['KL_data']['true_distances'], mft['means_data']['true_distances']      
        if isinstance(agent, _RL_agent):
            true_preferences_A, true_preferences_B = mft['pref_data']['true_preferences_A'], mft['pref_data']['true_preferences_B']
            true_pref_dists = mft['pref_data']['true_distances']
            
    # Creat structures to store permuted data.
    shuffled_distances_KL = np.zeros((n_resample, n_params))
    shuffled_distances_means = np.zeros((n_resample, n_params))
    shuffled_pref_dists = np.zeros((n_resample, 2))
    shuffled_fits = []

    if not mft:
        perm_indices = range(n_resample)

    else:  # fill first part of arrays with previously calculated data.
        perm_indices = range(n_resample_orig, n_resample)
        shuffled_distances_KL   [:n_resample_orig,:] = mft['KL_data']   ['shuffled_distances']
        shuffled_distances_means[:n_resample_orig,:] = mft['means_data']['shuffled_distances']
        shuffled_fits += mft['shuffled_fits']
        if isinstance(agent, _RL_agent):
            shuffled_pref_dists     [:n_resample_orig,:] = mft['pref_data'] ['shuffled_distances']

    for i in perm_indices:
        print('Fitting permuted sessions, round: {} of {}'.format(i+1, n_resample))

        shuffled_ses_A, shuffled_ses_B = _permuted_dataset(sessions_A, sessions_B, perm_type)
            
        shuffled_fit_A = mf.fit_population(shuffled_ses_A, agent, eval_BIC = False, max_change = max_change, max_iter = max_iter,
                                           pop_init_params = init_params, parallel = parallel)
        shuffled_fit_B = mf.fit_population(shuffled_ses_B, agent, eval_BIC = False, max_change = max_change, max_iter = max_iter,
                                           pop_init_params = init_params, parallel = parallel)
        shuffled_fits.append(({'means':shuffled_fit_A['pop_params']['means'],'SDs':shuffled_fit_A['pop_params']['SDs']},
                              {'means':shuffled_fit_B['pop_params']['means'],'SDs':shuffled_fit_B['pop_params']['SDs']}))
        shuffled_distances_KL[i,:]    = _population_fit_distance(shuffled_fit_A, shuffled_fit_B, 'KL')
        shuffled_distances_means[i,:] = _population_fit_distance(shuffled_fit_A, shuffled_fit_B, 'means')

        if isinstance(agent, _RL_agent): 
            shuffled_preferences_A = rp.abs_preference_plot(shuffled_ses_A, shuffled_fit_A, agent, to_plot = False)
            shuffled_preferences_B = rp.abs_preference_plot(shuffled_ses_B, shuffled_fit_B, agent, to_plot = False)
            shuffled_pref_dists[i,:] = np.abs(np.array(shuffled_preferences_A) -
                                       np.array(shuffled_preferences_B))

    dist_ranks_KL = sum(shuffled_distances_KL>=np.tile(true_distances_KL,(n_resample,1)),0)
    p_vals_KL = dist_ranks_KL / n_resample  # Should this be n_resample + 1?

    dist_ranks_means = sum(shuffled_distances_means>=np.tile(true_distances_means,(n_resample,1)),0)   
    p_vals_means = dist_ranks_means / n_resample  # Should this be n_resample + 1?

    mft =  {'fit_A': true_model_fit_A,
            'fit_B': true_model_fit_B,
            'n_resample': n_resample,
            'perm_type': perm_type,
            'shuffled_fits': shuffled_fits, 
            'KL_data':     {'true_distances': true_distances_KL,
                            'shuffled_distances': shuffled_distances_KL,
                            'dist_ranks': dist_ranks_KL,
                            'p_vals': p_vals_KL},
            'means_data':  {'true_distances': true_distances_means,
                            'shuffled_distances': shuffled_distances_means,
                            'dist_ranks': dist_ranks_means,
                            'p_vals': p_vals_means}
            }

    if isinstance(agent, _RL_agent): 
        dist_ranks_pref = sum(shuffled_pref_dists>=np.tile(true_pref_dists,(n_resample,1)),0)
        p_vals_pref = dist_ranks_pref / n_resample  # Should this be n_resample + 1?
        mft['pref_data'] = {'true_preferences_A' : true_preferences_A,
                            'true_preferences_B' : true_preferences_B,
                            'true_distances': true_pref_dists,
                            'shuffled_distances': shuffled_pref_dists,
                            'dist_ranks': dist_ranks_pref,
                            'p_vals': p_vals_pref}
    return mft