Example #1
0
def plot_overview_aligned_1st_2nd(TrialDfs, SessionDf, pre, post, axes=None):
    """
        Session aligned to 1st cue and 2nd (timing) cues, together with choice RT markers and reach spans
        Trials split by trial outcome and interval category, similar to Gallinanes Fig.5C        
    """

    intervals, choice_rt = [], []
    correct_idx, incorrect_idx, missed_idx = [], [], []

    if axes is None:
        fig = plt.figure(constrained_layout=True, figsize=(5, 4))

    # Obtain spans of ReachDf for every trial
    for TrialDf in TrialDfs:

        left_reach_Df = bhv.get_spans_from_names(TrialDf, 'REACH_LEFT_ON',
                                                 'REACH_LEFT_OFF')
        right_reach_Df = bhv.get_spans_from_names(TrialDf, 'REACH_RIGHT_ON',
                                                  'REACH_RIGHT_OFF')

        left_reach_Df = left_reach_Df[
            left_reach_Df['dt'] <
            5]  # remove reaches whose length is less than 5ms
        right_reach_Df = right_reach_Df[right_reach_Df['dt'] < 5]

    intervals = SessionDf[
        'this_interval'].values + pre  # to be after the pre time
    choice_rt = SessionDf[
        'choice_rt'].values + intervals  # to be after the interval

    correct_idx = np.array(
        (SessionDf[SessionDf['outcome'] == 'correct']).index)
    incorrect_idx = np.array(
        (SessionDf[SessionDf['outcome'] == 'incorrect']).index)
    missed_idx = np.array((SessionDf[SessionDf['outcome'] == 'missed']).index)

    # Sort the INDEXES (of data already split based on interval)
    corr_idx_sorted = correct_idx[np.argsort(intervals[correct_idx])]
    incorr_idx_sorted = incorrect_idx[np.argsort(intervals[incorrect_idx])]
    missed_idx_sorted = missed_idx[np.argsort(intervals[missed_idx])]

    split_sorted_idxs_list = [
        corr_idx_sorted, incorr_idx_sorted, missed_idx_sorted
    ]
    """ Plotting """
    heights = [
        len(corr_idx_sorted),
        len(incorr_idx_sorted),
        len(missed_idx_sorted)
    ]
    gs = fig.add_gridspec(ncols=1, nrows=3, height_ratios=heights)
    ylabel = ['Correct', 'Incorrect', 'Missed']

    for i, idxs in enumerate(split_sorted_idxs_list):

        axes = fig.add_subplot(gs[i])

        axes.set_aspect('auto')
        axes.axvline(500, linestyle='solid', alpha=1, lw=1, color='k')
        axes.axvline(2000, linestyle='dashed', alpha=0.5, lw=1, color='k')

        # Second timing cue and choice RT bars
        ymin = np.arange(
            -0.5,
            len(idxs) -
            1)  # need to shift since lines starts at center of trial
        ymax = np.arange(0.45, len(idxs))
        axes.vlines(intervals[idxs], ymin, ymax, colors='#FFC0CB')
        axes.vlines(choice_rt[idxs], ymin, ymax, colors='#0000FF', linewidth=2)

        if i == 0:
            axes.set_title('Session overview aligned to 1st and 2nd cues')

        axes.set_ylabel(ylabel[i])

        axes.set_xticklabels([])
        axes.set_xticks([])
        axes.set_xlim(0, post)

    # Formatting
    axes.xaxis.set_ticks_position('bottom')
    plt.setp(axes,
             xticks=np.arange(0, post + pre + 1, 500),
             xticklabels=np.arange((-pre / 1000), (post / 1000) + 0.5, 0.5))
    plt.xlabel('Time')

    return axes
