Exemple #1
0
def plot_choice_RTs(session_folder, save=None):
    LogDf = bhv.get_LogDf_from_path(session_folder / 'arduino_log.txt')
    session_metrics = (metrics.get_start, metrics.get_choice_rt,
                       metrics.has_choice, metrics.get_chosen_side,
                       metrics.get_correct_side, metrics.get_outcome)

    SessionDf, TrialDfs = bhv.get_SessionDf(LogDf, session_metrics)

    fig, axes = plt.subplots(nrows=2, ncols=2, sharey=True)

    sides = ['left', 'right']
    outcomes = ['correct', 'incorrect']

    bins = sp.linspace(0, 3000, 40)

    for i, side in enumerate(sides):
        for j, outcome in enumerate(outcomes):
            SDf = bhv.intersect(SessionDf,
                                has_choice=True,
                                correct_side=side,
                                outcome=outcome)
            # SDf = SessionDf.groupby(['correct_side','outcome']).get_group((side,outcome))
            values = SDf['choice_rt'].values

            # reaches are colored by the side to which the animal reaches
            # columns are the requested trial type
            if (side == 'left'
                    and outcome == 'correct') or (side == 'right'
                                                  and outcome == 'incorrect'):
                color = colors['left']
            else:
                color = colors['right']
            axes[j, i].hist(values, bins=bins, color=color)

    for i, ax in enumerate(axes[:, 0]):
        ax.set_ylabel(outcomes[i])

    for i, ax in enumerate(axes[0, :]):
        ax.set_title(sides[i])

    for i, ax in enumerate(axes[-1, :]):
        ax.set_xlabel('time (ms)')

    Session = utils.Session(session_folder)
    Animal = utils.Animal(session_folder.parent)
    title = ' - '.join(
        [Animal.display(), Session.date,
         'day: %s' % Session.day])

    sns.despine(fig)
    fig.suptitle(title)
    fig.tight_layout()
    fig.subplots_adjust(top=0.85)

    if save is not None:
        os.makedirs(save.parent, exist_ok=True)
        plt.savefig(save, dpi=600)
        plt.close(fig)
Exemple #2
0
def reaches_during_delay_across_sess(animal_fd_path, tasks_names,
                                     init_day_idx):

    SessionsDf = utils.get_sessions(animal_fd_path)
    animal_meta = pd.read_csv(animal_fd_path / 'animal_meta.csv')

    # Filter sessions to the ones of the task we want to see
    FilteredSessionsDf = pd.concat(
        [SessionsDf.groupby('task').get_group(name) for name in tasks_names])
    log_paths = [
        Path(path) / 'arduino_log.txt' for path in FilteredSessionsDf['path']
    ]

    fig, axes = plt.subplots(ncols=2, figsize=[6, 4], sharey=True, sharex=True)
    colors = sns.color_palette(palette='turbo', n_colors=len(log_paths))

    for j, log_path in enumerate(log_paths[init_day_idx:]):

        LogDf = bhv.get_LogDf_from_path(log_path)

        # ADD SINGLE GO_CUE_EVENT
        LogDf = bhv.add_go_cue_LogDf(LogDf)
        TrialSpans = bhv.get_spans_from_names(LogDf, "TRIAL_ENTRY_STATE",
                                              "ITI_STATE")

        TrialDfs = []
        for i, row in tqdm(TrialSpans.iterrows(), position=0, leave=True):
            TrialDfs.append(bhv.time_slice(LogDf, row['t_on'], row['t_off']))

        metrics = (met.get_start, met.get_stop, met.get_correct_side,
                   met.get_outcome, met.get_interval_category,
                   met.get_chosen_side, met.has_reach_left,
                   met.has_reach_right)
        SessionDf = bhv.parse_trials(TrialDfs, metrics)

        CDF_of_reaches_during_delay(SessionDf,
                                    TrialDfs,
                                    axes=axes,
                                    color=colors[j],
                                    alpha=0.75,
                                    label='day ' + str(j + 1))

    fig.suptitle('CDF of first reach split on trial type \n' +
                 animal_meta['value'][5] + '-' + animal_meta['value'][0])

    axes[0].set_ylabel('Fraction of trials')
    axes[0].legend(frameon=False, fontsize='x-small')
    fig.tight_layout()

    return axes
