Ejemplo n.º 1
0
def plot_grasp_duration_distro(LogDf, SessionDf, bin_width, max_grasp_dur,
                               percentile):
    " Plots the distribution of reach durations split by chosen side and outcome"
    " Works for allowing mistakes but doesnt plot text or percentile due to Nan's "

    sides = ['LEFT', 'RIGHT']

    no_bins = round(max_grasp_dur / bin_width)
    kwargs = dict(bins=no_bins, range=(0, max_grasp_dur), edgecolor='none')

    fig, axes = plt.subplots(ncols=len(sides),
                             figsize=[6, 3],
                             sharex=True,
                             sharey=True)

    colors = sns.color_palette('hls', n_colors=len(sides))

    for ax, side, color in zip(axes, sides, colors):

        # For all reaches
        event_on, event_off = 'REACH_' + str(side) + '_ON', 'REACH_' + str(
            side) + '_OFF'  # event names

        grasp_spansDf = bhv.get_spans_from_names(LogDf, event_on, event_off)
        grasp_durs = np.array(grasp_spansDf['dt'].values, dtype=object)
        ax.hist(grasp_durs,
                **kwargs,
                alpha=0.25,
                color=color,
                label='All grasps')  # Histogram

        # For choice-inducing reaches
        choiceDf = bhv.groupby_dict(
            SessionDf, dict(has_choice=True, chosen_side=side.lower()))
        if len(choiceDf) != 0:
            grasp_durs = choiceDf[~choiceDf['grasp_dur'].isna(
            )]['grasp_dur'].values  # filter out Nans
            ax.hist(grasp_durs,
                    **kwargs,
                    alpha=1,
                    color=color,
                    label='Choice grasps')

        perc = np.percentile(grasp_durs, percentile)
        ax.axvline(perc, color=color, alpha=1)  # perc line
        ax.text(perc / max_grasp_dur, 0.9, perc,
                transform=ax.transAxes)  # perc number shifted a bit

    for ax, side in zip(axes, sides):
        ax.legend(frameon=False, fontsize='small')
        ax.set_title(side)
        ax.set_xlabel('Time (ms)')

    axes[0].set_ylabel('No. of occurrences')

    fig.suptitle("Histogram of grasps' duration, vertical bar is " +
                 str(percentile) + "th percentile")
    fig.tight_layout()

    return axes
Ejemplo n.º 2
0
def filter_trials_by(SessionDf, TrialDfs, filter_dict):
    """
        This function filters input TrialDfs given filter_pair tuple (or list of tuples)
        Example: given dict(outcome='correct', chosen_side='left') it will only output trials which are correct to left side 
    """

    if len(filter_dict) == 1:  # in case its only one pair
        groupby_keys = list(filter_dict.keys())
        getgroup_keys = list(filter_dict.values())[0]

        try:
            SDf = SessionDf.groupby(groupby_keys).get_group(getgroup_keys)
        except utils.NoFiltPairError:
            error('No trials with given input filter_pair combination')

    else:  # more than 1 pair
        try:
            SDf = bhv.groupby_dict(SessionDf, filter_dict)
        except utils.NoFiltPairError:
            error('No trials with given input filter_pair combination')

    TrialDfs_filt = [TrialDfs[i] for i in SDf.index.values.astype(int)]

    return TrialDfs_filt
Ejemplo n.º 3
0
def plot_psychometric(SessionDf, N=1000, axes=None, discrete=False):

    if axes is None:
        fig, axes = plt.subplots()

    # Get subset of timing trials with choices outside correction loops
    try:
        SDf = bhv.groupby_dict(SessionDf,
                               dict(has_choice=True, timing_trial=True))
    except KeyError:
        print('No trials fulfil criteria')

    y = SDf['chosen_side'].values == 'right'
    x = SDf['this_interval'].values

    # plot the choices
    if not discrete:
        axes.plot(x, y, '.', color='k', alpha=0.5)

    axx = plt.twinx(axes)
    axx.set_yticks([0, 1])
    axx.set_yticklabels(['short', 'long'])
    axx.set_ylabel('choice')
    w = 0.05
    axx.set_ylim(0 - w, 1 + w)
    axes.set_ylim(0 - w, 1 + w)
    axes.set_ylabel('p')

    axes.axvline(1500, linestyle=':', alpha=0.5, lw=1, color='k')
    axes.axhline(0.5, linestyle=':', alpha=0.5, lw=1, color='k')

    # plot the fit
    x_fit = np.linspace(0, 3000, 100)
    axes.plot(x_fit,
              bhv.log_reg(x, y, x_fit),
              color='red',
              linewidth=2,
              alpha=0.75)

    # plot the random models based on the choice bias
    bias = (SDf['chosen_side'] == 'right').sum() / SDf.shape[0]
    R = []
    for i in tqdm(range(N)):
        rand_choices = np.random.rand(SDf.shape[0]) < bias
        try:
            R.append(bhv.log_reg(x, rand_choices, x_fit))
        except ValueError:
            # thrown when all samples are true or false
            print("all true or false")
            pass
    R = np.array(R)

    # Several statistical boundaries
    alphas = [5, 0.5, 0.05]
    opacities = [0.2, 0.2, 0.2]
    for alpha, a in zip(alphas, opacities):
        R_pc = sp.percentile(R, (alpha, 100 - alpha), 0)
        axes.fill_between(x_fit,
                          R_pc[0],
                          R_pc[1],
                          color='blue',
                          alpha=a,
                          linewidth=0)

    axes.set_xlabel('time (ms)')

    if discrete:
        intervals = list(SessionDf.groupby('this_interval').groups.keys())
        correct_sides = ['right'] * len(intervals)
        for i, interval in enumerate(intervals):
            SDf = bhv.groupby_dict(
                SessionDf,
                dict(this_interval=interval,
                     has_choice=True,
                     in_corr_loop=False,
                     timing_trial=True))
            f = (SDf['chosen_side'] == correct_sides[i]).sum() / SDf.shape[0]
            axes.plot(interval, f, 'o', color='r')

    return axes