Example #2
0
def grasp_dur_across_sess(animal_fd_path, tasks_names):

    SessionsDf = utils.get_sessions(animal_fd_path)
    animal_meta = pd.read_csv(animal_fd_path / 'animal_meta.csv')

    # Filter sessions to the ones of the task we want to see
    FilteredSessionsDf = pd.concat(
        [SessionsDf.groupby('task').get_group(name) for name in tasks_names])
    log_paths = [
        Path(path) / 'arduino_log.txt' for path in FilteredSessionsDf['path']
    ]

    fig, axes = plt.subplots(ncols=2, figsize=[8, 4], sharey=True, sharex=True)
    sides = ['LEFT', 'RIGHT']

    Df = []
    # gather data
    for day, log_path in enumerate(log_paths):

        LogDf = bhv.get_LogDf_from_path(log_path)
        TrialSpans = bhv.get_spans_from_names(LogDf, "TRIAL_ENTRY_STATE",
                                              "ITI_STATE")

        TrialDfs = []
        for i, row in tqdm(TrialSpans.iterrows(), position=0, leave=True):
            TrialDfs.append(bhv.time_slice(LogDf, row['t_on'], row['t_off']))

        metrics = (met.get_start, met.get_stop, met.get_correct_side,
                   met.get_outcome, met.get_interval_category,
                   met.get_chosen_side, met.has_reach_left,
                   met.has_reach_right)
        SessionDf = bhv.parse_trials(TrialDfs, metrics)

        for side in sides:
            event_on, event_off = 'REACH_' + str(side) + '_ON', 'REACH_' + str(
                side) + '_OFF'  # event names

            # reaches
            grasp_spansDf = bhv.get_spans_from_names(LogDf, event_on,
                                                     event_off)
            reach_durs = np.array(grasp_spansDf['dt'].values, dtype=object)
            Df.append([reach_durs], columns=['durs'], ignore_index=True)

            # grasps
            choiceDf = bhv.groupby_dict(
                SessionDf, dict(has_choice=True, chosen_side=side.lower()))
            grasp_durs = choiceDf[~choiceDf['grasp_dur'].isna(
            )]['grasp_dur'].values  # filter out Nans
            Df.append([grasp_durs, 'g', side],
                      columns=['durs', 'type', 'side'],
                      ignore_index=True)

        sns.violinplot(data=Df[Df['type'] == 'r'],
                       x=day,
                       y='durs',
                       hue='side',
                       split=True,
                       cut=0,
                       legend='reaches')
        sns.violinplot(data=Df[Df['type'] == 'g'],
                       x=day,
                       y='durs',
                       hue='side',
                       split=True,
                       cut=0,
                       legend='grasps')

    fig.suptitle('CDF of first reach split on trial type \n' +
                 animal_meta['value'][5] + '-' + animal_meta['value'][0])
    axes.legend

    return axes
Example #3
0
def plot_overview_simple(LogDf, align_event, pre, post, how='bars', axes=None):
    "Plots trials together with reach spans"

    if axes is None:
        fig, axes = plt.subplots()

    # Key Events and Spans
    key_list = [
        'REACH_LEFT_ON', 'REACH_RIGHT_ON', 'PRESENT_INTERVAL_STATE',
        'GO_CUE_SHORT_EVENT', 'GO_CUE_LONG_EVENT'
    ]

    colors = sns.color_palette('hls', n_colors=len(key_list))
    cdict = dict(zip(key_list, colors))

    t_ref = bhv.get_events_from_name(LogDf, align_event).values

    for i, t in enumerate(t_ref):

        Df = bhv.time_slice(LogDf, t - pre, t + post, 't')

        for name in cdict:
            # plot events
            if name.endswith("_EVENT") or name.endswith("_STATE"):
                event_name = name
                times = bhv.get_events_from_name(Df, name).values - t

                if how == 'dots':
                    axes.plot(times, [i] * len(times),
                              '.',
                              color=cdict[event_name],
                              alpha=0.75)  # a bar

                if how == 'bars':
                    for time in times:
                        axes.plot([time, time], [i - 0.5, i + 0.5],
                                  lw=2,
                                  color=cdict[event_name],
                                  alpha=0.75)  # a bar

            # plot spans
            if name.endswith("_ON"):
                span_name = name.split("_ON")[0]
                on_name = span_name + '_ON'
                off_name = span_name + '_OFF'

                SpansDf = bhv.get_spans_from_names(Df, on_name, off_name)

                if 'REACH' in span_name:
                    SpansDf = SpansDf[
                        SpansDf['dt'] >
                        10]  # remove reaches whose length is less than 10ms

                for j, row_s in SpansDf.iterrows():
                    time = row_s['t_on'] - t
                    dur = row_s['dt']
                    rect = plt.Rectangle((time, i - 0.5),
                                         dur,
                                         1,
                                         facecolor=cdict[on_name],
                                         linewidth=1)
                    axes.add_patch(rect)

    for key in cdict.keys():
        axes.plot([0], [0], color=cdict[key], label=key, lw=4)

    # Formatting
    axes.legend(loc="center",
                bbox_to_anchor=(0.5, -0.2),
                prop={'size': len(key_list)},
                ncol=len(key_list),
                frameon=False)
    axes.set_title('Trials aligned to ' + str(align_event))
    plt.setp(axes,
             xticks=np.arange(-pre, post + 1, 500),
             xticklabels=np.arange(-pre / 1000, post / 1000 + 0.1, 0.5))
    axes.set_ylim([0, len(t_ref)])
    axes.invert_yaxis()  # Needs to be after set_ylim
    axes.set_xlabel('Time (ms)')
    axes.set_ylabel('Trial No.')

    fig.tight_layout()

    return axes