Exemple #3
0
def plot_reward_collection_rts(session_folder, save=None):
    LogDf = bhv.get_LogDf_from_path(session_folder / 'arduino_log.txt')
    session_metrics = (metrics.get_start, metrics.get_reward_collection_rt,
                       metrics.has_choice, metrics.get_chosen_side,
                       metrics.get_correct_side, metrics.get_outcome,
                       metrics.get_autodeliver_trial,
                       metrics.has_reward_collected)

    SessionDf, TrialDfs = bhv.get_SessionDf(LogDf, session_metrics)

    fig, axes = plt.subplots(ncols=2, sharey=True)
    bins = sp.linspace(0, 5000, 50)
    sides = ['left', 'right']
    for i, side in enumerate(sides):
        SDf = bhv.intersect(SessionDf,
                            has_reward_collected=True,
                            correct_side=side)
        axes[i].hist(SDf['reward_collection_rt'].values,
                     bins=bins,
                     color=colors[side])

        axes[i].set_title(side)

    Session = utils.Session(session_folder)
    Animal = utils.Animal(session_folder.parent)
    title = ' - '.join(
        [Animal.display(), Session.date,
         'day: %s' % Session.day])

    axes[0].set_ylabel('count')
    for ax in axes:
        ax.set_xlabel('time (ms)')

    sns.despine(fig)
    fig.suptitle(title)
    fig.tight_layout()
    fig.subplots_adjust(top=0.85)

    if save is not None:
        os.makedirs(save.parent, exist_ok=True)
        plt.savefig(save, dpi=600)
        plt.close(fig)
Exemple #4
0
def grasp_dur_across_sess(animal_fd_path, tasks_names):

    SessionsDf = utils.get_sessions(animal_fd_path)
    animal_meta = pd.read_csv(animal_fd_path / 'animal_meta.csv')

    # Filter sessions to the ones of the task we want to see
    FilteredSessionsDf = pd.concat(
        [SessionsDf.groupby('task').get_group(name) for name in tasks_names])
    log_paths = [
        Path(path) / 'arduino_log.txt' for path in FilteredSessionsDf['path']
    ]

    fig, axes = plt.subplots(ncols=2, figsize=[8, 4], sharey=True, sharex=True)
    sides = ['LEFT', 'RIGHT']

    Df = []
    # gather data
    for day, log_path in enumerate(log_paths):

        LogDf = bhv.get_LogDf_from_path(log_path)
        TrialSpans = bhv.get_spans_from_names(LogDf, "TRIAL_ENTRY_STATE",
                                              "ITI_STATE")

        TrialDfs = []
        for i, row in tqdm(TrialSpans.iterrows(), position=0, leave=True):
            TrialDfs.append(bhv.time_slice(LogDf, row['t_on'], row['t_off']))

        metrics = (met.get_start, met.get_stop, met.get_correct_side,
                   met.get_outcome, met.get_interval_category,
                   met.get_chosen_side, met.has_reach_left,
                   met.has_reach_right)
        SessionDf = bhv.parse_trials(TrialDfs, metrics)

        for side in sides:
            event_on, event_off = 'REACH_' + str(side) + '_ON', 'REACH_' + str(
                side) + '_OFF'  # event names

            # reaches
            grasp_spansDf = bhv.get_spans_from_names(LogDf, event_on,
                                                     event_off)
            reach_durs = np.array(grasp_spansDf['dt'].values, dtype=object)
            Df.append([reach_durs], columns=['durs'], ignore_index=True)

            # grasps
            choiceDf = bhv.groupby_dict(
                SessionDf, dict(has_choice=True, chosen_side=side.lower()))
            grasp_durs = choiceDf[~choiceDf['grasp_dur'].isna(
            )]['grasp_dur'].values  # filter out Nans
            Df.append([grasp_durs, 'g', side],
                      columns=['durs', 'type', 'side'],
                      ignore_index=True)

        sns.violinplot(data=Df[Df['type'] == 'r'],
                       x=day,
                       y='durs',
                       hue='side',
                       split=True,
                       cut=0,
                       legend='reaches')
        sns.violinplot(data=Df[Df['type'] == 'g'],
                       x=day,
                       y='durs',
                       hue='side',
                       split=True,
                       cut=0,
                       legend='grasps')

    fig.suptitle('CDF of first reach split on trial type \n' +
                 animal_meta['value'][5] + '-' + animal_meta['value'][0])
    axes.legend

    return axes
