def plot_grasp_duration_distro(LogDf, SessionDf, bin_width, max_grasp_dur, percentile): " Plots the distribution of reach durations split by chosen side and outcome" " Works for allowing mistakes but doesnt plot text or percentile due to Nan's " sides = ['LEFT', 'RIGHT'] no_bins = round(max_grasp_dur / bin_width) kwargs = dict(bins=no_bins, range=(0, max_grasp_dur), edgecolor='none') fig, axes = plt.subplots(ncols=len(sides), figsize=[6, 3], sharex=True, sharey=True) colors = sns.color_palette('hls', n_colors=len(sides)) for ax, side, color in zip(axes, sides, colors): # For all reaches event_on, event_off = 'REACH_' + str(side) + '_ON', 'REACH_' + str( side) + '_OFF' # event names grasp_spansDf = bhv.get_spans_from_names(LogDf, event_on, event_off) grasp_durs = np.array(grasp_spansDf['dt'].values, dtype=object) ax.hist(grasp_durs, **kwargs, alpha=0.25, color=color, label='All grasps') # Histogram # For choice-inducing reaches choiceDf = bhv.groupby_dict( SessionDf, dict(has_choice=True, chosen_side=side.lower())) if len(choiceDf) != 0: grasp_durs = choiceDf[~choiceDf['grasp_dur'].isna( )]['grasp_dur'].values # filter out Nans ax.hist(grasp_durs, **kwargs, alpha=1, color=color, label='Choice grasps') perc = np.percentile(grasp_durs, percentile) ax.axvline(perc, color=color, alpha=1) # perc line ax.text(perc / max_grasp_dur, 0.9, perc, transform=ax.transAxes) # perc number shifted a bit for ax, side in zip(axes, sides): ax.legend(frameon=False, fontsize='small') ax.set_title(side) ax.set_xlabel('Time (ms)') axes[0].set_ylabel('No. of occurrences') fig.suptitle("Histogram of grasps' duration, vertical bar is " + str(percentile) + "th percentile") fig.tight_layout() return axes
def filter_trials_by(SessionDf, TrialDfs, filter_dict): """ This function filters input TrialDfs given filter_pair tuple (or list of tuples) Example: given dict(outcome='correct', chosen_side='left') it will only output trials which are correct to left side """ if len(filter_dict) == 1: # in case its only one pair groupby_keys = list(filter_dict.keys()) getgroup_keys = list(filter_dict.values())[0] try: SDf = SessionDf.groupby(groupby_keys).get_group(getgroup_keys) except utils.NoFiltPairError: error('No trials with given input filter_pair combination') else: # more than 1 pair try: SDf = bhv.groupby_dict(SessionDf, filter_dict) except utils.NoFiltPairError: error('No trials with given input filter_pair combination') TrialDfs_filt = [TrialDfs[i] for i in SDf.index.values.astype(int)] return TrialDfs_filt
def plot_psychometric(SessionDf, N=1000, axes=None, discrete=False): if axes is None: fig, axes = plt.subplots() # Get subset of timing trials with choices outside correction loops try: SDf = bhv.groupby_dict(SessionDf, dict(has_choice=True, timing_trial=True)) except KeyError: print('No trials fulfil criteria') y = SDf['chosen_side'].values == 'right' x = SDf['this_interval'].values # plot the choices if not discrete: axes.plot(x, y, '.', color='k', alpha=0.5) axx = plt.twinx(axes) axx.set_yticks([0, 1]) axx.set_yticklabels(['short', 'long']) axx.set_ylabel('choice') w = 0.05 axx.set_ylim(0 - w, 1 + w) axes.set_ylim(0 - w, 1 + w) axes.set_ylabel('p') axes.axvline(1500, linestyle=':', alpha=0.5, lw=1, color='k') axes.axhline(0.5, linestyle=':', alpha=0.5, lw=1, color='k') # plot the fit x_fit = np.linspace(0, 3000, 100) axes.plot(x_fit, bhv.log_reg(x, y, x_fit), color='red', linewidth=2, alpha=0.75) # plot the random models based on the choice bias bias = (SDf['chosen_side'] == 'right').sum() / SDf.shape[0] R = [] for i in tqdm(range(N)): rand_choices = np.random.rand(SDf.shape[0]) < bias try: R.append(bhv.log_reg(x, rand_choices, x_fit)) except ValueError: # thrown when all samples are true or false print("all true or false") pass R = np.array(R) # Several statistical boundaries alphas = [5, 0.5, 0.05] opacities = [0.2, 0.2, 0.2] for alpha, a in zip(alphas, opacities): R_pc = sp.percentile(R, (alpha, 100 - alpha), 0) axes.fill_between(x_fit, R_pc[0], R_pc[1], color='blue', alpha=a, linewidth=0) axes.set_xlabel('time (ms)') if discrete: intervals = list(SessionDf.groupby('this_interval').groups.keys()) correct_sides = ['right'] * len(intervals) for i, interval in enumerate(intervals): SDf = bhv.groupby_dict( SessionDf, dict(this_interval=interval, has_choice=True, in_corr_loop=False, timing_trial=True)) f = (SDf['chosen_side'] == correct_sides[i]).sum() / SDf.shape[0] axes.plot(interval, f, 'o', color='r') return axes
def grasp_dur_across_sess(animal_fd_path, tasks_names): SessionsDf = utils.get_sessions(animal_fd_path) animal_meta = pd.read_csv(animal_fd_path / 'animal_meta.csv') # Filter sessions to the ones of the task we want to see FilteredSessionsDf = pd.concat( [SessionsDf.groupby('task').get_group(name) for name in tasks_names]) log_paths = [ Path(path) / 'arduino_log.txt' for path in FilteredSessionsDf['path'] ] fig, axes = plt.subplots(ncols=2, figsize=[8, 4], sharey=True, sharex=True) sides = ['LEFT', 'RIGHT'] Df = [] # gather data for day, log_path in enumerate(log_paths): LogDf = bhv.get_LogDf_from_path(log_path) TrialSpans = bhv.get_spans_from_names(LogDf, "TRIAL_ENTRY_STATE", "ITI_STATE") TrialDfs = [] for i, row in tqdm(TrialSpans.iterrows(), position=0, leave=True): TrialDfs.append(bhv.time_slice(LogDf, row['t_on'], row['t_off'])) metrics = (met.get_start, met.get_stop, met.get_correct_side, met.get_outcome, met.get_interval_category, met.get_chosen_side, met.has_reach_left, met.has_reach_right) SessionDf = bhv.parse_trials(TrialDfs, metrics) for side in sides: event_on, event_off = 'REACH_' + str(side) + '_ON', 'REACH_' + str( side) + '_OFF' # event names # reaches grasp_spansDf = bhv.get_spans_from_names(LogDf, event_on, event_off) reach_durs = np.array(grasp_spansDf['dt'].values, dtype=object) Df.append([reach_durs], columns=['durs'], ignore_index=True) # grasps choiceDf = bhv.groupby_dict( SessionDf, dict(has_choice=True, chosen_side=side.lower())) grasp_durs = choiceDf[~choiceDf['grasp_dur'].isna( )]['grasp_dur'].values # filter out Nans Df.append([grasp_durs, 'g', side], columns=['durs', 'type', 'side'], ignore_index=True) sns.violinplot(data=Df[Df['type'] == 'r'], x=day, y='durs', hue='side', split=True, cut=0, legend='reaches') sns.violinplot(data=Df[Df['type'] == 'g'], x=day, y='durs', hue='side', split=True, cut=0, legend='grasps') fig.suptitle('CDF of first reach split on trial type \n' + animal_meta['value'][5] + '-' + animal_meta['value'][0]) axes.legend return axes
fig, axes = plt.subplots() frame_ix = Sync.convert(t_on, 'arduino', 'dlc') frame = dlc.get_frame(Vid, frame_ix) dlc.plot_frame(frame, axes=axes) dlc.plot_bodyparts(bodyparts, DlcDf, frame_ix, colors=bp_cols, axes=axes) # trajectory DlcDfSlice = bhv.time_slice(DlcDf, t_on, t_off) dlc.plot_trajectories(DlcDfSlice, bodyparts, axes=axes, colors=bp_cols, lw=1, p=0.99) # %% plot all of the selected trial type # SDf = bhv.groupby_dict(SessionDf, dict(outcome='correct', correct_side='right', paw_resting=False)) # trial selection SDf = bhv.groupby_dict(SessionDf, dict(has_choice=True, correct_side='left', outcome='correct')) # plot some random frame fig, axes = plt.subplots() frame_ix = 1000 frame = dlc.get_frame(Vid, frame_ix) dlc.plot_frame(frame, axes=axes) # plot all traj in selection for i in tqdm(SDf.index): TrialDf = TrialDfs[i] Df = bhv.event_slice(TrialDf,'PRESENT_INTERVAL_STATE','CHOICE_EVENT') # Df = bhv.time_slice(Df, Df.iloc[-1]['t']-500, Df.iloc[-1]['t']) t_on = Df.iloc[0]['t'] t_off = Df.iloc[-1]['t']
bp_cols = make_bodypart_colors(bodyparts) fig, axes = plt.subplots() frame_ix = Sync.convert(t_on, 'arduino', 'dlc') frame = dlc.get_frame(Vid, frame_ix) dlc.plot_frame(frame, axes=axes) dlc.plot_bodyparts(bodyparts, DlcDf, frame_ix, colors=bp_cols, axes=axes) # trajectory DlcDfSlice = bhv.time_slice(DlcDf, t_on, t_off) dlc.plot_trajectories(DlcDfSlice, bodyparts, axes=axes, colors=bp_cols, lw=1, p=0.99) # %% plot all of the selected trial type # trial selection SDf = bhv.groupby_dict(SessionDf, dict(correct_side='left',outcome='incorrect')) # SDf = SessionDf.groupby('correct_side').get_group('right') # plot some random frame fig, axes = plt.subplots() frame_ix = 1000 frame = dlc.get_frame(Vid, frame_ix) dlc.plot_frame(frame, axes=axes) # plot all traj in selection for i in tqdm(SDf.index): TrialDf = TrialDfs[i] Df = bhv.event_slice(TrialDf,'PRESENT_INTERVAL_STATE','ITI_STATE') t_on = Df.iloc[0]['t'] t_off = Df.iloc[-1]['t']
date_abbr.append(month_abbr+'-'+str(date[2])) # Getting metrics TrialSpans = bhv.get_spans_from_names(LogDf, "TRIAL_ENTRY_STATE", "ITI_STATE") TrialDfs = [] for i, row in tqdm(TrialSpans.iterrows(),position=0, leave=True): TrialDfs.append(bhv.time_slice(LogDf, row['t_on'], row['t_off'])) metrics = (met.get_start, met.get_stop, met.get_correct_side, met.get_outcome, met.get_chosen_side, met.has_choice) SessionDf = bhv.parse_trials(TrialDfs, metrics) # Session metrics MissedDf = SessionDf[SessionDf['outcome'] == 'missed'] choiceDf = SessionDf[SessionDf['has_choice'] == True] left_trials_missedDf = bhv.groupby_dict(SessionDf, dict(outcome='missed', correct_side='left')) right_trials_missedDf = bhv.groupby_dict(SessionDf, dict(outcome='missed', correct_side='right')) corr_leftDf = bhv.groupby_dict(SessionDf, dict(outcome='correct', correct_side='left')) left_trials_with_choiceDf = bhv.groupby_dict(SessionDf, dict(has_choice=True, correct_side='left')) corr_rightDf = bhv.groupby_dict(SessionDf, dict(outcome='correct', correct_side='right')) right_trials_with_choiceDf = bhv.groupby_dict(SessionDf, dict(has_choice=True, correct_side='right')) # Metrics of evolution try: perc_corr_left.append(len(corr_leftDf)/len(left_trials_with_choiceDf)*100) except: perc_corr_left.append(np.NaN) try: perc_corr_right.append(len(corr_rightDf)/len(right_trials_with_choiceDf)*100)
outcomes = SessionDf['outcome'].unique() for outcome in outcomes: SessionDf['is_'+outcome] = SessionDf['outcome'] == outcome # setup general filter SessionDf['exclude'] = False # %% N = 100 # def plot_psychometric(SessionDf, N=1000, axes=None, discrete=False): fig, axes = plt.subplots(figsize=[4,3]) # get only the subset with choices - excludes missed SDf = bhv.groupby_dict(SessionDf, dict(has_choice=True, exclude=False, in_corr_loop=False, is_premature=False, timing_trial=True)) try: SDf = SDf.groupby('timing_trial').get_group(True) except KeyError: print("no timing trials in session") y = SDf['chosen_side'].values == 'right' x = SDf['this_interval'].values # axx = plt.twinx(axes) # axx.set_yticks([0,1]) # axx.set_yticklabels(['short','long']) # axx.set_ylabel('choice') # axx.set_ylim(0-w, 1+w)