import pandas as pd
import numpy as np
import seaborn as sns
from os.path import join
import matplotlib.pyplot as plt
from paper_behavior_functions import seaborn_style, figpath, datapath, FIGURE_WIDTH, FIGURE_HEIGHT

# Settings
DECODER = 'bayes'
FIG_PATH = figpath()
colors = [[1, 1, 1], [1, 1, 1], [0.6, 0.6, 0.6]]
seaborn_style()

# Load in results from csv file
decoding_result = pd.read_pickle(join(datapath(),
                                          'classification_results_perf_%s.pkl' % DECODER))

# Calculate if decoder performs above chance
chance_level = decoding_result['original_shuffled'].mean()
significance = np.percentile(decoding_result['original'], 2.5)
sig_control = np.percentile(decoding_result['control'], 0.001)
if chance_level > significance:
    print('\n%s classifier did not perform above chance' % DECODER)
    print('Chance level: %.2f (F1 score)' % chance_level)
else:
    print('\n%s classifier did not perform above chance' % DECODER)
    print('Chance level: %.2f (F1 score)' % chance_level)
print('F1 score: %.2f ± %.3f' % (decoding_result['original'].mean(),
                                 decoding_result['original'].std()))
Exemplo n.º 2
0
    params['accuracy'] = np.mean(acc)

    return params  # wide df


# ========================================== #
#%% 3. FIT FOR EACH MOUSE
# ========================================== #

print('fitting GLM to BASIC task...')
params_basic = behav.loc[behav.task == 'traini', :].groupby([
    'institution_code', 'subject_nickname'
]).progress_apply(fit_glm, prior_blocks=False).reset_index()
print('The mean condition number for the basic model is',
      params_basic['condition_number'].mean())

print('fitting GLM to FULL task...')
params_full = behav.loc[behav.task == 'biased', :].groupby([
    'institution_code', 'subject_nickname'
]).progress_apply(fit_glm, prior_blocks=True).reset_index()
print('The mean condition number for the full model is',
      params_full['condition_number'].mean())

# ========================================== #
# SAVE FOR NEXT TIME
# ========================================== #

data_path = Path(datapath(), 'model_results')
params_basic.to_csv(data_path / 'params_basic.csv')
params_full.to_csv(data_path / 'params_full.csv')
Exemplo n.º 3
0
    # Randomly select N mice from each lab to equalize classes
    use_index = np.empty(0, dtype=int)
    for j, lab in enumerate(np.unique(labels)):
        use_index = np.concatenate([use_index, np.random.choice(labels_nr[labels == lab],
                                                               N_MICE, replace=False)])

    # Original data
    decoding_result.loc[i, 'original'], conf_matrix = decoding(decoding_set[use_index],
                                                               labels_decod, clf)
    decoding_result.loc[i, 'confusion_matrix'] = (conf_matrix
                                                  / conf_matrix.sum(axis=1)[:, np.newaxis])

    # Shuffled data
    np.random.shuffle(labels_shuffle)
    decoding_result.loc[i, 'original_shuffled'], _ = decoding(decoding_set[use_index],
                                                              labels_shuffle, clf)

    # Positive control data
    decoding_result.loc[i, 'control'], conf_matrix = decoding(control_set[use_index],
                                                              labels_decod, clf)
    decoding_result.loc[i, 'control_cm'] = (conf_matrix
                                            / conf_matrix.sum(axis=1)[:, np.newaxis])

    # Shuffled positive control data
    np.random.shuffle(labels_shuffle)
    decoding_result.loc[i, 'control_shuffled'], _ = decoding(control_set[use_index],
                                                             labels_shuffle, clf)

# Save to csv
decoding_result.to_pickle(join(datapath(), 'classification_results_first_biased_%s.pkl' % DECODER))
        use_index = np.concatenate([
            use_index,
            np.random.choice(labels_nr[labels == lab], N_MICE, replace=False)
        ])

    # Original data
    decoding_result.loc[i, 'original'], conf_matrix = decoding(
        decoding_set[use_index], labels_decod, clf)
    decoding_result.loc[i, 'confusion_matrix'] = (
        conf_matrix / conf_matrix.sum(axis=1)[:, np.newaxis])

    # Shuffled data
    np.random.shuffle(labels_shuffle)
    decoding_result.loc[i, 'original_shuffled'], _ = decoding(
        decoding_set[use_index], labels_shuffle, clf)

    # Positive control data
    decoding_result.loc[i, 'control'], conf_matrix = decoding(
        control_set[use_index], labels_decod, clf)
    decoding_result.loc[i, 'control_cm'] = (
        conf_matrix / conf_matrix.sum(axis=1)[:, np.newaxis])

    # Shuffled positive control data
    np.random.shuffle(labels_shuffle)
    decoding_result.loc[i, 'control_shuffled'], _ = decoding(
        control_set[use_index], labels_shuffle, clf)

# Save to csv
decoding_result.to_pickle(
    join(datapath(), 'classification_results_full_%s.pkl' % DECODER))
# Date at which trained_1b was implemented in DJ pipeline
DATE_IMPL = datetime.strptime('12-09-2019', '%d-%m-%Y').date()