Exemple #5
0
def get_arduino_sync(log_path, sync_event_name="TRIAL_ENTRY_EVENT"):
    LogDf = bhv.get_LogDf_from_path(log_path)
    SyncDf = bhv.get_events_from_name(LogDf, sync_event_name)
    return SyncDf
Exemple #6
0
os.chdir(session_folder)

### DeepLabCut data
# h5_path = session_folder / [fname for fname in os.listdir(session_folder) if fname.endswith('filtered.h5')][0]
h5_path = session_folder / [fname for fname in os.listdir(session_folder) if fname.endswith('.h5')][0]
DlcDf = dlc.read_dlc_h5(h5_path)
 # getting all dlc body parts
bodyparts = sp.unique([j[0] for j in DlcDf.columns[1:]])

### Camera data
video_path = session_folder / "bonsai_video.avi"
Vid = dlc.read_video(str(video_path))

### Arduino data
log_path = session_folder / 'arduino_log.txt'
LogDf = bhv.get_LogDf_from_path(log_path)

### LoadCell Data
# LoadCellDf = bhv.parse_bonsai_LoadCellData(session_folder / 'bonsai_LoadCellData.csv')

# Syncer
from Utils import sync
cam_sync_event = sync.parse_cam_sync(session_folder / 'bonsai_frame_stamps.csv')
# lc_sync_event = sync.parse_harp_sync(session_folder / 'bonsai_harp_sync.csv')
arduino_sync_event = sync.get_arduino_sync(session_folder / 'arduino_log.txt')

Sync = Syncer()
Sync.data['arduino'] = arduino_sync_event['t'].values
# Sync.data['loadcell'] = lc_sync_event['t'].values
Sync.data['dlc'] = cam_sync_event.index.values # the frames are the DLC
Sync.data['cam'] = cam_sync_event['t'].values # used for what?
Exemple #7
0
def plot_forces_on_init(session_folder, save=None):
    LogDf = bhv.get_LogDf_from_path(session_folder / "arduino_log.txt")

    ### LoadCell Data
    LoadCellDf = bhv.parse_bonsai_LoadCellData(session_folder / 'bonsai_LoadCellData.csv')

    # Syncer
    from Utils import sync
    lc_sync_event = sync.parse_harp_sync(session_folder / 'bonsai_harp_sync.csv', trig_len=100, ttol=5)
    arduino_sync_event = sync.get_arduino_sync(session_folder / 'arduino_log.txt')

    Sync = sync.Syncer()
    Sync.data['arduino'] = arduino_sync_event['t'].values
    Sync.data['loadcell'] = lc_sync_event['t'].values
    Sync.sync('arduino','loadcell')

    LogDf['t_orig'] = LogDf['t']
    LogDf['t'] = Sync.convert(LogDf['t'].values, 'arduino', 'loadcell')

    # preprocessing
    samples = 10000 # 10s buffer: harp samples at 1khz, arduino at 100hz, LC controller has 1000 samples in buffer
    LoadCellDf['x'] = LoadCellDf['x'] - LoadCellDf['x'].rolling(samples).mean()
    LoadCellDf['y'] = LoadCellDf['y'] - LoadCellDf['y'].rolling(samples).mean()

    # plot forces  
    times = LogDf.groupby('name').get_group('TRIAL_ENTRY_EVENT')['t'].values
    pre, post = -1000, 1000
    fig, axes = plt.subplots(nrows=2,sharex=True,sharey=False)

    x_avgs = []
    y_avgs = []
    for i,t in enumerate(tqdm(times)):
        Df = bhv.time_slice(LoadCellDf, t+pre, t+post, reset_index=False)
        # these colors need to be thorougly checked
        axes[0].plot(Df['t'].values - t, Df['x'])
        axes[1].plot(Df['t'].values - t, Df['y'])

        x_avgs.append(Df['x'].values)
        y_avgs.append(Df['y'].values)

    x_avgs = np.average(np.array(x_avgs),axis=0)
    y_avgs = np.average(np.array(y_avgs),axis=0)

    tvec = np.linspace(pre,post,x_avgs.shape[0])
    axes[0].plot(tvec, x_avgs, color='k',lw=2)
    axes[1].plot(tvec, y_avgs, color='k',lw=2)

    kws = dict(linestyle=':',lw=1, alpha=0.8, color='k')
    for ax in axes:
        ax.axhline(-500, **kws)
        ax.axvline(0, **kws)

    # deco
    Session = utils.Session(session_folder)
    Animal = utils.Animal(session_folder.parent)
    title = ' - '.join([Animal.display(), Session.date, 'day: %s'% Session.day])

    for ax in axes:
        ax.set_ylim(-2500,2500)
        ax.set_ylabel('Force [au]')
    axes[1].set_xlabel('time (ms)')

    sns.despine(fig)
    fig.suptitle(title)
    fig.tight_layout()
    fig.subplots_adjust(top=0.9)

    if save is not None:
        os.makedirs(session_folder / 'plots', exist_ok=True)
        plt.savefig(save, dpi=600)
        plt.close(fig)
Exemple #8
0
def plot_bias_over_sessions(Animal_folder, task_name, save=None):
    Animal = utils.Animal(Animal_folder)

    # get BiasDfs
    SessionsDf = utils.get_sessions(Animal.folder).groupby('task').get_group(task_name)
    BiasDfs = []

    autodeliver = []
    p_lefts = []

    for i, row in SessionsDf.iterrows():
        session_folder = Path(row['path'])
        LogDf = bhv.get_LogDf_from_path(session_folder / "arduino_log.txt")
        LogDf['min'] = LogDf['t'] / 60000

        # one sesion bias
        BiasDf = LogDf.groupby('var').get_group('bias')
        t_min = BiasDf['t'].values[0]
        t_max = BiasDf['t'].values[-1]
        BiasDf['t_rel'] = (BiasDf['t'].values - t_min)/t_max

        BiasDfs.append(BiasDf)

        # get autodeliver value for session
        fname = session_folder / task_name / 'Arduino' / 'src' / 'interface_variables.h'
        value = utils.parse_arduino_vars(fname).groupby('name').get_group('autodeliver_rewards').iloc[0]['value']
        autodeliver.append(value)

        # get static bias corr if possible
        fname = session_folder / task_name / 'Arduino' / 'src' / 'interface_variables.h'
        try:
            p_left = utils.parse_arduino_vars(fname).groupby('name').get_group('p_left').iloc[0]['value']
        except KeyError:
            p_left = 0.5
        p_lefts.append(p_left)

    fig, axes = plt.subplots(nrows=2, sharex=True, gridspec_kw=dict(height_ratios=(0.1,1)))
    w = 0.5
    
    axes[0].set_ylabel('auto\nrewards')
    axes[0].set_xticks([])
    axes[0].set_yticks([])

    for i in range(SessionsDf.shape[0]):
        BiasDf = BiasDfs[i]
        tvec = np.linspace(i-w/2,i+w/2,BiasDf.shape[0])
        axes[1].plot(tvec, BiasDf['value'])
        axes[1].plot(i, np.average(BiasDf['value']),'o',color='k')
        axes[1].set_ylim(-0.1,1.1)
        if autodeliver[i] == 1:
            axes[0].plot(i,0,'o',color='black')
        # axes[0].text(i,0.03,str(p_lefts[i]),ha='center')

    axes[1].axhline(0.5,linestyle=':',lw=1,alpha=0.5,color='k')
    axes[1].set_xticks(range(SessionsDf.shape[0]))
    axes[1].set_xticklabels(SessionsDf['date'],rotation=45, ha="right")
    axes[1].set_xlabel('date')
    axes[1].set_ylabel('bias\n1=right')

    title = Animal.Nickname + ' - bias over sessions'
    axes[0].set_title(title)
    sns.despine(fig)
    fig.tight_layout()
    fig.subplots_adjust(hspace=0.1)

    if save is not None:
        os.makedirs(save.parent, exist_ok=True)
        plt.savefig(save, dpi=600)
        plt.close(fig)