Example #4
0
# %%  Loadcell
csv_path = log_path.parent / "bonsai_LoadCellData.csv"
LoadCellDf = bhv.parse_bonsai_LoadCellData(csv_path)

# %%
LogDf['t_original'] = LogDf['t']
LogDf['t'] = Sync.convert(LogDf['t'].values,'arduino','loadcell')

# %% median removal for loadcelldata
# %% median correction
samples = 10000 # 10s buffer: harp samples at 1khz, arduino at 100hz, LC controller has 1000 samples in buffer
LoadCellDf['x'] = LoadCellDf['x'] - LoadCellDf['x'].rolling(samples).median()
LoadCellDf['y'] = LoadCellDf['y'] - LoadCellDf['y'].rolling(samples).median()

# %%
Spans = bhv.get_spans_from_names(LogDf, "REWARD_RIGHT_VALVE_ON", "REWARD_RIGHT_VALVE_OFF")
N = 2000
Data = sp.zeros((N,27,2))
for  i, row in Spans.iterrows():
    t_on = row["t_on"]
    t_off = t_on + N
    data = bhv.time_slice(LoadCellDf, t_on, t_off)
    Data[:,i,:] = data.values[:,1:]

# %%
fs = sp.arange(60,310,10)
Data = Data[:,-fs.shape[0]:,:]

# %%
fig, axes = plt.subplots()
for i in range(fs.shape[0]):
Example #5
0
Sync.sync('arduino','cam')

# %% get frames per file
fnames = np.sort([fname for fname in os.listdir() if fname.endswith('.tif')])
from skimage.io import imread
nFrames_per_file = []
fnames = fnames
for fname in tqdm(fnames):
    D = imread(fname)
    nFrames_per_file.append(D.shape[0])

np.save('Frames_per_file.npy',np.array(nFrames_per_file))

# %% reconstructing frame times
nFrames_per_file = np.load('Frames_per_file.npy')
spans = bhv.get_spans_from_names(LogDf,"TRIAL_ENTRY_EVENT","FRAME_EVENT")

recorded_frame_events = LogDf[LogDf['name'] == 'FRAME_EVENT']['t'].values
# known frame rate
fr = 3.06188
dt = 1/fr

for i,row in spans.iterrows():
    inf_times = row['t_off'] + np.arange(nFrames_per_file[i+1]) * dt * 1000
    Df = pd.DataFrame(zip(['FRAME_INF_EVENT'] * (nFrames_per_file[i+1]),inf_times),columns=['name','t'])
    LogDf = LogDf.append(Df)