# Query data
if QUERY is True:
    # Query sessions
    use_subjects = query_subjects()
    ses = ((use_subjects * behavior_analysis.SessionTrainingStatus * behavior_analysis.PsychResults
            & 'training_status = "trained_1a" OR training_status = "trained_1b"')
           .proj('subject_nickname', 'n_trials_stim', 'institution_short', 'training_status')
           .fetch(format='frame')
           .reset_index())
    ses['n_trials'] = [sum(i) for i in ses['n_trials_stim']]
else:
    ses = pd.read_csv(join(datapath(), 'Fig2c.csv'))
    use_subjects = ses['subject_uuid'].unique()  # For counting the number of subjects

ses = ses.sort_values(by=['subject_uuid', 'session_start_time'])
uni_sub = np.unique(ses['subject_uuid'])

training_time = pd.DataFrame(columns=['sessions'])
# Loop over subjects
for i_sub in range(0, len(uni_sub)):
    subj = uni_sub[i_sub]

    # Construct dataframe
    df = ses.loc[ses['subject_uuid'] == subj]
    if len(np.unique(df['training_status'])) == 2:  # Append

        # Check that the session start date is different for when reaching 1a/1b
from os.path import join, isdir
import pandas as pd
from paper_behavior_functions import (query_subjects,
                                      query_sessions_around_criterion,
                                      institution_map, CUTOFF_DATE, dj2pandas,
                                      datapath,
                                      query_session_around_performance)
from ibl_pipeline.analyses import behavior as behavioral_analyses
from ibl_pipeline import reference, subject, behavior, acquisition
import csv

# Get map of lab number to institute
institution_map, _ = institution_map()

# create data directory if it doesn't exist yet
root = datapath()
if not isdir(root):
    mkdir(root)

# Create list of subjects used
subjects = query_subjects(as_dataframe=True)
subjects.to_csv(join(root, 'subjects.csv'))

# %%=============================== #
# FIGURE 2
# ================================= #
print('Starting figure 2.')
# Figure 2af
use_subjects = query_subjects()
b = (behavioral_analyses.BehavioralSummaryByDate * use_subjects *
     behavioral_analyses.BehavioralSummaryByDate.PsychResults)
Exemplo n.º 7
0
pal = group_colors()
institution_map, col_names = institution_map()
col_names = col_names[:-1]

# %% ============================== #
# GET DATA FROM TRAINED ANIMALS
# ================================= #

if QUERY is True:
    use_subjects = query_subjects()
    b = (behavioral_analyses.BehavioralSummaryByDate * use_subjects)
    behav = b.fetch(order_by='institution_short, subject_nickname, training_day',
                    format='frame').reset_index()
    behav['institution_code'] = behav.institution_short.map(institution_map)
else:
    behav = pd.read_pickle(os.path.join(datapath(), 'Fig2af.pkl'))

# exclude sessions with fewer than 100 trials
behav = behav[behav['n_trials_date'] > 100]

# convolve performance over 3 days
for i, nickname in enumerate(behav['subject_nickname'].unique()):
    perf = behav.loc[behav['subject_nickname'] == nickname, 'performance_easy'].values
    perf_conv = np.convolve(perf, np.ones((3,))/3, mode='valid')
    # perf_conv = np.append(perf_conv, [np.nan, np.nan])
    perf_conv = medfilt(perf, kernel_size=3)
    behav.loc[behav['subject_nickname'] == nickname, 'performance_easy'] = perf_conv

# how many mice are there for each lab?
N = behav.groupby(['institution_code'])['subject_nickname'].nunique().to_dict()
behav['n_mice'] = behav.institution_code.map(N)
Exemplo n.º 8
0
    # Randomly select N mice from each lab to equalize classes
    use_index = np.empty(0, dtype=int)
    for j, lab in enumerate(np.unique(labels)):
        use_index = np.concatenate([use_index, np.random.choice(labels_nr[labels == lab],
                                                               N_MICE, replace=False)])

    # Original data
    decoding_result.loc[i, 'original'], conf_matrix = decoding(decoding_set[use_index],
                                                               labels_decod, clf)
    decoding_result.loc[i, 'confusion_matrix'] = (conf_matrix
                                                  / conf_matrix.sum(axis=1)[:, np.newaxis])

    # Shuffled data
    np.random.shuffle(labels_shuffle)
    decoding_result.loc[i, 'original_shuffled'], _ = decoding(decoding_set[use_index],
                                                              labels_shuffle, clf)

    # Positive control data
    decoding_result.loc[i, 'control'], conf_matrix = decoding(control_set[use_index],
                                                              labels_decod, clf)
    decoding_result.loc[i, 'control_cm'] = (conf_matrix
                                            / conf_matrix.sum(axis=1)[:, np.newaxis])

    # Shuffled positive control data
    np.random.shuffle(labels_shuffle)
    decoding_result.loc[i, 'control_shuffled'], _ = decoding(control_set[use_index],
                                                             labels_shuffle, clf)

# Save to pickle
decoding_result.to_pickle(join(datapath(), 'classification_results_perf_%s.pkl' % DECODER))