# %%
# Animal_folder = Path("/media/georg/htcondor/shared-paton/georg/Animals_reaching/JJP-02911_Lumberjack")
# task_name = 'learn_to_choose_v2'
# plot_bias_over_sessions(Animal_folder, task_name=task_name, save=None)
Exemple #9
0
def plot_session_overview(session_folder, save=None, on_t=True):

    LogDf = bhv.get_LogDf_from_path(session_folder / 'arduino_log.txt')
    # session_metrics = (metrics.get_start, metrics.has_choice, metrics.get_chosen_side,
    #                 metrics.get_outcome, metrics.get_correct_side, metrics.get_timing_trial,
    #                 metrics.get_interval, metrics.get_interval_category, metrics.get_in_corr_loop,
    #                 metrics.has_reward_collected, metrics.get_autodeliver_trial)
    session_metrics = (metrics.get_start, metrics.has_choice,
                       metrics.get_chosen_side, metrics.get_outcome,
                       metrics.get_correct_side, metrics.get_timing_trial,
                       metrics.has_reward_collected,
                       metrics.get_autodeliver_trial, metrics.get_in_corr_loop)

    SessionDf, TrialDfs = bhv.get_SessionDf(LogDf, session_metrics)
    SessionDf = bhv.expand_columns(SessionDf, ['outcome'])

    outcomes = SessionDf['outcome'].unique()
    if np.any(pd.isna(outcomes)):
        SessionDf.loc[pd.isna(SessionDf['outcome']),
                      'outcome'] = 'reward_autodelivered'

    fig, axes = plt.subplots(figsize=[10, 2.6])

    for i, row in SessionDf.iterrows():
        if on_t:
            t = row['t_on'] / 60000
        else:
            t = i

        axes.plot([t, t], [0, 1],
                  lw=2.5,
                  color=colors[row['outcome']],
                  zorder=-1)

        w = 0.05
        if row['correct_side'] == 'left':
            axes.plot([t, t], [0 - w, 0 + w], lw=1, color='k')
        if row['correct_side'] == 'right':
            axes.plot([t, t], [1 - w, 1 + w], lw=1, color='k')

        if row['has_choice']:
            if row['chosen_side'] == 'left':
                axes.plot(t, -0.0, '.', color='k')
            if row['chosen_side'] == 'right':
                axes.plot(t, 1.0, '.', color='k')

        if row['in_corr_loop'] and not np.isnan(row['in_corr_loop']):
            axes.plot([t, t], [-0.1, 1.1],
                      color='red',
                      alpha=0.5,
                      zorder=-2,
                      lw=3)

        if row['timing_trial'] and not np.isnan(row['timing_trial']):
            axes.plot([t, t], [-0.1, 1.1],
                      color='cyan',
                      alpha=0.5,
                      zorder=-2,
                      lw=3)

        if row['autodeliver_rewards'] and not np.isnan(
                row['autodeliver_rewards']):
            axes.plot([t, t], [-0.1, 1.1],
                      color='pink',
                      alpha=0.5,
                      zorder=-2,
                      lw=3)

        if row['has_reward_collected']:
            if row['correct_side'] == 'left':
                axes.plot(t,
                          -0.0,
                          '.',
                          color=colors['reward'],
                          markersize=3,
                          alpha=0.5)
            if row['correct_side'] == 'right':
                axes.plot(t,
                          1.0,
                          '.',
                          color=colors['reward'],
                          markersize=3,
                          alpha=0.5)

    # success rate
    hist = 10
    for outcome in ['missed']:
        srate = (SessionDf['outcome'] == outcome).rolling(hist).mean()
        if on_t:
            k = SessionDf['t_on'].values / 60000
        else:
            k = range(SessionDf.shape[0])
        axes.plot(k, srate, lw=1.5, color='black', alpha=0.75)
        axes.plot(k, srate, lw=1, color=colors[outcome], alpha=0.75)

    # valid trials
    SDf = bhv.intersect(SessionDf, is_missed=False)
    srate = (SDf.outcome == 'correct').rolling(hist).mean()
    if on_t:
        k = SDf['t_on'] / 60000
    else:
        k = SDf.index

    axes.plot(k, srate, lw=1.5, color='k')
    axes.axhline(0.5, linestyle=':', color='k', alpha=0.5)

    # deco
    if on_t:
        axes.set_xlabel('time (min)')
    else:
        axes.set_xlabel('trial #')
    axes.set_ylabel('success rate')

    Session = utils.Session(session_folder)
    Animal = utils.Animal(session_folder.parent)
    title = ' - '.join(
        [Animal.display(), Session.date,
         'day: %s' % Session.day])

    sns.despine(fig)
    fig.suptitle(title)
    fig.tight_layout()
    fig.subplots_adjust(top=0.9)

    if save is not None:
        os.makedirs(save.parent, exist_ok=True)
        plt.savefig(save, dpi=600)
        plt.close(fig)
Exemple #10
0
        Policeman="/media/georg/htcondor/shared-paton/georg/Animals_reaching/JJP-02996",
        Therapist="/media/georg/htcondor/shared-paton/georg/Animals_reaching/JJP-02997")

# %%

for animal, folder in folders.items():
    print("processing %s" % animal)
    folder = Path(folder)
    Animal = utils.Animal(folder)
    SessionsDf = utils.get_sessions(Animal.folder).groupby('task').get_group('learn_to_choose_v2')
    SessionsDf = SessionsDf.reset_index()

    for day, row in SessionsDf.iterrows():
        print(day)
        folder = Path(SessionsDf.loc[day,'path'])
        LogDf = bhv.get_LogDf_from_path(folder / "arduino_log.txt")
        LogDf['min'] = LogDf['t'] / 60000


        # check each reach
        ReachesLeftDf = bhv.get_spans_from_names(LogDf, "REACH_LEFT_ON", "REACH_LEFT_OFF")

        # drop invalid
        min_th = 5
        max_th = 2000

        binds = np.logical_and(ReachesLeftDf['dt'].values > min_th, ReachesLeftDf['dt'].values < max_th)

        ReachesLeftDf = ReachesLeftDf.loc[binds]

        ReachesLeftDf['is_grasp'] = False