LogDf = LogDf.sort_values('t')
# %% get dFF data
folder = "/media/georg/data/mesoscope/first data/with behavior/2021-08-25_day2_square_2/reshaped"
os.chdir(folder)
Example #6
0
for j, animal in enumerate(animals):

    SessionsDf = utils.get_sessions(animals_dir / animal)
    
    # Filter sessions to the ones of the task we want to see
    task_name = ['learn_to_choose_v2']
    FilteredSessionsDf = pd.concat([SessionsDf.groupby('task').get_group(name) for name in task_name])
    log_paths = [Path(path)/'arduino_log.txt' for path in FilteredSessionsDf['path']]

    for k,log_path in enumerate(tqdm(log_paths[:no_sessions_to_analyze], position=0, leave=True, desc=animal)):
        
        LogDf = bhv.get_LogDf_from_path(log_path)

        # Getting metrics
        TrialSpans = bhv.get_spans_from_names(LogDf, "TRIAL_ENTRY_STATE", "ITI_STATE")

        TrialDfs = []
        for i, row in TrialSpans.iterrows():
            TrialDfs.append(bhv.time_slice(LogDf, row['t_on'], row['t_off']))

        metrics = ( met.get_start, met.get_stop, met.get_correct_side, met.get_outcome, met.get_chosen_side, met.has_choice,
                    met.rew_collected)
        SessionDf = bhv.parse_trials(TrialDfs, metrics)

        # Session metrics
        n_rews_collected[j,k] = SessionDf['rew_collect'].sum()
        n_anticipatory[j,k] = SessionDf['has_choice'].sum()

        no_trials[j,k] = len(SessionDf)
        session_length [j,k] = (LogDf['t'].iloc[-1]-LogDf['t'].iloc[0])/(1000*60) # convert msec. -> sec.-> min.
Example #7
0
for animal, folder in folders.items():
    print("processing %s" % animal)
    folder = Path(folder)
    Animal = utils.Animal(folder)
    SessionsDf = utils.get_sessions(Animal.folder).groupby('task').get_group('learn_to_choose_v2')
    SessionsDf = SessionsDf.reset_index()

    for day, row in SessionsDf.iterrows():
        print(day)
        folder = Path(SessionsDf.loc[day,'path'])
        LogDf = bhv.get_LogDf_from_path(folder / "arduino_log.txt")
        LogDf['min'] = LogDf['t'] / 60000


        # check each reach
        ReachesLeftDf = bhv.get_spans_from_names(LogDf, "REACH_LEFT_ON", "REACH_LEFT_OFF")

        # drop invalid
        min_th = 5
        max_th = 2000

        binds = np.logical_and(ReachesLeftDf['dt'].values > min_th, ReachesLeftDf['dt'].values < max_th)

        ReachesLeftDf = ReachesLeftDf.loc[binds]

        ReachesLeftDf['is_grasp'] = False
        for i, row in ReachesLeftDf.iterrows():
            t_on = row['t_on']
            t_off = row['t_off']
            Df = bhv.time_slice(LogDf, t_on, t_off)
            if 'GRASP_LEFT_ON' in Df.name.values:
Example #8
0
Animals = utils.get_Animals(folder)

# %%

# lifeguard
folder = "/media/georg/htcondor/shared-paton/georg/Animals_reaching/JJP-02909"
Animal = utils.Animal(folder)

day = 0
folder = Path(utils.get_sessions(Animal.folder).path[day])

LogDf = bhv.get_LogDf_from_path(folder / "arduino_log.txt")
LogDf['min'] = LogDf['t'] / 60000

# %% check each reach
ReachesDf = bhv.get_spans_from_names(LogDf, "REACH_ON", "REACH_OFF")

# drop invalid
min_th = 5
max_th = 2000

binds = np.logical_and(ReachesDf['dt'].values > min_th, ReachesDf['dt'].values < max_th)

ReachesDf = ReachesDf.loc[binds]

ReachesDf['is_grasp'] = False
for i, row in ReachesDf.iterrows():
    t_on = row['t_on']
    t_off = row['t_off']
    Df = bhv.time_slice(LogDf, t_on, t_off)
    if 'GRASP_ON' in Df.name.values: