def plot_raster_trials( s, ax=None, sub_colors_dict=None, align=[0, 0, 0, 0], reindex=None): ''' Plot raster of behaviour related ts aligned to different points. Parameters ---------- s : session object ax : plt.axe, default None Optional ax object to plot into. sub_colors_dict: dict, None dict with subject: colors, used to assign color to title align : list of bool, [0, 0, 0, 0] input structure - [align_rw, align_pell, align_FI, align_FRes] if [0, 0, 0, 0], start aligned. reindex: list, None list of which to reorder trials by ''' if ax is None: fig, ax = plt.subplots(figsize=(20, 10)) else: fig = None # alignment decision align_rw, align_pell, align_FI, align_FRes = align # Retrive session related variables date = s.get_metadata('start_date').replace('/', '_') sub = s.get_metadata('subject') stage = s.get_stage() trial_df = s.get_trial_df_norm() if reindex is not None: trial_df = trial_df.reindex(reindex, copy=True).reset_index() plt.yticks(np.arange(len(reindex)), reindex, fontsize=10) # Extract data from pandas_df def df_col_to_2d_list(df, col): df_col = df[col] l = [] for i in range(len(df_col)): val = np.copy(df_col.at[i]) l.append(val) return l norm_lever = df_col_to_2d_list(trial_df, "Levers_ts") norm_err = df_col_to_2d_list(trial_df, "Err_ts") norm_dr = df_col_to_2d_list(trial_df, "D_Pellet_ts") norm_pell = df_col_to_2d_list(trial_df, "Pellet_ts") norm_rw = df_col_to_2d_list(trial_df, "Reward_ts") norm_FRes = df_col_to_2d_list(trial_df, "First_response") schedule_type = trial_df['Schedule'].copy(deep=True) norm_tone = trial_df['Tone_s'].copy(deep=True) norm_start = trial_df['Trial_s'].copy(deep=True) color = [] # Alignment Specific Parameters if align_rw: plot_type = 'Reward-Aligned' norm_arr = norm_rw elif align_pell: plot_type = 'Pell-Aligned' norm_arr = norm_pell xmax = 5 xmin = -30 elif align_FI: plot_type = 'Interval-Aligned' norm_arr = np.empty_like(norm_rw) norm_arr.fill(30) elif align_FRes: plot_type = 'First_Resp-Aligned' norm_arr = norm_FRes xmax = 40 xmin = -10 else: plot_type = 'Start-Aligned' norm_arr = norm_start xmax = None xmin = None # xmax = 60 # xmin = -10 plot_name = 'Raster ({})'.format(plot_type) ax.axvline(0, linestyle='-.', color='g', linewidth=1) ax.text(0.1, -1, plot_type.split('-')[0], fontsize=12, color='g', ha='left', va='top') for i in range(len(norm_rw)): # color assigment for trial type if schedule_type[i] == 'FR': color.append('black') elif schedule_type[i] == 'FI': color.append('b') else: color.append('g') subtract_val = norm_arr[i].flatten() if len(subtract_val) != 1: raise ValueError("Can't align to non single value") subtract_val = subtract_val[0] norm_lever[i] -= subtract_val norm_err[i] -= subtract_val norm_dr[i] -= subtract_val norm_pell[i] -= subtract_val norm_rw[i] -= subtract_val norm_tone[i] -= subtract_val norm_start[i] -= subtract_val # Plotting of raster ax.eventplot(norm_lever, color=color) ax.eventplot(norm_err, color='red') for i, x in enumerate(norm_rw): if len(x) == 0: norm_rw[i] = None rw_plot = ax.scatter(norm_rw, np.arange(len(norm_rw)), s=5, color='orange', label='Reward Collection') ax.eventplot(norm_dr, color='magenta') # Legend construction (Standard items) FR_label = lines.Line2D([], [], color='black', marker='|', linestyle='None', markersize=10, markeredgewidth=1.5, label='FR press') FI_label = lines.Line2D([], [], color='b', marker='|', linestyle='None', markersize=10, markeredgewidth=1.5, label='FI press') drw_label = lines.Line2D([], [], color='magenta', marker='|', linestyle='None', markersize=10, markeredgewidth=1.5, label='Double Reward') handles = [FR_label, FI_label, drw_label, rw_plot] # Figure labels ax.set_xlim(xmin, xmax) # Uncomment to set x limit ax.set_ylim(-3, len(norm_rw) + 3) # Uncomment to set y limit ax.tick_params(axis='both', labelsize=15) ax.set_xlabel('Time (s)', fontsize=20) ax.set_ylabel('Trials', fontsize=20) if sub_colors_dict == None: color = "k" else: color = mycolors(sub, sub_colors_dict) ax.set_title(' {}'.format(plot_name), y=1.04, ha='center', fontsize=20) ax.text(0.5, 1.015, '{} {} S{}'.format(sub, date, stage), ha='center', transform=ax.transAxes, fontsize=12) # Optional plot options [Pell, Tone, dr_win] opt_plot = [1, 1, 1] if opt_plot[0]: # Plot pellet drops x = norm_pell ax.eventplot(x, color='green') pell_label = lines.Line2D([], [], color='green', marker='|', linestyle='None', markersize=10, markeredgewidth=1.5, label='Pell') handles.append(pell_label) if opt_plot[1]: for i, x in enumerate(norm_tone): # Plot Tone presentation if x: plt.broken_barh([(x, 5)], (i - .5, 1), color='grey', alpha=0.2, hatch='///', edgecolor='k') tone_label = mpatches.Patch( facecolor='grey', hatch='///', edgecolor='k', alpha=0.2, label='Tone') handles.append(tone_label) if opt_plot[2]: for i, (x, sch) in enumerate(zip(norm_start, schedule_type)): # Plot DR window if sch == 'FI': plt.broken_barh([(x + 20, 20)], (i - .4, 0.8), color='magenta', alpha=0.05) drwin_label = mpatches.Patch( color='magenta', alpha=0.05, label='dr_win') handles.append(drwin_label) ax.legend(handles=handles, fontsize=12, loc='upper right') # Highlight specific trials hline, h_ref = [], [] h_dr, h_err, h_fr = [1, 0, 0] if h_dr: c = 'pink' h_ref = norm_dr elif h_err: c = 'magenta' h_ref = norm_err elif h_fr: c = 'k' for s in schedule_type: if s == 'FR': h_ref.append([1]) else: h_ref.append([]) else: pass # highlight if array is not empty for i, ts in enumerate(h_ref): if len(ts) > 0: hline.append(i) for l in hline: plt.axhline(l, linestyle='-', color=c, linewidth='5', alpha=0.1) return fig
def trial_length_hist(s, valid=True, ax=None, loop=None, sub_colors_dict=None): ''' Plot histrogram of trial durations Parameters ---------- s : session object ax : plt.axe, default None Optional ax object to plot into. loop: str, None indicates number of iterations through plot without changing axis *Mainly used to identify morning/afternoon sessions for each animal sub_colors_dict: dict, None dict with subject: colors, used to assign color to title ''' if ax is None: fig, ax = plt.subplots(figsize=(10, 10)) else: fig = None if valid: t_df = s.get_valid_tdf(norm=True) else: t_df = s.get_trial_df_norm() # Trial duration in ms for FR and FI t_len = { 'FI': t_df[t_df['Schedule'] == 'FI']['Reward_ts'], 'FR': t_df[t_df['Schedule'] == 'FR']['Reward_ts']} gm = bv_plot.GroupManager(['FI', 'FR']) if loop: txt = "_p" + str(loop) else: txt = "" log = False for key, x in t_len.items(): c = gm.get_next_color() if log: x = np.log(x) sns.distplot(x, ax=ax, label='{}{}'.format(key, txt), color=c) # Plot customization date = s.get_metadata('start_date').replace('/', '_') sub = s.get_metadata('subject') stage = s.get_stage() plot_name = 'Trial Length Hist' if valid: plot_name += '_v' ax.tick_params(axis='both', labelsize=12) ax.set_ylabel('Probability Density', fontsize=20) ax.set_xlabel('Time (s)', fontsize=20) # ax.set_ylabel('Trials', fontsize=20) if sub_colors_dict == None: color = "k" else: color = mycolors(sub, sub_colors_dict) ax.set_title(' {}'.format(plot_name), y=1.04, ha='center', fontsize=25, color=color) ax.text(0.5, 1.015, '{} {} S{}'.format(sub, date, stage), ha='center', transform=ax.transAxes, fontsize=12, color=color) ax.legend(fontsize=20) return fig
def IRT(session, out_dir, ax=None, showIRT=False): """ Perform an inter-response time plot for a Session. IRT calculated from prev reward to next lever press resulting in reward """ single_plot = False # General session info date = session.get_metadata('start_date').replace('/', '_') session_type = session.get_metadata('name') stage = session_type[:2].replace('_', '') subject = session.get_metadata('subject') ratio = session.get_ratio() interval = session.get_interval() # Timestameps data extraction time_taken = session.time_taken() # lever ts without unnecessary presses good_lever_ts = session.get_lever_ts(False) # Only consider rewards for lever pressing rewards_i = session.get_rw_ts() reward_idxs = np.nonzero(rewards_i >= good_lever_ts[0]) rewards = rewards_i[reward_idxs] # b assigns ascending numbers to rewards within lever presses b = np.digitize(rewards, bins=good_lever_ts) _, a = np.unique(b, return_index=True) # returns index for good rewards good_rewards = rewards[a] # nosepoke ts for pressing levers if session_type == '5a_FixedRatio_p': last_lever_ts = [] for i in b: last_lever_ts.append(good_lever_ts[i - 1]) else: last_lever_ts = good_lever_ts if len(last_lever_ts[1:]) > len(good_rewards[:-1]): IRT = last_lever_ts[1:] - good_rewards[:] # Ended sess w lever press else: # Ended session w nosepoke IRT = last_lever_ts[1:] - good_rewards[:-1] hist_count, hist_bins, _ = ax.hist( IRT, bins=math.ceil(np.amax(IRT)), range=(0, math.ceil(np.amax(IRT))), color=mycolors(subject)) # Plotting of IRT Graphs if ax is None: single_plot = True fig, ax = plt.subplots() ax.set_title('Inter-Response Time\n', fontsize=15) if session_type == '5a_FixedRatio_p': plt.suptitle('\n({}, {} {}, {})'.format( subject, session_type[:-2], ratio, date), fontsize=10, y=.98, x=.51) elif session_type == '5b_FixedInterval_p': plt.suptitle('\n({}, {} {}s, {})'.format( subject, session_type[:-2], interval, date), fontsize=10, y=.98, x=.51) else: plt.suptitle('\n({}, {}, {})'.format( subject, session_type[:-2], date), fontsize=10, y=.98, x=.51) else: ax.set_title('\n{}, S{}, IRT'.format( subject, stage), color=mycolors(subject), fontsize=10, y=1, x=.51) ax.set_xlabel('IRT (s)') ax.set_ylabel('Counts') maxidx = np.argmax(np.array(hist_count)) maxval = (hist_bins[maxidx + 1] - hist_bins[maxidx]) / \ 2 + hist_bins[maxidx] ax.text(0.45, 0.85, 'Session Duration: {} mins\nMost Freq. IRT Bin: {} s' .format(time_taken, maxval), transform=ax.transAxes) if showIRT: show_IRT_details(IRT, maxidx, hist_bins) if single_plot: # Text Display on Graph text = ( 'Session Duration: ' + '{} mins\nMost Freq. IRT Bin: {} s'.format( time_taken, maxval)) ax.text(0.55, 0.8, text, transform=ax.transAxes) out_name = (subject.zfill(3) + "_IRT_Hist_" + session_type[:-2] + "_" + date + ".png") print("Saved figure to {}".format( os.path.join(out_dir, out_name))) bv_plot.savefig(fig, os.path.join(out_dir, out_name)) plt.close() else: return ax
def lever_hist(s, ax=None, valid=True, excl_dr=False, split_t=False, sub_colors_dict=None): ''' Plot histrogram of lever presses Parameters ---------- s : session object ax : plt.axe, default None Optional. ax object to plot into. split_t : bool, False Optional. Plots lever histogram by trials ''' if ax is None: fig, ax = plt.subplots(figsize=(10, 10)) else: fig = None if valid: t_df = s.get_valid_tdf(excl_dr=excl_dr, norm=True) else: t_df = s.get_trial_df_norm() if split_t: gm = bv_plot.GroupManager(t_df['Schedule'].values.tolist()) for idx, row in t_df.iterrows(): color = gm.get_next_color() lev_ts = row['Levers_ts'] x = lev_ts[~np.isnan(lev_ts)] # Remove NaN from lever timestamps if row['Schedule'] == 'FR': sns.distplot(x, ax=ax, label='FR-t{}'.format(idx + 1), color=color, hist=True) elif row['Schedule'] == 'FI': sns.distplot(x, ax=ax, label='FI-t{}'.format(idx + 1), color=color, hist=True) legend_size = 6 else: gm = bv_plot.GroupManager(['FI', 'FR']) t_lev = { 'FI': t_df[t_df['Schedule'] == 'FI']['Levers_ts'], 'FR': t_df[t_df['Schedule'] == 'FR']['Levers_ts']} for key, x in t_lev.items(): c = gm.get_next_color() x = x.to_numpy() # convert pandas to numpy # flatten nested numpy into single numpy x = np.concatenate(x).ravel() x = x[~np.isnan(x)] # remove NaN from numpy sns.distplot(x, ax=ax, label=key, color=c) legend_size = 10 # Plot customization date = s.get_metadata('start_date').replace('/', '_') sub = s.get_metadata('subject') stage = s.get_stage() plot_name = 'Lever Response Hist' if split_t: plot_name += ' (Trials)' if valid: plot_name += '_v' if excl_dr: plot_name += '_exdr' ax.tick_params(axis='both', labelsize=12) ax.set_ylabel('Probability Density', fontsize=20) ax.set_xlabel('Time (s)', fontsize=20) # ax.set_ylabel('Trials', fontsize=20) if sub_colors_dict == None: color = "k" else: color = mycolors(sub, sub_colors_dict) ax.set_title(' {}'.format(plot_name), y=1.04, ha='center', fontsize=25, color=color) ax.text(0.5, 1.015, '{} {} S{}'.format(sub, date, stage), ha='center', transform=ax.transAxes, fontsize=12, color=color) ax.legend(fontsize=legend_size, ncol=2) return fig
def cumplot(session, out_dir, ax=None, int_only=False, zoom=False, zoom_sch=False, plot_error=False, plot_all=False): """Perform a cumulative plot for a Session.""" date = session.get_metadata('start_date').replace('/', '_') timestamps = session.get_arrays() lever_ts = session.get_lever_ts() session_type = session.get_metadata('name') stage = session_type[:2].replace('_', '') subject = session.get_metadata('subject') reward_times = session.get_rw_ts() pell_ts = timestamps["Reward"] pell_double = np.nonzero(np.diff(pell_ts) < 0.5) # for printing of error rates and rewards on graph err_FI = 0 err_FR = 0 rw_FR = 0 rw_FI = 0 reward_double = reward_times[np.searchsorted( reward_times, pell_ts[pell_double], side='right')] single_plot = False ratio = session.get_ratio() interval = session.get_interval() if ax is None: single_plot = True fig, ax = plt.subplots() ax.set_title('Cumulative Lever Presses\n', fontsize=15) if session_type == '5a_FixedRatio_p': plt.suptitle('\n{}, {} {}, {}'.format( subject, session_type[:-2], ratio, date), color=mycolors(subject), fontsize=10, y=.98, x=.51) elif session_type == '5b_FixedInterval_p': plt.suptitle('\n{}, {} {}s, {}'.format( subject, session_type[:-2], interval, date), color=mycolors(subject), fontsize=10, y=.98, x=.51) elif session_type == '6_RandomisedBlocks_p:': plt.suptitle('\n{}, {} FR{}/FI{}s, {}'.format( subject, session_type[:-2], ratio, interval, date), color=mycolors(subject), fontsize=10, y=.98, x=.51) else: plt.suptitle('\n{}, S{}, {}'.format( subject, stage, date), color=mycolors(subject), fontsize=10, y=.98, x=.51) else: if session_type == '5a_FixedRatio_p': ax.set_title('\n{}, S{}, FR{}, {}'.format( subject, stage, ratio, date), color=mycolors(subject), fontsize=10) elif session_type == '5b_FixedInterval_p': ax.set_title('\n{}, S{}, FI{}s, {}'.format( subject, stage, interval, date), color=mycolors(subject), fontsize=10) elif session_type == '6_RandomisedBlocks_p' or stage == '7': switch_ts = np.arange(5, 1830, 305) for x in switch_ts: plt.axvline(x, color='g', linestyle='-.', linewidth='.4') ax.set_title('\n{}, S{}, FR{}/FI{}s, {}'.format( subject, stage, ratio, interval, date), color=mycolors(subject), fontsize=10) else: ax.set_title('\n{}, S{}, {}'.format( subject, stage, date), color=mycolors(subject), fontsize=10, y=1, x=.51) ax.set_xlabel('Time (s)') ax.set_ylabel('Cumulative Lever Presses') # Broken and plots are ugly - Plots single trials if zoom: trial_lever_ts = np.split(lever_ts, np.searchsorted(lever_ts, reward_times)) norm_reward_ts = [] norm_lever_ts = [] reward_times_0 = np.append([0], reward_times, axis=0) for i, l in enumerate(trial_lever_ts[:-1]): norm_lever_ts.append(np.append([0], l - reward_times_0[i], axis=0)) norm_reward_ts.append(reward_times[i] - reward_times_0[i]) ax.set_xlim(0, np.max(norm_reward_ts)) color = plt.cm.get_cmap('autumn') for i, l in enumerate(norm_lever_ts): ax.step(l, np.arange(l.size), c=color(i * 20), where="post") bins = l reward_y = np.digitize(norm_reward_ts[i], bins) - 1 plt.scatter(norm_reward_ts[i], reward_y, marker="x", c="grey", s=25) ax.set_title('\n{}, Trial-Based'.format( subject), color=mycolors(subject), fontsize=10) ax.legend(loc='lower right') return fig elif zoom_sch and (session_type == '6_RandomisedBlocks_p' or stage == '7'): # plots cum graph based on schedule type (i.e. FI/FR) norm_r_ts, norm_l_ts, norm_err_ts, norm_dr_ts, incl = session.split_sess( plot_error=plot_error, all_levers=plot_all) sch_type = session.get_arrays('Trial Type') ratio_c = plt.cm.get_cmap('Wistia') interval_c = plt.cm.get_cmap('winter') ax.set_xlim(0, 305) for i, l in enumerate(norm_l_ts): if sch_type[i] == 1 and not int_only: ax.step(l, np.arange(l.size), c=ratio_c(i * 45), where="post", label='B' + str(i + 1) + ' - FR', zorder=1) elif sch_type[i] == 0: ax.step(l, np.arange(l.size), c=interval_c(i * 45), where="post", label='B' + str(i + 1) + ' - FI', zorder=1) bins = l reward_y = np.digitize(norm_r_ts[i], bins) - 1 double_y = np.digitize(norm_dr_ts[i], bins) - 1 if stage == '7' and plot_all: # plots all responses incl. errors ax.scatter(norm_err_ts[i], np.isin( l, norm_err_ts[i]).nonzero()[0], c='r', s=1, zorder=2) incl = '_All' if int_only: if sch_type[i] == 0: plt.scatter(norm_r_ts[i], reward_y, marker="x", c="grey", s=25) plt.scatter(norm_dr_ts[i], double_y, marker="x", c="magenta", s=25) else: plt.scatter(norm_r_ts[i], reward_y, marker="x", c="grey", s=25) plt.scatter(norm_dr_ts[i], double_y, marker="x", c="magenta", s=25) ax.set_title('\n{}, Block-Split {}'.format( subject, incl), color=mycolors(subject), fontsize=10) ax.legend(loc='upper left') return fig elif zoom_sch: # plots cum graph split into 5 min blocks if session_type == '5a_FixedRatio_p': sch_type = 'FR' ax.set_title('\n{}, FR{} Split'.format( subject, ratio), color=mycolors(subject), fontsize=10) elif session_type == '5b_FixedInterval_p': sch_type = 'FI' ax.set_title('\n{}, FI{}s Split'.format( subject, interval), color=mycolors(subject), fontsize=10) elif stage == 6 or stage == 7: pass else: return print("Unable to split session") # Change values to set division blocks blocks = np.arange(0, 60 * 30, 300) norm_r_ts, norm_l_ts, norm_err_ts, norm_dr_ts, _ = session.split_sess( blocks) ax.set_xlim(0, 305) for i, l in enumerate(norm_l_ts): ax.step(l, np.arange(l.size), c=mycolors(i), where="post", label='B' + str(i + 1) + ' - {}'.format(sch_type)) bins = l reward_y = np.digitize(norm_r_ts[i], bins) - 1 double_y = np.digitize(norm_dr_ts[i], bins) - 1 plt.scatter(norm_r_ts[i], reward_y, marker="x", c="grey", s=25) plt.scatter(norm_dr_ts[i], double_y, marker="x", c="magenta", s=25) ax.legend(loc='upper left') return fig else: if stage == '7': err_lever_ts = session.get_err_lever_ts() lever_ts = np.sort(np.concatenate(( lever_ts, err_lever_ts), axis=None)) lever_times = np.insert(lever_ts, 0, 0, axis=0) ax.step(lever_times, np.arange( lever_times.size), c=mycolors(subject), where="post", label='Animal' + subject, zorder=1) if stage == '7': # plots error press in red ax.scatter(err_lever_ts, np.isin( lever_times, err_lever_ts).nonzero()[0], c='r', label='Errors', s=1, zorder=2) if reward_times[-1] > lever_times[-1]: ax.plot( [lever_times[-1], reward_times[-1] + 2], [lever_times.size - 1, lever_times.size - 1], c=mycolors(subject)) bins = lever_times reward_y = np.digitize(reward_times, bins) - 1 double_y = np.digitize(reward_double, bins) - 1 # for printing of error rates on graph norm_r_ts, _, norm_err_ts, _, _ = session.split_sess( all_levers=True) sch_type = session.get_arrays('Trial Type') for i, l in enumerate(norm_err_ts): if sch_type[i] == 1: err_FR = err_FR + len(norm_err_ts[i]) elif sch_type[i] == 0: err_FI = err_FI + len(norm_err_ts[i]) if stage == '6' or stage == '7': for i, l in enumerate(norm_r_ts): if sch_type[i] == 1: rw_FR = rw_FR + len(norm_r_ts[i]) elif sch_type[i] == 0: rw_FI = rw_FI + len(norm_r_ts[i]) rw_print = "\nCorrect FR \\ FI: " + \ str(rw_FR) + r" \ " + str(rw_FI) ax.scatter(reward_times, reward_y, marker="x", c="grey", label='Reward Collected', s=25) if len(reward_double) > 0: dr_print = "\nTotal # of Double Rewards: " + str(len(reward_double)) ax.scatter(reward_double, double_y, marker="x", c="magenta", label='Double Reward', s=25) else: dr_print = "" ax.legend(loc='lower right') # ax.set_xlim(0, 30 * 60 + 30) if err_FR > 0 or err_FI > 0: err_print = "\nErrors FR \\ FI: " + str(err_FR) + r" \ " + str(err_FI) else: err_print = "" if single_plot: out_name = (subject.zfill(3) + "_CumulativeHist_" + date + "_" + session_type[:-2] + ".png") out_name = os.path.join(out_dir, out_name) print("Saved figure to {}".format(out_name)) # Text Display on Graph if stage == '6' or stage == '7': text = ( 'Total # of Lever Press: ' + '{}\nTotal # of Rewards: {}{}{}{}'.format( len(lever_ts), len(reward_times) + len(reward_double), dr_print, rw_print, err_print)) ax.text(0.55, 0.15, text, transform=ax.transAxes) bv_plot.savefig(fig, out_name) plt.close() else: # Text Display on Graph if stage == '6' or stage == '7': text = ( 'Total # of Lever Press: ' + '{}\nTotal # of Rewards: {}{}{}{}'.format( len(lever_ts), len(reward_times) + len(reward_double), dr_print, rw_print, err_print)) ax.text(0.05, 0.75, text, transform=ax.transAxes) return fig