Exemple #11
0
def plot_init_hist(session_folder, save=None):

    LogDf = bhv.get_LogDf_from_path(session_folder / "arduino_log.txt")

    # Sync first
    loadcell_sync_event = sync.parse_harp_sync(session_folder /
                                               'bonsai_harp_sync.csv',
                                               trig_len=100,
                                               ttol=5)
    arduino_sync_event = sync.get_arduino_sync(session_folder /
                                               'arduino_log.txt')

    Sync = sync.Syncer()
    Sync.data['arduino'] = arduino_sync_event['t'].values
    Sync.data['loadcell'] = loadcell_sync_event['t'].values
    success = Sync.sync('arduino', 'loadcell')

    # abort if sync fails
    if not success:
        utils.printer(
            "trying to plot_init_hist, but failed to sync in file %s, - aborting"
            % session_folder)
        return None

    LogDf['t_orig'] = LogDf['t']
    LogDf['t'] = Sync.convert(LogDf['t'].values, 'arduino', 'loadcell')

    LoadCellDf = bhv.parse_bonsai_LoadCellData(session_folder /
                                               'bonsai_LoadCellData.csv')

    # preprocessing
    samples = 10000  # 10s buffer: harp samples at 1khz, arduino at 100hz, LC controller has 1000 samples in buffer
    LoadCellDf['x'] = LoadCellDf['x'] - LoadCellDf['x'].rolling(samples).mean()
    LoadCellDf['y'] = LoadCellDf['y'] - LoadCellDf['y'].rolling(samples).mean()

    # smoothing forces
    F = LoadCellDf[['x', 'y']].values
    w = np.ones(100)
    F[:, 0] = np.convolve(F[:, 0], w, mode='same')
    F[:, 1] = np.convolve(F[:, 1], w, mode='same')

    # detect pushes
    th = 500
    L = F < -th
    events = np.where(np.diff(np.logical_and(L[:, 0], L[:, 1])) == 1)[0]
    times = [LoadCellDf.iloc[int(i)]['t'] for i in events]

    # histogram of pushes pre vs pushes post trial available
    trial_times = bhv.get_events_from_name(LogDf,
                                           'TRIAL_AVAILABLE_EVENT')['t'].values
    post = []
    pre = []

    for t in trial_times:
        dt = times - t
        try:
            post.append(np.min(dt[dt > 0]))
        except ValueError:
            # thrown when no more pushes after last init
            pass
        try:
            pre.append(np.min(-1 * dt[dt < 0]))
        except ValueError:
            # thrown when no pushes before first init
            pass

    fig, axes = plt.subplots()
    bins = np.linspace(0, 5000, 25)
    axes.hist(pre, bins=bins, alpha=0.5, label='pre')
    axes.hist(post, bins=bins, alpha=0.5, label='post')
    axes.set_xlabel('time (ms)')
    axes.set_ylabel('count')
    axes.legend()

    Session = utils.Session(session_folder)
    Animal = utils.Animal(session_folder.parent)
    title = ' - '.join(
        [Animal.display(), Session.date,
         'day: %s' % Session.day])

    sns.despine(fig)
    fig.suptitle(title)
    fig.tight_layout()
    fig.subplots_adjust(top=0.85)

    if save is not None:
        os.makedirs(save.parent, exist_ok=True)
        plt.savefig(save, dpi=600)
        plt.close(fig)