Ejemplo n.º 4
0
def grasp_dur_across_sess(animal_fd_path, tasks_names):

    SessionsDf = utils.get_sessions(animal_fd_path)
    animal_meta = pd.read_csv(animal_fd_path / 'animal_meta.csv')

    # Filter sessions to the ones of the task we want to see
    FilteredSessionsDf = pd.concat(
        [SessionsDf.groupby('task').get_group(name) for name in tasks_names])
    log_paths = [
        Path(path) / 'arduino_log.txt' for path in FilteredSessionsDf['path']
    ]

    fig, axes = plt.subplots(ncols=2, figsize=[8, 4], sharey=True, sharex=True)
    sides = ['LEFT', 'RIGHT']

    Df = []
    # gather data
    for day, log_path in enumerate(log_paths):

        LogDf = bhv.get_LogDf_from_path(log_path)
        TrialSpans = bhv.get_spans_from_names(LogDf, "TRIAL_ENTRY_STATE",
                                              "ITI_STATE")

        TrialDfs = []
        for i, row in tqdm(TrialSpans.iterrows(), position=0, leave=True):
            TrialDfs.append(bhv.time_slice(LogDf, row['t_on'], row['t_off']))

        metrics = (met.get_start, met.get_stop, met.get_correct_side,
                   met.get_outcome, met.get_interval_category,
                   met.get_chosen_side, met.has_reach_left,
                   met.has_reach_right)
        SessionDf = bhv.parse_trials(TrialDfs, metrics)

        for side in sides:
            event_on, event_off = 'REACH_' + str(side) + '_ON', 'REACH_' + str(
                side) + '_OFF'  # event names

            # reaches
            grasp_spansDf = bhv.get_spans_from_names(LogDf, event_on,
                                                     event_off)
            reach_durs = np.array(grasp_spansDf['dt'].values, dtype=object)
            Df.append([reach_durs], columns=['durs'], ignore_index=True)

            # grasps
            choiceDf = bhv.groupby_dict(
                SessionDf, dict(has_choice=True, chosen_side=side.lower()))
            grasp_durs = choiceDf[~choiceDf['grasp_dur'].isna(
            )]['grasp_dur'].values  # filter out Nans
            Df.append([grasp_durs, 'g', side],
                      columns=['durs', 'type', 'side'],
                      ignore_index=True)

        sns.violinplot(data=Df[Df['type'] == 'r'],
                       x=day,
                       y='durs',
                       hue='side',
                       split=True,
                       cut=0,
                       legend='reaches')
        sns.violinplot(data=Df[Df['type'] == 'g'],
                       x=day,
                       y='durs',
                       hue='side',
                       split=True,
                       cut=0,
                       legend='grasps')

    fig.suptitle('CDF of first reach split on trial type \n' +
                 animal_meta['value'][5] + '-' + animal_meta['value'][0])
    axes.legend

    return axes
Ejemplo n.º 5
0
fig, axes = plt.subplots()
frame_ix = Sync.convert(t_on, 'arduino', 'dlc')
frame = dlc.get_frame(Vid, frame_ix)
dlc.plot_frame(frame, axes=axes)
dlc.plot_bodyparts(bodyparts, DlcDf, frame_ix, colors=bp_cols, axes=axes)

# trajectory
DlcDfSlice = bhv.time_slice(DlcDf, t_on, t_off)
dlc.plot_trajectories(DlcDfSlice, bodyparts, axes=axes, colors=bp_cols, lw=1, p=0.99)

# %% plot all of the selected trial type

# SDf = bhv.groupby_dict(SessionDf, dict(outcome='correct', correct_side='right', paw_resting=False))

# trial selection
SDf = bhv.groupby_dict(SessionDf, dict(has_choice=True, correct_side='left', outcome='correct'))

# plot some random frame
fig, axes = plt.subplots()

frame_ix = 1000
frame = dlc.get_frame(Vid, frame_ix)
dlc.plot_frame(frame, axes=axes)

# plot all traj in selection
for i in tqdm(SDf.index):
    TrialDf = TrialDfs[i]
    Df = bhv.event_slice(TrialDf,'PRESENT_INTERVAL_STATE','CHOICE_EVENT')
    # Df = bhv.time_slice(Df, Df.iloc[-1]['t']-500, Df.iloc[-1]['t'])
    t_on = Df.iloc[0]['t']
    t_off = Df.iloc[-1]['t']
Ejemplo n.º 6
0
bp_cols = make_bodypart_colors(bodyparts)

fig, axes = plt.subplots()
frame_ix = Sync.convert(t_on, 'arduino', 'dlc')
frame = dlc.get_frame(Vid, frame_ix)
dlc.plot_frame(frame, axes=axes)
dlc.plot_bodyparts(bodyparts, DlcDf, frame_ix, colors=bp_cols, axes=axes)

# trajectory
DlcDfSlice = bhv.time_slice(DlcDf, t_on, t_off)
dlc.plot_trajectories(DlcDfSlice, bodyparts, axes=axes, colors=bp_cols, lw=1, p=0.99)

# %% plot all of the selected trial type
# trial selection
SDf = bhv.groupby_dict(SessionDf, dict(correct_side='left',outcome='incorrect'))
# SDf = SessionDf.groupby('correct_side').get_group('right')

# plot some random frame
fig, axes = plt.subplots()
frame_ix = 1000
frame = dlc.get_frame(Vid, frame_ix)
dlc.plot_frame(frame, axes=axes)

# plot all traj in selection
for i in tqdm(SDf.index):
    TrialDf = TrialDfs[i]
    Df = bhv.event_slice(TrialDf,'PRESENT_INTERVAL_STATE','ITI_STATE')
    t_on = Df.iloc[0]['t']
    t_off = Df.iloc[-1]['t']
Ejemplo n.º 7
0
    date_abbr.append(month_abbr+'-'+str(date[2]))

    # Getting metrics
    TrialSpans = bhv.get_spans_from_names(LogDf, "TRIAL_ENTRY_STATE", "ITI_STATE")

    TrialDfs = []
    for i, row in tqdm(TrialSpans.iterrows(),position=0, leave=True):
        TrialDfs.append(bhv.time_slice(LogDf, row['t_on'], row['t_off']))

    metrics = (met.get_start, met.get_stop, met.get_correct_side, met.get_outcome, met.get_chosen_side, met.has_choice)
    SessionDf = bhv.parse_trials(TrialDfs, metrics)

    # Session metrics
    MissedDf = SessionDf[SessionDf['outcome'] == 'missed']
    choiceDf = SessionDf[SessionDf['has_choice'] == True]
    left_trials_missedDf = bhv.groupby_dict(SessionDf, dict(outcome='missed', correct_side='left'))
    right_trials_missedDf = bhv.groupby_dict(SessionDf, dict(outcome='missed', correct_side='right'))

    corr_leftDf = bhv.groupby_dict(SessionDf, dict(outcome='correct', correct_side='left'))
    left_trials_with_choiceDf = bhv.groupby_dict(SessionDf, dict(has_choice=True, correct_side='left'))
    corr_rightDf = bhv.groupby_dict(SessionDf, dict(outcome='correct', correct_side='right'))
    right_trials_with_choiceDf = bhv.groupby_dict(SessionDf, dict(has_choice=True, correct_side='right'))

    # Metrics of evolution
    try:
        perc_corr_left.append(len(corr_leftDf)/len(left_trials_with_choiceDf)*100)
    except:
        perc_corr_left.append(np.NaN)

    try:
        perc_corr_right.append(len(corr_rightDf)/len(right_trials_with_choiceDf)*100)
Ejemplo n.º 8
0
outcomes = SessionDf['outcome'].unique()
for outcome in outcomes:
    SessionDf['is_'+outcome] = SessionDf['outcome'] == outcome

# setup general filter
SessionDf['exclude'] = False

# %%

N = 100
# def plot_psychometric(SessionDf, N=1000, axes=None, discrete=False):
fig, axes = plt.subplots(figsize=[4,3])

# get only the subset with choices - excludes missed
SDf = bhv.groupby_dict(SessionDf, dict(has_choice=True, exclude=False,
                                        in_corr_loop=False, is_premature=False,
                                        timing_trial=True))

try:
    SDf = SDf.groupby('timing_trial').get_group(True)
except KeyError:
    print("no timing trials in session")

y = SDf['chosen_side'].values == 'right'
x = SDf['this_interval'].values

# axx = plt.twinx(axes)
# axx.set_yticks([0,1])
# axx.set_yticklabels(['short','long'])
# axx.set_ylabel('choice')
# axx.set_ylim(0-w, 1+w)