Exemple #12
0
def plot_psychometric(session_folder, N=1000, kind='true', fit_lapses=True, save=None):

    LogDf = bhv.get_LogDf_from_path(session_folder / 'arduino_log.txt')

    # exit here if there are no timing trials
    if not np.any(LogDf.groupby('var').get_group('timing_trial')['value']):
        return None

    session_metrics = (metrics.get_start, metrics.has_choice, metrics.get_chosen_side, 
                        metrics.get_outcome, metrics.get_correct_side, metrics.get_timing_trial,
                        metrics.get_interval, metrics.get_interval_category, metrics.get_in_corr_loop,
                        metrics.has_reward_collected, metrics.get_autodeliver_trial, metrics.get_chosen_interval)

    SessionDf, TrialDfs = bhv.get_SessionDf(LogDf, session_metrics)
    SessionDf = bhv.expand_columns(SessionDf, ['outcome'])

    outcomes = SessionDf['outcome'].unique()
    if np.any(pd.isna(outcomes)):
        SessionDf.loc[pd.isna(SessionDf['outcome']),'outcome'] = 'reward_autodelivered'

    # the three variants
    if kind == 'true':
        SDf = bhv.intersect(SessionDf, has_choice=True, is_premature=False, timing_trial=True)
    if kind == 'cued':
        SDf = bhv.intersect(SessionDf, has_choice=True, is_premature=False, timing_trial=False)
    if kind == 'premature':
        SDf = bhv.intersect(SessionDf, has_choice=True, is_premature=True)

    fig, axes = plt.subplots()

    # plot the choices as p(Long)
    intervals = list(np.sort(SDf['this_interval'].unique()))
    for i, interval in enumerate(intervals):
        Df = bhv.intersect(SDf, this_interval=interval)
        f = np.sum(Df['chosen_interval'] == 'long') / Df.shape[0]
        axes.plot(interval, f,'o',color='r')
    axes.set_ylabel('p(choice = long)')

    # plot the fit
    y = SDf['chosen_side'].values == 'right'
    x = SDf['this_interval'].values
    x_fit = np.linspace(0,3000,100)

    y_fit, p_fit = bhv.log_reg_cf(x, y, x_fit, fit_lapses=fit_lapses)
    axes.plot(x_fit, y_fit,color='red', linewidth=2,alpha=0.75)

    if fit_lapses:
        # add lapse rate as text to the axes
        lapse_upper = p_fit[3] + p_fit[2]
        lapse_lower = p_fit[3]
        axes.text(intervals[0], lapse_lower+0.05, "%.2f" % lapse_lower, ha='center', va='center')
        axes.text(intervals[-1], lapse_upper+0.05, "%.2f" % lapse_upper, ha='center', va='center')

    if N is not None:
        # simulating random choices based, respecting session bias
        bias = (SDf['chosen_side'] == 'right').sum() / SDf.shape[0]
        R = np.zeros((x_fit.shape[0], N))
        P = []
        R[:] = np.nan
        for i in tqdm(range(N)):
            # simulating random choices
            rand_choices = sp.rand(SDf.shape[0]) < bias
            try:
                y_fit, p_fit = bhv.log_reg_cf(x, rand_choices, x_fit, fit_lapses=fit_lapses)
                R[:,i] = y_fit
                P.append(p_fit)
            except RuntimeError:
                pass
    
        # filter out NaN cols
        R = R[:,~np.isnan(R[0,:])]
        R = R.T
        P = np.array(P)

        # Several statistical boundaries
        alphas = [5, 0.5, 0.05]
        opacities = [0.2, 0.2, 0.2]
        for alpha, a in zip(alphas, opacities):
            R_pc = np.percentile(R, (alpha, 100-alpha), 0)
            axes.fill_between(x_fit, R_pc[0], R_pc[1], color='blue', alpha=a, linewidth=0)

    # deco
    w = 0.05
    axes.set_ylim(0-w, 1+w)
    axes.axvline(1500,linestyle=':',alpha=0.5,lw=1,color='k')
    axes.axhline(0.5,linestyle=':',alpha=0.5,lw=1,color='k')

    axes.set_xlim(500,2500)
    axes.set_xlabel('time (ms)')

    Session = utils.Session(session_folder)
    Animal = utils.Animal(session_folder.parent)
    title = ' - '.join([Animal.display(), Session.date, 'day: %s'% Session.day])

    sns.despine(fig)
    fig.suptitle(title)
    fig.tight_layout()
    fig.subplots_adjust(top=0.9)

    if save is not None:
        os.makedirs(save.parent, exist_ok=True)
        plt.savefig(save, dpi=600)
        plt.close(fig)
Exemple #13
0
#     if save:
#         outpath = plot_dir / 'session_overview_coarse.png'
#         plt.savefig(outpath, dpi=600)
#         plt.close(fig)
#     else:
#         return fig, axes


# %%

session_folder = Path("/media/georg/htcondor/shared-paton/georg/Animals_reaching/JJP-01975_Marquez/2021-05-18_09-41-58_learn_to_fixate_discrete_v1")
Session = utils.Session(session_folder)
Animal = utils.Animal(session_folder.parent)

LogDf = bhv.get_LogDf_from_path(session_folder / "arduino_log.txt")
LogDf['min'] = LogDf['t'] / 60000

#  slice into trials
def get_SessionDf(LogDf, metrics, trial_entry_event="TRIAL_AVAILABLE_STATE", trial_exit_event="ITI_STATE"):

    TrialSpans = bhv.get_spans_from_names(LogDf, trial_entry_event, trial_exit_event)

    TrialDfs = []
    for i, row in tqdm(TrialSpans.iterrows(),position=0, leave=True):
        TrialDfs.append(bhv.time_slice(LogDf, row['t_on'], row['t_off']))
    
    SessionDf = bhv.parse_trials(TrialDfs, metrics)
    return SessionDf, TrialDfs

from Utils.metrics import *